diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings')
2 files changed, 256 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py new file mode 100644 index 00000000..0fe5145a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py @@ -0,0 +1,182 @@ +""" +Google AI Studio /batchEmbedContents Embeddings Endpoint +""" + +import json +from typing import Any, Literal, Optional, Union + +import httpx + +import litellm +from litellm import EmbeddingResponse +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) +from litellm.types.llms.openai import EmbeddingInput +from litellm.types.llms.vertex_ai import ( + VertexAIBatchEmbeddingsRequestBody, + VertexAIBatchEmbeddingsResponseObject, +) + +from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM +from .batch_embed_content_transformation import ( + process_response, + transform_openai_input_gemini_content, +) + + +class GoogleBatchEmbeddings(VertexLLM): + def batch_embeddings( + self, + model: str, + input: EmbeddingInput, + print_verbose, + model_response: EmbeddingResponse, + custom_llm_provider: Literal["gemini", "vertex_ai"], + optional_params: dict, + logging_obj: Any, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + encoding=None, + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + aembedding=False, + timeout=300, + client=None, + ) -> EmbeddingResponse: + + _auth_header, vertex_project = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider=custom_llm_provider, + ) + + auth_header, url = self._get_token_and_url( + model=model, + auth_header=_auth_header, + gemini_api_key=api_key, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=None, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=False, + mode="batch_embedding", + ) + + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore + else: + sync_handler = client # type: ignore + + optional_params = optional_params or {} + + ### TRANSFORMATION ### + request_data = transform_openai_input_gemini_content( + input=input, model=model, optional_params=optional_params + ) + + headers = { + "Content-Type": "application/json; charset=utf-8", + } + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key="", + additional_args={ + "complete_input_dict": request_data, + "api_base": url, + "headers": headers, + }, + ) + + if aembedding is True: + return self.async_batch_embeddings( # type: ignore + model=model, + api_base=api_base, + url=url, + data=request_data, + model_response=model_response, + timeout=timeout, + headers=headers, + input=input, + ) + + response = sync_handler.post( + url=url, + headers=headers, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + _predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore + + return process_response( + model=model, + model_response=model_response, + _predictions=_predictions, + input=input, + ) + + async def async_batch_embeddings( + self, + model: str, + api_base: Optional[str], + url: str, + data: VertexAIBatchEmbeddingsRequestBody, + model_response: EmbeddingResponse, + input: EmbeddingInput, + timeout: Optional[Union[float, httpx.Timeout]], + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> EmbeddingResponse: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + async_handler: AsyncHTTPHandler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + params={"timeout": timeout}, + ) + else: + async_handler = client # type: ignore + + response = await async_handler.post( + url=url, + headers=headers, + data=json.dumps(data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + _predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore + + return process_response( + model=model, + model_response=model_response, + _predictions=_predictions, + input=input, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py new file mode 100644 index 00000000..592dac58 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py @@ -0,0 +1,74 @@ +""" +Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format. + +Why separate file? Make it easy to see how transformation works +""" + +from typing import List + +from litellm import EmbeddingResponse +from litellm.types.llms.openai import EmbeddingInput +from litellm.types.llms.vertex_ai import ( + ContentType, + EmbedContentRequest, + PartType, + VertexAIBatchEmbeddingsRequestBody, + VertexAIBatchEmbeddingsResponseObject, +) +from litellm.types.utils import Embedding, Usage +from litellm.utils import get_formatted_prompt, token_counter + + +def transform_openai_input_gemini_content( + input: EmbeddingInput, model: str, optional_params: dict +) -> VertexAIBatchEmbeddingsRequestBody: + """ + The content to embed. Only the parts.text fields will be counted. + """ + gemini_model_name = "models/{}".format(model) + requests: List[EmbedContentRequest] = [] + if isinstance(input, str): + request = EmbedContentRequest( + model=gemini_model_name, + content=ContentType(parts=[PartType(text=input)]), + **optional_params + ) + requests.append(request) + else: + for i in input: + request = EmbedContentRequest( + model=gemini_model_name, + content=ContentType(parts=[PartType(text=i)]), + **optional_params + ) + requests.append(request) + + return VertexAIBatchEmbeddingsRequestBody(requests=requests) + + +def process_response( + input: EmbeddingInput, + model_response: EmbeddingResponse, + model: str, + _predictions: VertexAIBatchEmbeddingsResponseObject, +) -> EmbeddingResponse: + + openai_embeddings: List[Embedding] = [] + for embedding in _predictions["embeddings"]: + openai_embedding = Embedding( + embedding=embedding["values"], + index=0, + object="embedding", + ) + openai_embeddings.append(openai_embedding) + + model_response.data = openai_embeddings + model_response.model = model + + input_text = get_formatted_prompt(data={"input": input}, call_type="embedding") + prompt_tokens = token_counter(model=model, text=input_text) + model_response.usage = Usage( + prompt_tokens=prompt_tokens, total_tokens=prompt_tokens + ) + + return model_response |