aboutsummaryrefslogtreecommitdiff
import base64
import time
from io import BytesIO
from typing import Any, List, Mapping, Optional, Tuple, Union

from aiohttp import ClientResponse
from httpx import Headers, Response

from litellm.llms.base_llm.chat.transformation import (
    BaseLLMException,
    LiteLLMLoggingObj,
)
from litellm.types.llms.openai import (
    AllMessageValues,
    OpenAIImageVariationOptionalParams,
)
from litellm.types.utils import (
    FileTypes,
    HttpHandlerRequestFields,
    ImageObject,
    ImageResponse,
)

from ...base_llm.image_variations.transformation import BaseImageVariationConfig
from ..common_utils import TopazException


class TopazImageVariationConfig(BaseImageVariationConfig):
    def get_supported_openai_params(
        self, model: str
    ) -> List[OpenAIImageVariationOptionalParams]:
        return ["response_format", "size"]

    def validate_environment(
        self,
        headers: dict,
        model: str,
        messages: List[AllMessageValues],
        optional_params: dict,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
    ) -> dict:
        if api_key is None:
            raise ValueError(
                "API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
            )
        return {
            # "Content-Type": "multipart/form-data",
            "Accept": "image/jpeg",
            "X-API-Key": api_key,
        }

    def get_complete_url(
        self,
        api_base: Optional[str],
        model: str,
        optional_params: dict,
        litellm_params: dict,
        stream: Optional[bool] = None,
    ) -> str:
        api_base = api_base or "https://api.topazlabs.com"
        return f"{api_base}/image/v1/enhance"

    def map_openai_params(
        self,
        non_default_params: dict,
        optional_params: dict,
        model: str,
        drop_params: bool,
    ) -> dict:
        for k, v in non_default_params.items():
            if k == "response_format":
                optional_params["output_format"] = v
            elif k == "size":
                split_v = v.split("x")
                assert len(split_v) == 2, "size must be in the format of widthxheight"
                optional_params["output_width"] = split_v[0]
                optional_params["output_height"] = split_v[1]
        return optional_params

    def prepare_file_tuple(
        self,
        file_data: FileTypes,
    ) -> Tuple[str, Optional[FileTypes], str, Mapping[str, str]]:
        """
        Convert various file input formats to a consistent tuple format for HTTPX
        Returns: (filename, file_content, content_type, headers)
        """
        # Default values
        filename = "image.png"
        content: Optional[FileTypes] = None
        content_type = "image/png"
        headers: Mapping[str, str] = {}

        if isinstance(file_data, (bytes, BytesIO)):
            # Case 1: Just file content
            content = file_data
        elif isinstance(file_data, tuple):
            if len(file_data) == 2:
                # Case 2: (filename, content)
                filename = file_data[0] or filename
                content = file_data[1]
            elif len(file_data) == 3:
                # Case 3: (filename, content, content_type)
                filename = file_data[0] or filename
                content = file_data[1]
                content_type = file_data[2] or content_type
            elif len(file_data) == 4:
                # Case 4: (filename, content, content_type, headers)
                filename = file_data[0] or filename
                content = file_data[1]
                content_type = file_data[2] or content_type
                headers = file_data[3]

        return (filename, content, content_type, headers)

    def transform_request_image_variation(
        self,
        model: Optional[str],
        image: FileTypes,
        optional_params: dict,
        headers: dict,
    ) -> HttpHandlerRequestFields:

        request_params = HttpHandlerRequestFields(
            files={"image": self.prepare_file_tuple(image)},
            data=optional_params,
        )

        return request_params

    def _common_transform_response_image_variation(
        self,
        image_content: bytes,
        response_ms: float,
    ) -> ImageResponse:

        # Convert to base64
        base64_image = base64.b64encode(image_content).decode("utf-8")

        return ImageResponse(
            created=int(time.time()),
            data=[
                ImageObject(
                    b64_json=base64_image,
                    url=None,
                    revised_prompt=None,
                )
            ],
            response_ms=response_ms,
        )

    async def async_transform_response_image_variation(
        self,
        model: Optional[str],
        raw_response: ClientResponse,
        model_response: ImageResponse,
        logging_obj: LiteLLMLoggingObj,
        request_data: dict,
        image: FileTypes,
        optional_params: dict,
        litellm_params: dict,
        encoding: Any,
        api_key: Optional[str] = None,
    ) -> ImageResponse:
        image_content = await raw_response.read()

        response_ms = logging_obj.get_response_ms()

        return self._common_transform_response_image_variation(
            image_content, response_ms
        )

    def transform_response_image_variation(
        self,
        model: Optional[str],
        raw_response: Response,
        model_response: ImageResponse,
        logging_obj: LiteLLMLoggingObj,
        request_data: dict,
        image: FileTypes,
        optional_params: dict,
        litellm_params: dict,
        encoding: Any,
        api_key: Optional[str] = None,
    ) -> ImageResponse:
        image_content = raw_response.content

        response_ms = (
            raw_response.elapsed.total_seconds() * 1000
        )  # Convert to milliseconds

        return self._common_transform_response_image_variation(
            image_content, response_ms
        )

    def get_error_class(
        self, error_message: str, status_code: int, headers: Union[dict, Headers]
    ) -> BaseLLMException:
        return TopazException(
            status_code=status_code,
            message=error_message,
            headers=headers,
        )