diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/handler.py | 244 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/transformation.py | 82 |
2 files changed, 326 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/handler.py new file mode 100644 index 00000000..f738115a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/handler.py @@ -0,0 +1,244 @@ +""" +OpenAI Image Variations Handler +""" + +from typing import Callable, Optional + +import httpx +from openai import AsyncOpenAI, OpenAI + +import litellm +from litellm.types.utils import FileTypes, ImageResponse, LlmProviders +from litellm.utils import ProviderConfigManager + +from ...base_llm.image_variations.transformation import BaseImageVariationConfig +from ...custom_httpx.llm_http_handler import LiteLLMLoggingObj +from ..common_utils import OpenAIError + + +class OpenAIImageVariationsHandler: + def get_sync_client( + self, + client: Optional[OpenAI], + init_client_params: dict, + ): + if client is None: + openai_client = OpenAI( + **init_client_params, + ) + else: + openai_client = client + return openai_client + + def get_async_client( + self, client: Optional[AsyncOpenAI], init_client_params: dict + ) -> AsyncOpenAI: + if client is None: + openai_client = AsyncOpenAI( + **init_client_params, + ) + else: + openai_client = client + return openai_client + + async def async_image_variations( + self, + api_key: str, + api_base: str, + organization: Optional[str], + client: Optional[AsyncOpenAI], + data: dict, + headers: dict, + model: Optional[str], + timeout: float, + max_retries: int, + logging_obj: LiteLLMLoggingObj, + model_response: ImageResponse, + optional_params: dict, + litellm_params: dict, + image: FileTypes, + provider_config: BaseImageVariationConfig, + ) -> ImageResponse: + try: + init_client_params = { + "api_key": api_key, + "base_url": api_base, + "http_client": litellm.client_session, + "timeout": timeout, + "max_retries": max_retries, # type: ignore + "organization": organization, + } + + client = self.get_async_client( + client=client, init_client_params=init_client_params + ) + + raw_response = await client.images.with_raw_response.create_variation(**data) # type: ignore + response = raw_response.parse() + response_json = response.model_dump() + + ## LOGGING + logging_obj.post_call( + api_key=api_key, + original_response=response_json, + additional_args={ + "headers": headers, + "api_base": api_base, + }, + ) + + ## RESPONSE OBJECT + return provider_config.transform_response_image_variation( + model=model, + model_response=ImageResponse(**response_json), + raw_response=httpx.Response( + status_code=200, + request=httpx.Request( + method="GET", url="https://litellm.ai" + ), # mock request object + ), + logging_obj=logging_obj, + request_data=data, + image=image, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=None, + api_key=api_key, + ) + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) + + def image_variations( + self, + model_response: ImageResponse, + api_key: str, + api_base: str, + model: Optional[str], + image: FileTypes, + timeout: float, + custom_llm_provider: str, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + litellm_params: dict, + print_verbose: Optional[Callable] = None, + logger_fn=None, + client=None, + organization: Optional[str] = None, + headers: Optional[dict] = None, + ) -> ImageResponse: + try: + provider_config = ProviderConfigManager.get_provider_image_variation_config( + model=model or "", # openai defaults to dall-e-2 + provider=LlmProviders.OPENAI, + ) + + if provider_config is None: + raise ValueError( + f"image variation provider not found: {custom_llm_provider}." + ) + + max_retries = optional_params.pop("max_retries", 2) + + data = provider_config.transform_request_image_variation( + model=model, + image=image, + optional_params=optional_params, + headers=headers or {}, + ) + json_data = data.get("data") + if not json_data: + raise ValueError( + f"data field is required, for openai image variations. Got={data}" + ) + ## LOGGING + logging_obj.pre_call( + input="", + api_key=api_key, + additional_args={ + "headers": headers, + "api_base": api_base, + "complete_input_dict": data, + }, + ) + if litellm_params.get("async_call", False): + return self.async_image_variations( + api_base=api_base, + data=json_data, + headers=headers or {}, + model_response=model_response, + api_key=api_key, + logging_obj=logging_obj, + model=model, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + provider_config=provider_config, + image=image, + optional_params=optional_params, + litellm_params=litellm_params, + ) # type: ignore + + init_client_params = { + "api_key": api_key, + "base_url": api_base, + "http_client": litellm.client_session, + "timeout": timeout, + "max_retries": max_retries, # type: ignore + "organization": organization, + } + + client = self.get_sync_client( + client=client, init_client_params=init_client_params + ) + + raw_response = client.images.with_raw_response.create_variation(**json_data) # type: ignore + response = raw_response.parse() + response_json = response.model_dump() + + ## LOGGING + logging_obj.post_call( + api_key=api_key, + original_response=response_json, + additional_args={ + "headers": headers, + "api_base": api_base, + }, + ) + + ## RESPONSE OBJECT + return provider_config.transform_response_image_variation( + model=model, + model_response=ImageResponse(**response_json), + raw_response=httpx.Response( + status_code=200, + request=httpx.Request( + method="GET", url="https://litellm.ai" + ), # mock request object + ), + logging_obj=logging_obj, + request_data=json_data, + image=image, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=None, + api_key=api_key, + ) + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/transformation.py new file mode 100644 index 00000000..96d1a302 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/transformation.py @@ -0,0 +1,82 @@ +from typing import Any, List, Optional, Union + +from aiohttp import ClientResponse +from httpx import Headers, Response + +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.base_llm.image_variations.transformation import LiteLLMLoggingObj +from litellm.types.llms.openai import OpenAIImageVariationOptionalParams +from litellm.types.utils import FileTypes, HttpHandlerRequestFields, ImageResponse + +from ...base_llm.image_variations.transformation import BaseImageVariationConfig +from ..common_utils import OpenAIError + + +class OpenAIImageVariationConfig(BaseImageVariationConfig): + def get_supported_openai_params( + self, model: str + ) -> List[OpenAIImageVariationOptionalParams]: + return ["n", "size", "response_format", "user"] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + optional_params.update(non_default_params) + return optional_params + + def transform_request_image_variation( + self, + model: Optional[str], + image: FileTypes, + optional_params: dict, + headers: dict, + ) -> HttpHandlerRequestFields: + return { + "data": { + "image": image, + **optional_params, + } + } + + async def async_transform_response_image_variation( + self, + model: Optional[str], + raw_response: ClientResponse, + model_response: ImageResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + image: FileTypes, + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + ) -> ImageResponse: + return model_response + + def transform_response_image_variation( + self, + model: Optional[str], + raw_response: Response, + model_response: ImageResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + image: FileTypes, + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + ) -> ImageResponse: + return model_response + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, Headers] + ) -> BaseLLMException: + return OpenAIError( + status_code=status_code, + message=error_message, + headers=headers, + ) |