aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py294
1 files changed, 294 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py
new file mode 100644
index 00000000..f63d1ce1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py
@@ -0,0 +1,294 @@
+import json
+from typing import List, Literal, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ get_async_httpx_client,
+)
+from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
+ VertexAIError,
+ VertexLLM,
+)
+from litellm.types.llms.vertex_ai import (
+ Instance,
+ InstanceImage,
+ InstanceVideo,
+ MultimodalPredictions,
+ VertexMultimodalEmbeddingRequest,
+)
+from litellm.types.utils import Embedding, EmbeddingResponse
+from litellm.utils import is_base64_encoded
+
+
+class VertexMultimodalEmbedding(VertexLLM):
+ def __init__(self) -> None:
+ super().__init__()
+ self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
+ "multimodalembedding",
+ "multimodalembedding@001",
+ ]
+
+ def multimodal_embedding(
+ self,
+ model: str,
+ input: Union[list, str],
+ print_verbose,
+ model_response: EmbeddingResponse,
+ custom_llm_provider: Literal["gemini", "vertex_ai"],
+ optional_params: dict,
+ logging_obj: LiteLLMLoggingObj,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ encoding=None,
+ vertex_project=None,
+ vertex_location=None,
+ vertex_credentials=None,
+ aembedding=False,
+ timeout=300,
+ client=None,
+ ) -> EmbeddingResponse:
+
+ _auth_header, vertex_project = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+
+ auth_header, url = self._get_token_and_url(
+ model=model,
+ auth_header=_auth_header,
+ gemini_api_key=api_key,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=None,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=False,
+ mode="embedding",
+ )
+
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
+ else:
+ sync_handler = client # type: ignore
+
+ optional_params = optional_params or {}
+
+ request_data = VertexMultimodalEmbeddingRequest()
+
+ if "instances" in optional_params:
+ request_data["instances"] = optional_params["instances"]
+ elif isinstance(input, list):
+ vertex_instances: List[Instance] = self.process_openai_embedding_input(
+ _input=input
+ )
+ request_data["instances"] = vertex_instances
+
+ else:
+ # construct instances
+ vertex_request_instance = Instance(**optional_params)
+
+ if isinstance(input, str):
+ vertex_request_instance = self._process_input_element(input)
+
+ request_data["instances"] = [vertex_request_instance]
+
+ headers = {
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {auth_header}",
+ }
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input,
+ api_key="",
+ additional_args={
+ "complete_input_dict": request_data,
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ if aembedding is True:
+ return self.async_multimodal_embedding( # type: ignore
+ model=model,
+ api_base=url,
+ data=request_data,
+ timeout=timeout,
+ headers=headers,
+ client=client,
+ model_response=model_response,
+ )
+
+ response = sync_handler.post(
+ url=url,
+ headers=headers,
+ data=json.dumps(request_data),
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ _json_response = response.json()
+ if "predictions" not in _json_response:
+ raise litellm.InternalServerError(
+ message=f"embedding response does not contain 'predictions', got {_json_response}",
+ llm_provider="vertex_ai",
+ model=model,
+ )
+ _predictions = _json_response["predictions"]
+ vertex_predictions = MultimodalPredictions(predictions=_predictions)
+ model_response.data = self.transform_embedding_response_to_openai(
+ predictions=vertex_predictions
+ )
+ model_response.model = model
+
+ return model_response
+
+ async def async_multimodal_embedding(
+ self,
+ model: str,
+ api_base: str,
+ data: VertexMultimodalEmbeddingRequest,
+ model_response: litellm.EmbeddingResponse,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ headers={},
+ client: Optional[AsyncHTTPHandler] = None,
+ ) -> litellm.EmbeddingResponse:
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ timeout = httpx.Timeout(timeout)
+ _params["timeout"] = timeout
+ client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI,
+ params={"timeout": timeout},
+ )
+ else:
+ client = client # type: ignore
+
+ try:
+ response = await client.post(api_base, headers=headers, json=data) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+ _json_response = response.json()
+ if "predictions" not in _json_response:
+ raise litellm.InternalServerError(
+ message=f"embedding response does not contain 'predictions', got {_json_response}",
+ llm_provider="vertex_ai",
+ model=model,
+ )
+ _predictions = _json_response["predictions"]
+
+ vertex_predictions = MultimodalPredictions(predictions=_predictions)
+ model_response.data = self.transform_embedding_response_to_openai(
+ predictions=vertex_predictions
+ )
+ model_response.model = model
+
+ return model_response
+
+ def _process_input_element(self, input_element: str) -> Instance:
+ """
+ Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text.
+
+ Args:
+ input_element (str): The input element to process.
+
+ Returns:
+ Dict[str, Any]: A dictionary representing the processed input element.
+ """
+ if len(input_element) == 0:
+ return Instance(text=input_element)
+ elif "gs://" in input_element:
+ if "mp4" in input_element:
+ return Instance(video=InstanceVideo(gcsUri=input_element))
+ else:
+ return Instance(image=InstanceImage(gcsUri=input_element))
+ elif is_base64_encoded(s=input_element):
+ return Instance(image=InstanceImage(bytesBase64Encoded=input_element))
+ else:
+ return Instance(text=input_element)
+
+ def process_openai_embedding_input(
+ self, _input: Union[list, str]
+ ) -> List[Instance]:
+ """
+ Process the input for multimodal embedding requests.
+
+ Args:
+ _input (Union[list, str]): The input data to process.
+
+ Returns:
+ List[Instance]: A list of processed VertexAI Instance objects.
+ """
+
+ _input_list = None
+ if not isinstance(_input, list):
+ _input_list = [_input]
+ else:
+ _input_list = _input
+
+ processed_instances = []
+ for element in _input_list:
+ if isinstance(element, str):
+ instance = Instance(**self._process_input_element(element))
+ elif isinstance(element, dict):
+ instance = Instance(**element)
+ else:
+ raise ValueError(f"Unsupported input type: {type(element)}")
+ processed_instances.append(instance)
+
+ return processed_instances
+
+ def transform_embedding_response_to_openai(
+ self, predictions: MultimodalPredictions
+ ) -> List[Embedding]:
+
+ openai_embeddings: List[Embedding] = []
+ if "predictions" in predictions:
+ for idx, _prediction in enumerate(predictions["predictions"]):
+ if _prediction:
+ if "textEmbedding" in _prediction:
+ openai_embedding_object = Embedding(
+ embedding=_prediction["textEmbedding"],
+ index=idx,
+ object="embedding",
+ )
+ openai_embeddings.append(openai_embedding_object)
+ elif "imageEmbedding" in _prediction:
+ openai_embedding_object = Embedding(
+ embedding=_prediction["imageEmbedding"],
+ index=idx,
+ object="embedding",
+ )
+ openai_embeddings.append(openai_embedding_object)
+ elif "videoEmbeddings" in _prediction:
+ for video_embedding in _prediction["videoEmbeddings"]:
+ openai_embedding_object = Embedding(
+ embedding=video_embedding["embedding"],
+ index=idx,
+ object="embedding",
+ )
+ openai_embeddings.append(openai_embedding_object)
+ return openai_embeddings