aboutsummaryrefslogtreecommitdiff
"""
OpenAI Image Variations Handler
"""

from typing import Callable, Optional

import httpx
from openai import AsyncOpenAI, OpenAI

import litellm
from litellm.types.utils import FileTypes, ImageResponse, LlmProviders
from litellm.utils import ProviderConfigManager

from ...base_llm.image_variations.transformation import BaseImageVariationConfig
from ...custom_httpx.llm_http_handler import LiteLLMLoggingObj
from ..common_utils import OpenAIError


class OpenAIImageVariationsHandler:
    def get_sync_client(
        self,
        client: Optional[OpenAI],
        init_client_params: dict,
    ):
        if client is None:
            openai_client = OpenAI(
                **init_client_params,
            )
        else:
            openai_client = client
        return openai_client

    def get_async_client(
        self, client: Optional[AsyncOpenAI], init_client_params: dict
    ) -> AsyncOpenAI:
        if client is None:
            openai_client = AsyncOpenAI(
                **init_client_params,
            )
        else:
            openai_client = client
        return openai_client

    async def async_image_variations(
        self,
        api_key: str,
        api_base: str,
        organization: Optional[str],
        client: Optional[AsyncOpenAI],
        data: dict,
        headers: dict,
        model: Optional[str],
        timeout: float,
        max_retries: int,
        logging_obj: LiteLLMLoggingObj,
        model_response: ImageResponse,
        optional_params: dict,
        litellm_params: dict,
        image: FileTypes,
        provider_config: BaseImageVariationConfig,
    ) -> ImageResponse:
        try:
            init_client_params = {
                "api_key": api_key,
                "base_url": api_base,
                "http_client": litellm.client_session,
                "timeout": timeout,
                "max_retries": max_retries,  # type: ignore
                "organization": organization,
            }

            client = self.get_async_client(
                client=client, init_client_params=init_client_params
            )

            raw_response = await client.images.with_raw_response.create_variation(**data)  # type: ignore
            response = raw_response.parse()
            response_json = response.model_dump()

            ## LOGGING
            logging_obj.post_call(
                api_key=api_key,
                original_response=response_json,
                additional_args={
                    "headers": headers,
                    "api_base": api_base,
                },
            )

            ## RESPONSE OBJECT
            return provider_config.transform_response_image_variation(
                model=model,
                model_response=ImageResponse(**response_json),
                raw_response=httpx.Response(
                    status_code=200,
                    request=httpx.Request(
                        method="GET", url="https://litellm.ai"
                    ),  # mock request object
                ),
                logging_obj=logging_obj,
                request_data=data,
                image=image,
                optional_params=optional_params,
                litellm_params=litellm_params,
                encoding=None,
                api_key=api_key,
            )
        except Exception as e:
            status_code = getattr(e, "status_code", 500)
            error_headers = getattr(e, "headers", None)
            error_text = getattr(e, "text", str(e))
            error_response = getattr(e, "response", None)
            if error_headers is None and error_response:
                error_headers = getattr(error_response, "headers", None)
            raise OpenAIError(
                status_code=status_code, message=error_text, headers=error_headers
            )

    def image_variations(
        self,
        model_response: ImageResponse,
        api_key: str,
        api_base: str,
        model: Optional[str],
        image: FileTypes,
        timeout: float,
        custom_llm_provider: str,
        logging_obj: LiteLLMLoggingObj,
        optional_params: dict,
        litellm_params: dict,
        print_verbose: Optional[Callable] = None,
        logger_fn=None,
        client=None,
        organization: Optional[str] = None,
        headers: Optional[dict] = None,
    ) -> ImageResponse:
        try:
            provider_config = ProviderConfigManager.get_provider_image_variation_config(
                model=model or "",  # openai defaults to dall-e-2
                provider=LlmProviders.OPENAI,
            )

            if provider_config is None:
                raise ValueError(
                    f"image variation provider not found: {custom_llm_provider}."
                )

            max_retries = optional_params.pop("max_retries", 2)

            data = provider_config.transform_request_image_variation(
                model=model,
                image=image,
                optional_params=optional_params,
                headers=headers or {},
            )
            json_data = data.get("data")
            if not json_data:
                raise ValueError(
                    f"data field is required, for openai image variations. Got={data}"
                )
            ## LOGGING
            logging_obj.pre_call(
                input="",
                api_key=api_key,
                additional_args={
                    "headers": headers,
                    "api_base": api_base,
                    "complete_input_dict": data,
                },
            )
            if litellm_params.get("async_call", False):
                return self.async_image_variations(
                    api_base=api_base,
                    data=json_data,
                    headers=headers or {},
                    model_response=model_response,
                    api_key=api_key,
                    logging_obj=logging_obj,
                    model=model,
                    timeout=timeout,
                    max_retries=max_retries,
                    organization=organization,
                    client=client,
                    provider_config=provider_config,
                    image=image,
                    optional_params=optional_params,
                    litellm_params=litellm_params,
                )  # type: ignore

            init_client_params = {
                "api_key": api_key,
                "base_url": api_base,
                "http_client": litellm.client_session,
                "timeout": timeout,
                "max_retries": max_retries,  # type: ignore
                "organization": organization,
            }

            client = self.get_sync_client(
                client=client, init_client_params=init_client_params
            )

            raw_response = client.images.with_raw_response.create_variation(**json_data)  # type: ignore
            response = raw_response.parse()
            response_json = response.model_dump()

            ## LOGGING
            logging_obj.post_call(
                api_key=api_key,
                original_response=response_json,
                additional_args={
                    "headers": headers,
                    "api_base": api_base,
                },
            )

            ## RESPONSE OBJECT
            return provider_config.transform_response_image_variation(
                model=model,
                model_response=ImageResponse(**response_json),
                raw_response=httpx.Response(
                    status_code=200,
                    request=httpx.Request(
                        method="GET", url="https://litellm.ai"
                    ),  # mock request object
                ),
                logging_obj=logging_obj,
                request_data=json_data,
                image=image,
                optional_params=optional_params,
                litellm_params=litellm_params,
                encoding=None,
                api_key=api_key,
            )
        except Exception as e:
            status_code = getattr(e, "status_code", 500)
            error_headers = getattr(e, "headers", None)
            error_text = getattr(e, "text", str(e))
            error_response = getattr(e, "response", None)
            if error_headers is None and error_response:
                error_headers = getattr(error_response, "headers", None)
            raise OpenAIError(
                status_code=status_code, message=error_text, headers=error_headers
            )