aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py204
1 files changed, 204 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py
new file mode 100644
index 00000000..8b95deed
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py
@@ -0,0 +1,204 @@
+import base64
+import time
+from io import BytesIO
+from typing import Any, List, Mapping, Optional, Tuple, Union
+
+from aiohttp import ClientResponse
+from httpx import Headers, Response
+
+from litellm.llms.base_llm.chat.transformation import (
+ BaseLLMException,
+ LiteLLMLoggingObj,
+)
+from litellm.types.llms.openai import (
+ AllMessageValues,
+ OpenAIImageVariationOptionalParams,
+)
+from litellm.types.utils import (
+ FileTypes,
+ HttpHandlerRequestFields,
+ ImageObject,
+ ImageResponse,
+)
+
+from ...base_llm.image_variations.transformation import BaseImageVariationConfig
+from ..common_utils import TopazException
+
+
+class TopazImageVariationConfig(BaseImageVariationConfig):
+ def get_supported_openai_params(
+ self, model: str
+ ) -> List[OpenAIImageVariationOptionalParams]:
+ return ["response_format", "size"]
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ ) -> dict:
+ if api_key is None:
+ raise ValueError(
+ "API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
+ )
+ return {
+ # "Content-Type": "multipart/form-data",
+ "Accept": "image/jpeg",
+ "X-API-Key": api_key,
+ }
+
+ def get_complete_url(
+ self,
+ api_base: Optional[str],
+ model: str,
+ optional_params: dict,
+ litellm_params: dict,
+ stream: Optional[bool] = None,
+ ) -> str:
+ api_base = api_base or "https://api.topazlabs.com"
+ return f"{api_base}/image/v1/enhance"
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+ for k, v in non_default_params.items():
+ if k == "response_format":
+ optional_params["output_format"] = v
+ elif k == "size":
+ split_v = v.split("x")
+ assert len(split_v) == 2, "size must be in the format of widthxheight"
+ optional_params["output_width"] = split_v[0]
+ optional_params["output_height"] = split_v[1]
+ return optional_params
+
+ def prepare_file_tuple(
+ self,
+ file_data: FileTypes,
+ ) -> Tuple[str, Optional[FileTypes], str, Mapping[str, str]]:
+ """
+ Convert various file input formats to a consistent tuple format for HTTPX
+ Returns: (filename, file_content, content_type, headers)
+ """
+ # Default values
+ filename = "image.png"
+ content: Optional[FileTypes] = None
+ content_type = "image/png"
+ headers: Mapping[str, str] = {}
+
+ if isinstance(file_data, (bytes, BytesIO)):
+ # Case 1: Just file content
+ content = file_data
+ elif isinstance(file_data, tuple):
+ if len(file_data) == 2:
+ # Case 2: (filename, content)
+ filename = file_data[0] or filename
+ content = file_data[1]
+ elif len(file_data) == 3:
+ # Case 3: (filename, content, content_type)
+ filename = file_data[0] or filename
+ content = file_data[1]
+ content_type = file_data[2] or content_type
+ elif len(file_data) == 4:
+ # Case 4: (filename, content, content_type, headers)
+ filename = file_data[0] or filename
+ content = file_data[1]
+ content_type = file_data[2] or content_type
+ headers = file_data[3]
+
+ return (filename, content, content_type, headers)
+
+ def transform_request_image_variation(
+ self,
+ model: Optional[str],
+ image: FileTypes,
+ optional_params: dict,
+ headers: dict,
+ ) -> HttpHandlerRequestFields:
+
+ request_params = HttpHandlerRequestFields(
+ files={"image": self.prepare_file_tuple(image)},
+ data=optional_params,
+ )
+
+ return request_params
+
+ def _common_transform_response_image_variation(
+ self,
+ image_content: bytes,
+ response_ms: float,
+ ) -> ImageResponse:
+
+ # Convert to base64
+ base64_image = base64.b64encode(image_content).decode("utf-8")
+
+ return ImageResponse(
+ created=int(time.time()),
+ data=[
+ ImageObject(
+ b64_json=base64_image,
+ url=None,
+ revised_prompt=None,
+ )
+ ],
+ response_ms=response_ms,
+ )
+
+ async def async_transform_response_image_variation(
+ self,
+ model: Optional[str],
+ raw_response: ClientResponse,
+ model_response: ImageResponse,
+ logging_obj: LiteLLMLoggingObj,
+ request_data: dict,
+ image: FileTypes,
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ ) -> ImageResponse:
+ image_content = await raw_response.read()
+
+ response_ms = logging_obj.get_response_ms()
+
+ return self._common_transform_response_image_variation(
+ image_content, response_ms
+ )
+
+ def transform_response_image_variation(
+ self,
+ model: Optional[str],
+ raw_response: Response,
+ model_response: ImageResponse,
+ logging_obj: LiteLLMLoggingObj,
+ request_data: dict,
+ image: FileTypes,
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ ) -> ImageResponse:
+ image_content = raw_response.content
+
+ response_ms = (
+ raw_response.elapsed.total_seconds() * 1000
+ ) # Convert to milliseconds
+
+ return self._common_transform_response_image_variation(
+ image_content, response_ms
+ )
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, Headers]
+ ) -> BaseLLMException:
+ return TopazException(
+ status_code=status_code,
+ message=error_message,
+ headers=headers,
+ )