aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/handler.py244
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/transformation.py82
2 files changed, 326 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/handler.py
new file mode 100644
index 00000000..f738115a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/handler.py
@@ -0,0 +1,244 @@
+"""
+OpenAI Image Variations Handler
+"""
+
+from typing import Callable, Optional
+
+import httpx
+from openai import AsyncOpenAI, OpenAI
+
+import litellm
+from litellm.types.utils import FileTypes, ImageResponse, LlmProviders
+from litellm.utils import ProviderConfigManager
+
+from ...base_llm.image_variations.transformation import BaseImageVariationConfig
+from ...custom_httpx.llm_http_handler import LiteLLMLoggingObj
+from ..common_utils import OpenAIError
+
+
+class OpenAIImageVariationsHandler:
+ def get_sync_client(
+ self,
+ client: Optional[OpenAI],
+ init_client_params: dict,
+ ):
+ if client is None:
+ openai_client = OpenAI(
+ **init_client_params,
+ )
+ else:
+ openai_client = client
+ return openai_client
+
+ def get_async_client(
+ self, client: Optional[AsyncOpenAI], init_client_params: dict
+ ) -> AsyncOpenAI:
+ if client is None:
+ openai_client = AsyncOpenAI(
+ **init_client_params,
+ )
+ else:
+ openai_client = client
+ return openai_client
+
+ async def async_image_variations(
+ self,
+ api_key: str,
+ api_base: str,
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ data: dict,
+ headers: dict,
+ model: Optional[str],
+ timeout: float,
+ max_retries: int,
+ logging_obj: LiteLLMLoggingObj,
+ model_response: ImageResponse,
+ optional_params: dict,
+ litellm_params: dict,
+ image: FileTypes,
+ provider_config: BaseImageVariationConfig,
+ ) -> ImageResponse:
+ try:
+ init_client_params = {
+ "api_key": api_key,
+ "base_url": api_base,
+ "http_client": litellm.client_session,
+ "timeout": timeout,
+ "max_retries": max_retries, # type: ignore
+ "organization": organization,
+ }
+
+ client = self.get_async_client(
+ client=client, init_client_params=init_client_params
+ )
+
+ raw_response = await client.images.with_raw_response.create_variation(**data) # type: ignore
+ response = raw_response.parse()
+ response_json = response.model_dump()
+
+ ## LOGGING
+ logging_obj.post_call(
+ api_key=api_key,
+ original_response=response_json,
+ additional_args={
+ "headers": headers,
+ "api_base": api_base,
+ },
+ )
+
+ ## RESPONSE OBJECT
+ return provider_config.transform_response_image_variation(
+ model=model,
+ model_response=ImageResponse(**response_json),
+ raw_response=httpx.Response(
+ status_code=200,
+ request=httpx.Request(
+ method="GET", url="https://litellm.ai"
+ ), # mock request object
+ ),
+ logging_obj=logging_obj,
+ request_data=data,
+ image=image,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=None,
+ api_key=api_key,
+ )
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_text = getattr(e, "text", str(e))
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise OpenAIError(
+ status_code=status_code, message=error_text, headers=error_headers
+ )
+
+ def image_variations(
+ self,
+ model_response: ImageResponse,
+ api_key: str,
+ api_base: str,
+ model: Optional[str],
+ image: FileTypes,
+ timeout: float,
+ custom_llm_provider: str,
+ logging_obj: LiteLLMLoggingObj,
+ optional_params: dict,
+ litellm_params: dict,
+ print_verbose: Optional[Callable] = None,
+ logger_fn=None,
+ client=None,
+ organization: Optional[str] = None,
+ headers: Optional[dict] = None,
+ ) -> ImageResponse:
+ try:
+ provider_config = ProviderConfigManager.get_provider_image_variation_config(
+ model=model or "", # openai defaults to dall-e-2
+ provider=LlmProviders.OPENAI,
+ )
+
+ if provider_config is None:
+ raise ValueError(
+ f"image variation provider not found: {custom_llm_provider}."
+ )
+
+ max_retries = optional_params.pop("max_retries", 2)
+
+ data = provider_config.transform_request_image_variation(
+ model=model,
+ image=image,
+ optional_params=optional_params,
+ headers=headers or {},
+ )
+ json_data = data.get("data")
+ if not json_data:
+ raise ValueError(
+ f"data field is required, for openai image variations. Got={data}"
+ )
+ ## LOGGING
+ logging_obj.pre_call(
+ input="",
+ api_key=api_key,
+ additional_args={
+ "headers": headers,
+ "api_base": api_base,
+ "complete_input_dict": data,
+ },
+ )
+ if litellm_params.get("async_call", False):
+ return self.async_image_variations(
+ api_base=api_base,
+ data=json_data,
+ headers=headers or {},
+ model_response=model_response,
+ api_key=api_key,
+ logging_obj=logging_obj,
+ model=model,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ provider_config=provider_config,
+ image=image,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ ) # type: ignore
+
+ init_client_params = {
+ "api_key": api_key,
+ "base_url": api_base,
+ "http_client": litellm.client_session,
+ "timeout": timeout,
+ "max_retries": max_retries, # type: ignore
+ "organization": organization,
+ }
+
+ client = self.get_sync_client(
+ client=client, init_client_params=init_client_params
+ )
+
+ raw_response = client.images.with_raw_response.create_variation(**json_data) # type: ignore
+ response = raw_response.parse()
+ response_json = response.model_dump()
+
+ ## LOGGING
+ logging_obj.post_call(
+ api_key=api_key,
+ original_response=response_json,
+ additional_args={
+ "headers": headers,
+ "api_base": api_base,
+ },
+ )
+
+ ## RESPONSE OBJECT
+ return provider_config.transform_response_image_variation(
+ model=model,
+ model_response=ImageResponse(**response_json),
+ raw_response=httpx.Response(
+ status_code=200,
+ request=httpx.Request(
+ method="GET", url="https://litellm.ai"
+ ), # mock request object
+ ),
+ logging_obj=logging_obj,
+ request_data=json_data,
+ image=image,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=None,
+ api_key=api_key,
+ )
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_text = getattr(e, "text", str(e))
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise OpenAIError(
+ status_code=status_code, message=error_text, headers=error_headers
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/transformation.py
new file mode 100644
index 00000000..96d1a302
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/openai/image_variations/transformation.py
@@ -0,0 +1,82 @@
+from typing import Any, List, Optional, Union
+
+from aiohttp import ClientResponse
+from httpx import Headers, Response
+
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.llms.base_llm.image_variations.transformation import LiteLLMLoggingObj
+from litellm.types.llms.openai import OpenAIImageVariationOptionalParams
+from litellm.types.utils import FileTypes, HttpHandlerRequestFields, ImageResponse
+
+from ...base_llm.image_variations.transformation import BaseImageVariationConfig
+from ..common_utils import OpenAIError
+
+
+class OpenAIImageVariationConfig(BaseImageVariationConfig):
+ def get_supported_openai_params(
+ self, model: str
+ ) -> List[OpenAIImageVariationOptionalParams]:
+ return ["n", "size", "response_format", "user"]
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+ optional_params.update(non_default_params)
+ return optional_params
+
+ def transform_request_image_variation(
+ self,
+ model: Optional[str],
+ image: FileTypes,
+ optional_params: dict,
+ headers: dict,
+ ) -> HttpHandlerRequestFields:
+ return {
+ "data": {
+ "image": image,
+ **optional_params,
+ }
+ }
+
+ async def async_transform_response_image_variation(
+ self,
+ model: Optional[str],
+ raw_response: ClientResponse,
+ model_response: ImageResponse,
+ logging_obj: LiteLLMLoggingObj,
+ request_data: dict,
+ image: FileTypes,
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ ) -> ImageResponse:
+ return model_response
+
+ def transform_response_image_variation(
+ self,
+ model: Optional[str],
+ raw_response: Response,
+ model_response: ImageResponse,
+ logging_obj: LiteLLMLoggingObj,
+ request_data: dict,
+ image: FileTypes,
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ ) -> ImageResponse:
+ return model_response
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, Headers]
+ ) -> BaseLLMException:
+ return OpenAIError(
+ status_code=status_code,
+ message=error_message,
+ headers=headers,
+ )