aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/openai/cli/_api/image.py
blob: 3e2a0a90f1b41a1e0a39e401932e0f5ccfc3e3b0 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from __future__ import annotations

from typing import TYPE_CHECKING, Any, cast
from argparse import ArgumentParser

from .._utils import get_client, print_model
from ..._types import NOT_GIVEN, NotGiven, NotGivenOr
from .._models import BaseModel
from .._progress import BufferReader

if TYPE_CHECKING:
    from argparse import _SubParsersAction


def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
    sub = subparser.add_parser("images.generate")
    sub.add_argument("-m", "--model", type=str)
    sub.add_argument("-p", "--prompt", type=str, required=True)
    sub.add_argument("-n", "--num-images", type=int, default=1)
    sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
    sub.add_argument("--response-format", type=str, default="url")
    sub.set_defaults(func=CLIImage.create, args_model=CLIImageCreateArgs)

    sub = subparser.add_parser("images.edit")
    sub.add_argument("-m", "--model", type=str)
    sub.add_argument("-p", "--prompt", type=str, required=True)
    sub.add_argument("-n", "--num-images", type=int, default=1)
    sub.add_argument(
        "-I",
        "--image",
        type=str,
        required=True,
        help="Image to modify. Should be a local path and a PNG encoded image.",
    )
    sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
    sub.add_argument("--response-format", type=str, default="url")
    sub.add_argument(
        "-M",
        "--mask",
        type=str,
        required=False,
        help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
    )
    sub.set_defaults(func=CLIImage.edit, args_model=CLIImageEditArgs)

    sub = subparser.add_parser("images.create_variation")
    sub.add_argument("-m", "--model", type=str)
    sub.add_argument("-n", "--num-images", type=int, default=1)
    sub.add_argument(
        "-I",
        "--image",
        type=str,
        required=True,
        help="Image to modify. Should be a local path and a PNG encoded image.",
    )
    sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
    sub.add_argument("--response-format", type=str, default="url")
    sub.set_defaults(func=CLIImage.create_variation, args_model=CLIImageCreateVariationArgs)


class CLIImageCreateArgs(BaseModel):
    prompt: str
    num_images: int
    size: str
    response_format: str
    model: NotGivenOr[str] = NOT_GIVEN


class CLIImageCreateVariationArgs(BaseModel):
    image: str
    num_images: int
    size: str
    response_format: str
    model: NotGivenOr[str] = NOT_GIVEN


class CLIImageEditArgs(BaseModel):
    image: str
    num_images: int
    size: str
    response_format: str
    prompt: str
    mask: NotGivenOr[str] = NOT_GIVEN
    model: NotGivenOr[str] = NOT_GIVEN


class CLIImage:
    @staticmethod
    def create(args: CLIImageCreateArgs) -> None:
        image = get_client().images.generate(
            model=args.model,
            prompt=args.prompt,
            n=args.num_images,
            # casts required because the API is typed for enums
            # but we don't want to validate that here for forwards-compat
            size=cast(Any, args.size),
            response_format=cast(Any, args.response_format),
        )
        print_model(image)

    @staticmethod
    def create_variation(args: CLIImageCreateVariationArgs) -> None:
        with open(args.image, "rb") as file_reader:
            buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")

        image = get_client().images.create_variation(
            model=args.model,
            image=("image", buffer_reader),
            n=args.num_images,
            # casts required because the API is typed for enums
            # but we don't want to validate that here for forwards-compat
            size=cast(Any, args.size),
            response_format=cast(Any, args.response_format),
        )
        print_model(image)

    @staticmethod
    def edit(args: CLIImageEditArgs) -> None:
        with open(args.image, "rb") as file_reader:
            buffer_reader = BufferReader(file_reader.read(), desc="Image upload progress")

        if isinstance(args.mask, NotGiven):
            mask: NotGivenOr[BufferReader] = NOT_GIVEN
        else:
            with open(args.mask, "rb") as file_reader:
                mask = BufferReader(file_reader.read(), desc="Mask progress")

        image = get_client().images.edit(
            model=args.model,
            prompt=args.prompt,
            image=("image", buffer_reader),
            n=args.num_images,
            mask=("mask", mask) if not isinstance(mask, NotGiven) else mask,
            # casts required because the API is typed for enums
            # but we don't want to validate that here for forwards-compat
            size=cast(Any, args.size),
            response_format=cast(Any, args.response_format),
        )
        print_model(image)