about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py294
1 files changed, 294 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py
new file mode 100644
index 00000000..f63d1ce1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py
@@ -0,0 +1,294 @@
+import json
+from typing import List, Literal, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.custom_httpx.http_handler import (
+    AsyncHTTPHandler,
+    HTTPHandler,
+    get_async_httpx_client,
+)
+from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
+    VertexAIError,
+    VertexLLM,
+)
+from litellm.types.llms.vertex_ai import (
+    Instance,
+    InstanceImage,
+    InstanceVideo,
+    MultimodalPredictions,
+    VertexMultimodalEmbeddingRequest,
+)
+from litellm.types.utils import Embedding, EmbeddingResponse
+from litellm.utils import is_base64_encoded
+
+
+class VertexMultimodalEmbedding(VertexLLM):
+    def __init__(self) -> None:
+        super().__init__()
+        self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
+            "multimodalembedding",
+            "multimodalembedding@001",
+        ]
+
+    def multimodal_embedding(
+        self,
+        model: str,
+        input: Union[list, str],
+        print_verbose,
+        model_response: EmbeddingResponse,
+        custom_llm_provider: Literal["gemini", "vertex_ai"],
+        optional_params: dict,
+        logging_obj: LiteLLMLoggingObj,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+        encoding=None,
+        vertex_project=None,
+        vertex_location=None,
+        vertex_credentials=None,
+        aembedding=False,
+        timeout=300,
+        client=None,
+    ) -> EmbeddingResponse:
+
+        _auth_header, vertex_project = self._ensure_access_token(
+            credentials=vertex_credentials,
+            project_id=vertex_project,
+            custom_llm_provider=custom_llm_provider,
+        )
+
+        auth_header, url = self._get_token_and_url(
+            model=model,
+            auth_header=_auth_header,
+            gemini_api_key=api_key,
+            vertex_project=vertex_project,
+            vertex_location=vertex_location,
+            vertex_credentials=vertex_credentials,
+            stream=None,
+            custom_llm_provider=custom_llm_provider,
+            api_base=api_base,
+            should_use_v1beta1_features=False,
+            mode="embedding",
+        )
+
+        if client is None:
+            _params = {}
+            if timeout is not None:
+                if isinstance(timeout, float) or isinstance(timeout, int):
+                    _httpx_timeout = httpx.Timeout(timeout)
+                    _params["timeout"] = _httpx_timeout
+            else:
+                _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+            sync_handler: HTTPHandler = HTTPHandler(**_params)  # type: ignore
+        else:
+            sync_handler = client  # type: ignore
+
+        optional_params = optional_params or {}
+
+        request_data = VertexMultimodalEmbeddingRequest()
+
+        if "instances" in optional_params:
+            request_data["instances"] = optional_params["instances"]
+        elif isinstance(input, list):
+            vertex_instances: List[Instance] = self.process_openai_embedding_input(
+                _input=input
+            )
+            request_data["instances"] = vertex_instances
+
+        else:
+            # construct instances
+            vertex_request_instance = Instance(**optional_params)
+
+            if isinstance(input, str):
+                vertex_request_instance = self._process_input_element(input)
+
+            request_data["instances"] = [vertex_request_instance]
+
+        headers = {
+            "Content-Type": "application/json; charset=utf-8",
+            "Authorization": f"Bearer {auth_header}",
+        }
+
+        ## LOGGING
+        logging_obj.pre_call(
+            input=input,
+            api_key="",
+            additional_args={
+                "complete_input_dict": request_data,
+                "api_base": url,
+                "headers": headers,
+            },
+        )
+
+        if aembedding is True:
+            return self.async_multimodal_embedding(  # type: ignore
+                model=model,
+                api_base=url,
+                data=request_data,
+                timeout=timeout,
+                headers=headers,
+                client=client,
+                model_response=model_response,
+            )
+
+        response = sync_handler.post(
+            url=url,
+            headers=headers,
+            data=json.dumps(request_data),
+        )
+
+        if response.status_code != 200:
+            raise Exception(f"Error: {response.status_code} {response.text}")
+
+        _json_response = response.json()
+        if "predictions" not in _json_response:
+            raise litellm.InternalServerError(
+                message=f"embedding response does not contain 'predictions', got {_json_response}",
+                llm_provider="vertex_ai",
+                model=model,
+            )
+        _predictions = _json_response["predictions"]
+        vertex_predictions = MultimodalPredictions(predictions=_predictions)
+        model_response.data = self.transform_embedding_response_to_openai(
+            predictions=vertex_predictions
+        )
+        model_response.model = model
+
+        return model_response
+
+    async def async_multimodal_embedding(
+        self,
+        model: str,
+        api_base: str,
+        data: VertexMultimodalEmbeddingRequest,
+        model_response: litellm.EmbeddingResponse,
+        timeout: Optional[Union[float, httpx.Timeout]],
+        headers={},
+        client: Optional[AsyncHTTPHandler] = None,
+    ) -> litellm.EmbeddingResponse:
+        if client is None:
+            _params = {}
+            if timeout is not None:
+                if isinstance(timeout, float) or isinstance(timeout, int):
+                    timeout = httpx.Timeout(timeout)
+                _params["timeout"] = timeout
+            client = get_async_httpx_client(
+                llm_provider=litellm.LlmProviders.VERTEX_AI,
+                params={"timeout": timeout},
+            )
+        else:
+            client = client  # type: ignore
+
+        try:
+            response = await client.post(api_base, headers=headers, json=data)  # type: ignore
+            response.raise_for_status()
+        except httpx.HTTPStatusError as err:
+            error_code = err.response.status_code
+            raise VertexAIError(status_code=error_code, message=err.response.text)
+        except httpx.TimeoutException:
+            raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+        _json_response = response.json()
+        if "predictions" not in _json_response:
+            raise litellm.InternalServerError(
+                message=f"embedding response does not contain 'predictions', got {_json_response}",
+                llm_provider="vertex_ai",
+                model=model,
+            )
+        _predictions = _json_response["predictions"]
+
+        vertex_predictions = MultimodalPredictions(predictions=_predictions)
+        model_response.data = self.transform_embedding_response_to_openai(
+            predictions=vertex_predictions
+        )
+        model_response.model = model
+
+        return model_response
+
+    def _process_input_element(self, input_element: str) -> Instance:
+        """
+        Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text.
+
+        Args:
+            input_element (str): The input element to process.
+
+        Returns:
+            Dict[str, Any]: A dictionary representing the processed input element.
+        """
+        if len(input_element) == 0:
+            return Instance(text=input_element)
+        elif "gs://" in input_element:
+            if "mp4" in input_element:
+                return Instance(video=InstanceVideo(gcsUri=input_element))
+            else:
+                return Instance(image=InstanceImage(gcsUri=input_element))
+        elif is_base64_encoded(s=input_element):
+            return Instance(image=InstanceImage(bytesBase64Encoded=input_element))
+        else:
+            return Instance(text=input_element)
+
+    def process_openai_embedding_input(
+        self, _input: Union[list, str]
+    ) -> List[Instance]:
+        """
+        Process the input for multimodal embedding requests.
+
+        Args:
+            _input (Union[list, str]): The input data to process.
+
+        Returns:
+            List[Instance]: A list of processed VertexAI Instance objects.
+        """
+
+        _input_list = None
+        if not isinstance(_input, list):
+            _input_list = [_input]
+        else:
+            _input_list = _input
+
+        processed_instances = []
+        for element in _input_list:
+            if isinstance(element, str):
+                instance = Instance(**self._process_input_element(element))
+            elif isinstance(element, dict):
+                instance = Instance(**element)
+            else:
+                raise ValueError(f"Unsupported input type: {type(element)}")
+            processed_instances.append(instance)
+
+        return processed_instances
+
+    def transform_embedding_response_to_openai(
+        self, predictions: MultimodalPredictions
+    ) -> List[Embedding]:
+
+        openai_embeddings: List[Embedding] = []
+        if "predictions" in predictions:
+            for idx, _prediction in enumerate(predictions["predictions"]):
+                if _prediction:
+                    if "textEmbedding" in _prediction:
+                        openai_embedding_object = Embedding(
+                            embedding=_prediction["textEmbedding"],
+                            index=idx,
+                            object="embedding",
+                        )
+                        openai_embeddings.append(openai_embedding_object)
+                    elif "imageEmbedding" in _prediction:
+                        openai_embedding_object = Embedding(
+                            embedding=_prediction["imageEmbedding"],
+                            index=idx,
+                            object="embedding",
+                        )
+                        openai_embeddings.append(openai_embedding_object)
+                    elif "videoEmbeddings" in _prediction:
+                        for video_embedding in _prediction["videoEmbeddings"]:
+                            openai_embedding_object = Embedding(
+                                embedding=video_embedding["embedding"],
+                                index=idx,
+                                object="embedding",
+                            )
+                            openai_embeddings.append(openai_embedding_object)
+        return openai_embeddings