diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py | 204 |
1 files changed, 204 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py new file mode 100644 index 00000000..8b95deed --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py @@ -0,0 +1,204 @@ +import base64 +import time +from io import BytesIO +from typing import Any, List, Mapping, Optional, Tuple, Union + +from aiohttp import ClientResponse +from httpx import Headers, Response + +from litellm.llms.base_llm.chat.transformation import ( + BaseLLMException, + LiteLLMLoggingObj, +) +from litellm.types.llms.openai import ( + AllMessageValues, + OpenAIImageVariationOptionalParams, +) +from litellm.types.utils import ( + FileTypes, + HttpHandlerRequestFields, + ImageObject, + ImageResponse, +) + +from ...base_llm.image_variations.transformation import BaseImageVariationConfig +from ..common_utils import TopazException + + +class TopazImageVariationConfig(BaseImageVariationConfig): + def get_supported_openai_params( + self, model: str + ) -> List[OpenAIImageVariationOptionalParams]: + return ["response_format", "size"] + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + if api_key is None: + raise ValueError( + "API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`" + ) + return { + # "Content-Type": "multipart/form-data", + "Accept": "image/jpeg", + "X-API-Key": api_key, + } + + def get_complete_url( + self, + api_base: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + api_base = api_base or "https://api.topazlabs.com" + return f"{api_base}/image/v1/enhance" + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for k, v in non_default_params.items(): + if k == "response_format": + optional_params["output_format"] = v + elif k == "size": + split_v = v.split("x") + assert len(split_v) == 2, "size must be in the format of widthxheight" + optional_params["output_width"] = split_v[0] + optional_params["output_height"] = split_v[1] + return optional_params + + def prepare_file_tuple( + self, + file_data: FileTypes, + ) -> Tuple[str, Optional[FileTypes], str, Mapping[str, str]]: + """ + Convert various file input formats to a consistent tuple format for HTTPX + Returns: (filename, file_content, content_type, headers) + """ + # Default values + filename = "image.png" + content: Optional[FileTypes] = None + content_type = "image/png" + headers: Mapping[str, str] = {} + + if isinstance(file_data, (bytes, BytesIO)): + # Case 1: Just file content + content = file_data + elif isinstance(file_data, tuple): + if len(file_data) == 2: + # Case 2: (filename, content) + filename = file_data[0] or filename + content = file_data[1] + elif len(file_data) == 3: + # Case 3: (filename, content, content_type) + filename = file_data[0] or filename + content = file_data[1] + content_type = file_data[2] or content_type + elif len(file_data) == 4: + # Case 4: (filename, content, content_type, headers) + filename = file_data[0] or filename + content = file_data[1] + content_type = file_data[2] or content_type + headers = file_data[3] + + return (filename, content, content_type, headers) + + def transform_request_image_variation( + self, + model: Optional[str], + image: FileTypes, + optional_params: dict, + headers: dict, + ) -> HttpHandlerRequestFields: + + request_params = HttpHandlerRequestFields( + files={"image": self.prepare_file_tuple(image)}, + data=optional_params, + ) + + return request_params + + def _common_transform_response_image_variation( + self, + image_content: bytes, + response_ms: float, + ) -> ImageResponse: + + # Convert to base64 + base64_image = base64.b64encode(image_content).decode("utf-8") + + return ImageResponse( + created=int(time.time()), + data=[ + ImageObject( + b64_json=base64_image, + url=None, + revised_prompt=None, + ) + ], + response_ms=response_ms, + ) + + async def async_transform_response_image_variation( + self, + model: Optional[str], + raw_response: ClientResponse, + model_response: ImageResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + image: FileTypes, + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + ) -> ImageResponse: + image_content = await raw_response.read() + + response_ms = logging_obj.get_response_ms() + + return self._common_transform_response_image_variation( + image_content, response_ms + ) + + def transform_response_image_variation( + self, + model: Optional[str], + raw_response: Response, + model_response: ImageResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + image: FileTypes, + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + ) -> ImageResponse: + image_content = raw_response.content + + response_ms = ( + raw_response.elapsed.total_seconds() * 1000 + ) # Convert to milliseconds + + return self._common_transform_response_image_variation( + image_content, response_ms + ) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, Headers] + ) -> BaseLLMException: + return TopazException( + status_code=status_code, + message=error_message, + headers=headers, + ) |