about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/topaz
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/topaz
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/topaz')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/topaz/common_utils.py35
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py204
2 files changed, 239 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/topaz/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/topaz/common_utils.py
new file mode 100644
index 00000000..4ef2315d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/topaz/common_utils.py
@@ -0,0 +1,35 @@
+from typing import List, Optional
+
+from litellm.secret_managers.main import get_secret_str
+
+from ..base_llm.base_utils import BaseLLMModelInfo
+from ..base_llm.chat.transformation import BaseLLMException
+
+
+class TopazException(BaseLLMException):
+    pass
+
+
+class TopazModelInfo(BaseLLMModelInfo):
+    def get_models(self) -> List[str]:
+        return [
+            "topaz/Standard V2",
+            "topaz/Low Resolution V2",
+            "topaz/CGI",
+            "topaz/High Resolution V2",
+            "topaz/Text Refine",
+        ]
+
+    @staticmethod
+    def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
+        return api_key or get_secret_str("TOPAZ_API_KEY")
+
+    @staticmethod
+    def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
+        return (
+            api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com"
+        )
+
+    @staticmethod
+    def get_base_model(model: str) -> str:
+        return model
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py
new file mode 100644
index 00000000..8b95deed
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/topaz/image_variations/transformation.py
@@ -0,0 +1,204 @@
+import base64
+import time
+from io import BytesIO
+from typing import Any, List, Mapping, Optional, Tuple, Union
+
+from aiohttp import ClientResponse
+from httpx import Headers, Response
+
+from litellm.llms.base_llm.chat.transformation import (
+    BaseLLMException,
+    LiteLLMLoggingObj,
+)
+from litellm.types.llms.openai import (
+    AllMessageValues,
+    OpenAIImageVariationOptionalParams,
+)
+from litellm.types.utils import (
+    FileTypes,
+    HttpHandlerRequestFields,
+    ImageObject,
+    ImageResponse,
+)
+
+from ...base_llm.image_variations.transformation import BaseImageVariationConfig
+from ..common_utils import TopazException
+
+
+class TopazImageVariationConfig(BaseImageVariationConfig):
+    def get_supported_openai_params(
+        self, model: str
+    ) -> List[OpenAIImageVariationOptionalParams]:
+        return ["response_format", "size"]
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> dict:
+        if api_key is None:
+            raise ValueError(
+                "API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
+            )
+        return {
+            # "Content-Type": "multipart/form-data",
+            "Accept": "image/jpeg",
+            "X-API-Key": api_key,
+        }
+
+    def get_complete_url(
+        self,
+        api_base: Optional[str],
+        model: str,
+        optional_params: dict,
+        litellm_params: dict,
+        stream: Optional[bool] = None,
+    ) -> str:
+        api_base = api_base or "https://api.topazlabs.com"
+        return f"{api_base}/image/v1/enhance"
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ) -> dict:
+        for k, v in non_default_params.items():
+            if k == "response_format":
+                optional_params["output_format"] = v
+            elif k == "size":
+                split_v = v.split("x")
+                assert len(split_v) == 2, "size must be in the format of widthxheight"
+                optional_params["output_width"] = split_v[0]
+                optional_params["output_height"] = split_v[1]
+        return optional_params
+
+    def prepare_file_tuple(
+        self,
+        file_data: FileTypes,
+    ) -> Tuple[str, Optional[FileTypes], str, Mapping[str, str]]:
+        """
+        Convert various file input formats to a consistent tuple format for HTTPX
+        Returns: (filename, file_content, content_type, headers)
+        """
+        # Default values
+        filename = "image.png"
+        content: Optional[FileTypes] = None
+        content_type = "image/png"
+        headers: Mapping[str, str] = {}
+
+        if isinstance(file_data, (bytes, BytesIO)):
+            # Case 1: Just file content
+            content = file_data
+        elif isinstance(file_data, tuple):
+            if len(file_data) == 2:
+                # Case 2: (filename, content)
+                filename = file_data[0] or filename
+                content = file_data[1]
+            elif len(file_data) == 3:
+                # Case 3: (filename, content, content_type)
+                filename = file_data[0] or filename
+                content = file_data[1]
+                content_type = file_data[2] or content_type
+            elif len(file_data) == 4:
+                # Case 4: (filename, content, content_type, headers)
+                filename = file_data[0] or filename
+                content = file_data[1]
+                content_type = file_data[2] or content_type
+                headers = file_data[3]
+
+        return (filename, content, content_type, headers)
+
+    def transform_request_image_variation(
+        self,
+        model: Optional[str],
+        image: FileTypes,
+        optional_params: dict,
+        headers: dict,
+    ) -> HttpHandlerRequestFields:
+
+        request_params = HttpHandlerRequestFields(
+            files={"image": self.prepare_file_tuple(image)},
+            data=optional_params,
+        )
+
+        return request_params
+
+    def _common_transform_response_image_variation(
+        self,
+        image_content: bytes,
+        response_ms: float,
+    ) -> ImageResponse:
+
+        # Convert to base64
+        base64_image = base64.b64encode(image_content).decode("utf-8")
+
+        return ImageResponse(
+            created=int(time.time()),
+            data=[
+                ImageObject(
+                    b64_json=base64_image,
+                    url=None,
+                    revised_prompt=None,
+                )
+            ],
+            response_ms=response_ms,
+        )
+
+    async def async_transform_response_image_variation(
+        self,
+        model: Optional[str],
+        raw_response: ClientResponse,
+        model_response: ImageResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: dict,
+        image: FileTypes,
+        optional_params: dict,
+        litellm_params: dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+    ) -> ImageResponse:
+        image_content = await raw_response.read()
+
+        response_ms = logging_obj.get_response_ms()
+
+        return self._common_transform_response_image_variation(
+            image_content, response_ms
+        )
+
+    def transform_response_image_variation(
+        self,
+        model: Optional[str],
+        raw_response: Response,
+        model_response: ImageResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: dict,
+        image: FileTypes,
+        optional_params: dict,
+        litellm_params: dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+    ) -> ImageResponse:
+        image_content = raw_response.content
+
+        response_ms = (
+            raw_response.elapsed.total_seconds() * 1000
+        )  # Convert to milliseconds
+
+        return self._common_transform_response_image_variation(
+            image_content, response_ms
+        )
+
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[dict, Headers]
+    ) -> BaseLLMException:
+        return TopazException(
+            status_code=status_code,
+            message=error_message,
+            headers=headers,
+        )