diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py | 294 |
1 files changed, 294 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py new file mode 100644 index 00000000..f63d1ce1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py @@ -0,0 +1,294 @@ +import json +from typing import List, Literal, Optional, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) +from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( + VertexAIError, + VertexLLM, +) +from litellm.types.llms.vertex_ai import ( + Instance, + InstanceImage, + InstanceVideo, + MultimodalPredictions, + VertexMultimodalEmbeddingRequest, +) +from litellm.types.utils import Embedding, EmbeddingResponse +from litellm.utils import is_base64_encoded + + +class VertexMultimodalEmbedding(VertexLLM): + def __init__(self) -> None: + super().__init__() + self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [ + "multimodalembedding", + "multimodalembedding@001", + ] + + def multimodal_embedding( + self, + model: str, + input: Union[list, str], + print_verbose, + model_response: EmbeddingResponse, + custom_llm_provider: Literal["gemini", "vertex_ai"], + optional_params: dict, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + encoding=None, + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + aembedding=False, + timeout=300, + client=None, + ) -> EmbeddingResponse: + + _auth_header, vertex_project = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider=custom_llm_provider, + ) + + auth_header, url = self._get_token_and_url( + model=model, + auth_header=_auth_header, + gemini_api_key=api_key, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=None, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=False, + mode="embedding", + ) + + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore + else: + sync_handler = client # type: ignore + + optional_params = optional_params or {} + + request_data = VertexMultimodalEmbeddingRequest() + + if "instances" in optional_params: + request_data["instances"] = optional_params["instances"] + elif isinstance(input, list): + vertex_instances: List[Instance] = self.process_openai_embedding_input( + _input=input + ) + request_data["instances"] = vertex_instances + + else: + # construct instances + vertex_request_instance = Instance(**optional_params) + + if isinstance(input, str): + vertex_request_instance = self._process_input_element(input) + + request_data["instances"] = [vertex_request_instance] + + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + } + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key="", + additional_args={ + "complete_input_dict": request_data, + "api_base": url, + "headers": headers, + }, + ) + + if aembedding is True: + return self.async_multimodal_embedding( # type: ignore + model=model, + api_base=url, + data=request_data, + timeout=timeout, + headers=headers, + client=client, + model_response=model_response, + ) + + response = sync_handler.post( + url=url, + headers=headers, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + if "predictions" not in _json_response: + raise litellm.InternalServerError( + message=f"embedding response does not contain 'predictions', got {_json_response}", + llm_provider="vertex_ai", + model=model, + ) + _predictions = _json_response["predictions"] + vertex_predictions = MultimodalPredictions(predictions=_predictions) + model_response.data = self.transform_embedding_response_to_openai( + predictions=vertex_predictions + ) + model_response.model = model + + return model_response + + async def async_multimodal_embedding( + self, + model: str, + api_base: str, + data: VertexMultimodalEmbeddingRequest, + model_response: litellm.EmbeddingResponse, + timeout: Optional[Union[float, httpx.Timeout]], + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> litellm.EmbeddingResponse: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + params={"timeout": timeout}, + ) + else: + client = client # type: ignore + + try: + response = await client.post(api_base, headers=headers, json=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise VertexAIError(status_code=408, message="Timeout error occurred.") + + _json_response = response.json() + if "predictions" not in _json_response: + raise litellm.InternalServerError( + message=f"embedding response does not contain 'predictions', got {_json_response}", + llm_provider="vertex_ai", + model=model, + ) + _predictions = _json_response["predictions"] + + vertex_predictions = MultimodalPredictions(predictions=_predictions) + model_response.data = self.transform_embedding_response_to_openai( + predictions=vertex_predictions + ) + model_response.model = model + + return model_response + + def _process_input_element(self, input_element: str) -> Instance: + """ + Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text. + + Args: + input_element (str): The input element to process. + + Returns: + Dict[str, Any]: A dictionary representing the processed input element. + """ + if len(input_element) == 0: + return Instance(text=input_element) + elif "gs://" in input_element: + if "mp4" in input_element: + return Instance(video=InstanceVideo(gcsUri=input_element)) + else: + return Instance(image=InstanceImage(gcsUri=input_element)) + elif is_base64_encoded(s=input_element): + return Instance(image=InstanceImage(bytesBase64Encoded=input_element)) + else: + return Instance(text=input_element) + + def process_openai_embedding_input( + self, _input: Union[list, str] + ) -> List[Instance]: + """ + Process the input for multimodal embedding requests. + + Args: + _input (Union[list, str]): The input data to process. + + Returns: + List[Instance]: A list of processed VertexAI Instance objects. + """ + + _input_list = None + if not isinstance(_input, list): + _input_list = [_input] + else: + _input_list = _input + + processed_instances = [] + for element in _input_list: + if isinstance(element, str): + instance = Instance(**self._process_input_element(element)) + elif isinstance(element, dict): + instance = Instance(**element) + else: + raise ValueError(f"Unsupported input type: {type(element)}") + processed_instances.append(instance) + + return processed_instances + + def transform_embedding_response_to_openai( + self, predictions: MultimodalPredictions + ) -> List[Embedding]: + + openai_embeddings: List[Embedding] = [] + if "predictions" in predictions: + for idx, _prediction in enumerate(predictions["predictions"]): + if _prediction: + if "textEmbedding" in _prediction: + openai_embedding_object = Embedding( + embedding=_prediction["textEmbedding"], + index=idx, + object="embedding", + ) + openai_embeddings.append(openai_embedding_object) + elif "imageEmbedding" in _prediction: + openai_embedding_object = Embedding( + embedding=_prediction["imageEmbedding"], + index=idx, + object="embedding", + ) + openai_embeddings.append(openai_embedding_object) + elif "videoEmbeddings" in _prediction: + for video_embedding in _prediction["videoEmbeddings"]: + openai_embedding_object = Embedding( + embedding=video_embedding["embedding"], + index=idx, + object="embedding", + ) + openai_embeddings.append(openai_embedding_object) + return openai_embeddings |