aboutsummaryrefslogtreecommitdiff
from __future__ import annotations

from typing import TYPE_CHECKING, Any, cast
from argparse import ArgumentParser

from .._utils import get_client, print_model
from ..._types import NOT_GIVEN, NotGiven, NotGivenOr
from .._models import BaseModel
from .._progress import BufferReader

if TYPE_CHECKING:
    from argparse import _SubParsersAction


def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
    sub = subparser.add_parser("images.generate")
    sub.add_argument("-m", "--model", type=str)
    sub.add_argument("-p", "--prompt", type=str, required=True)
    sub.add_argument("-n", "--num-images", type=int, default=1)
    sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
    sub.add_argument("--response-format", type=str, default="url")
    sub.set_defaults(func=CLIImage.create, args_model=CLIImageCreateArgs)

    sub = subparser.add_parser("images.edit")
    sub.add_argument("-m", "--model", type=str)
    sub.add_argument("-p", "--prompt", type=str, required=True)
    sub.add_argument("-n", "--num-images", type=int, default=1)
    sub.add_argument(
        "-I",
        "--image",
        type=str,
        required=True,
        help="Image to modify. Should be a local path and a PNG encoded image.",
    )
    sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
    sub.add_argument("--response-format", type=str, default="url")
    sub.add_argument(
        "-M",
        "--mask",
        type=str,
        required=False,
        help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
    )
    sub.set_defaults(func=CLIImage.edit, args_model=CLIImageEditArgs)

    sub = subparser.add_parser("images.create_variation")
    sub.add_argument("-m", "--model", type=str)
    sub.add_argument("-n", "--num-images", type=int, default=1)
    sub.add_argument(
        "-I",
        "--image",
        type=str,
        required=True,
        help="Image to modify. Should be a local path and a PNG encoded image.",
    )
    sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
    sub.add_argument("--response-format", type=str, default="url")
    sub.set_defaults(func=CLIImage.create_variation, args_model=CLIImageCreateVariationArgs)


class CLIImageCreateArgs(BaseModel):
    prompt: str
    num_images: int
    size: str
    response_format: str
    model: NotGivenOr[str] = NOT_GIVEN


class CLIImageCreateVariationArgs(BaseModel):
    image: str
    num_images: int
    size: str
    response_format: str
    model: NotGivenOr[str] = NOT_GIVEN


class CLIImageEditArgs(BaseModel):
    image: str
    num_images: int
    size: str
    response_format: str
    prompt: str
    mask: NotGivenOr[str] = NOT_GIVEN
    model: NotGivenOr[str] = NOT_GIVEN


class CLIImage:
    @staticmethod
    def create(args: CLIImageCreateArgs) -> None:
        image = get_client().images.generate(
            model=args.model,
            prompt=args.prompt,
            n=args.num_images,
            # casts required because the API is typed for enums
            # but we don't want to validate that here for forwards-compat
            size=cast(Any, args.size),
            response_format=cast(Any, args.response_format),
        )
        print_model(image)

    @staticmethod
    def create_variation(args: CLIImageCreateVariationArgs) -> None:
        with open(args.image, "rb") as file_reader:
            buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")

        image = get_client().images.create_variation(
            model=args.model,
            image=("image", buffer_reader),
            n=args.num_images,
            # casts required because the API is typed for enums
            # but we don't want to validate that here for forwards-compat
            size=cast(Any, args.size),
            response_format=cast(Any, args.response_format),
        )
        print_model(image)

    @staticmethod
    def edit(args: CLIImageEditArgs) -> None:
        with open(args.image, "rb") as file_reader:
            buffer_reader = BufferReader(file_reader.read(), desc="Image upload progress")

        if isinstance(args.mask, NotGiven):
            mask: NotGivenOr[BufferReader] = NOT_GIVEN
        else:
            with open(args.mask, "rb") as file_reader:
                mask = BufferReader(file_reader.read(), desc="Mask progress")

        image = get_client().images.edit(
            model=args.model,
            prompt=args.prompt,
            image=("image", buffer_reader),
            n=args.num_images,
            mask=("mask", mask) if not isinstance(mask, NotGiven) else mask,
            # casts required because the API is typed for enums
            # but we don't want to validate that here for forwards-compat
            size=cast(Any, args.size),
            response_format=cast(Any, args.response_format),
        )
        print_model(image)