about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py182
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py74
2 files changed, 256 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py
new file mode 100644
index 00000000..0fe5145a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py
@@ -0,0 +1,182 @@
+"""
+Google AI Studio /batchEmbedContents Embeddings Endpoint
+"""
+
+import json
+from typing import Any, Literal, Optional, Union
+
+import httpx
+
+import litellm
+from litellm import EmbeddingResponse
+from litellm.llms.custom_httpx.http_handler import (
+    AsyncHTTPHandler,
+    HTTPHandler,
+    get_async_httpx_client,
+)
+from litellm.types.llms.openai import EmbeddingInput
+from litellm.types.llms.vertex_ai import (
+    VertexAIBatchEmbeddingsRequestBody,
+    VertexAIBatchEmbeddingsResponseObject,
+)
+
+from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
+from .batch_embed_content_transformation import (
+    process_response,
+    transform_openai_input_gemini_content,
+)
+
+
+class GoogleBatchEmbeddings(VertexLLM):
+    def batch_embeddings(
+        self,
+        model: str,
+        input: EmbeddingInput,
+        print_verbose,
+        model_response: EmbeddingResponse,
+        custom_llm_provider: Literal["gemini", "vertex_ai"],
+        optional_params: dict,
+        logging_obj: Any,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+        encoding=None,
+        vertex_project=None,
+        vertex_location=None,
+        vertex_credentials=None,
+        aembedding=False,
+        timeout=300,
+        client=None,
+    ) -> EmbeddingResponse:
+
+        _auth_header, vertex_project = self._ensure_access_token(
+            credentials=vertex_credentials,
+            project_id=vertex_project,
+            custom_llm_provider=custom_llm_provider,
+        )
+
+        auth_header, url = self._get_token_and_url(
+            model=model,
+            auth_header=_auth_header,
+            gemini_api_key=api_key,
+            vertex_project=vertex_project,
+            vertex_location=vertex_location,
+            vertex_credentials=vertex_credentials,
+            stream=None,
+            custom_llm_provider=custom_llm_provider,
+            api_base=api_base,
+            should_use_v1beta1_features=False,
+            mode="batch_embedding",
+        )
+
+        if client is None:
+            _params = {}
+            if timeout is not None:
+                if isinstance(timeout, float) or isinstance(timeout, int):
+                    _httpx_timeout = httpx.Timeout(timeout)
+                    _params["timeout"] = _httpx_timeout
+            else:
+                _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+            sync_handler: HTTPHandler = HTTPHandler(**_params)  # type: ignore
+        else:
+            sync_handler = client  # type: ignore
+
+        optional_params = optional_params or {}
+
+        ### TRANSFORMATION ###
+        request_data = transform_openai_input_gemini_content(
+            input=input, model=model, optional_params=optional_params
+        )
+
+        headers = {
+            "Content-Type": "application/json; charset=utf-8",
+        }
+
+        ## LOGGING
+        logging_obj.pre_call(
+            input=input,
+            api_key="",
+            additional_args={
+                "complete_input_dict": request_data,
+                "api_base": url,
+                "headers": headers,
+            },
+        )
+
+        if aembedding is True:
+            return self.async_batch_embeddings(  # type: ignore
+                model=model,
+                api_base=api_base,
+                url=url,
+                data=request_data,
+                model_response=model_response,
+                timeout=timeout,
+                headers=headers,
+                input=input,
+            )
+
+        response = sync_handler.post(
+            url=url,
+            headers=headers,
+            data=json.dumps(request_data),
+        )
+
+        if response.status_code != 200:
+            raise Exception(f"Error: {response.status_code} {response.text}")
+
+        _json_response = response.json()
+        _predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response)  # type: ignore
+
+        return process_response(
+            model=model,
+            model_response=model_response,
+            _predictions=_predictions,
+            input=input,
+        )
+
+    async def async_batch_embeddings(
+        self,
+        model: str,
+        api_base: Optional[str],
+        url: str,
+        data: VertexAIBatchEmbeddingsRequestBody,
+        model_response: EmbeddingResponse,
+        input: EmbeddingInput,
+        timeout: Optional[Union[float, httpx.Timeout]],
+        headers={},
+        client: Optional[AsyncHTTPHandler] = None,
+    ) -> EmbeddingResponse:
+        if client is None:
+            _params = {}
+            if timeout is not None:
+                if isinstance(timeout, float) or isinstance(timeout, int):
+                    _httpx_timeout = httpx.Timeout(timeout)
+                    _params["timeout"] = _httpx_timeout
+            else:
+                _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+            async_handler: AsyncHTTPHandler = get_async_httpx_client(
+                llm_provider=litellm.LlmProviders.VERTEX_AI,
+                params={"timeout": timeout},
+            )
+        else:
+            async_handler = client  # type: ignore
+
+        response = await async_handler.post(
+            url=url,
+            headers=headers,
+            data=json.dumps(data),
+        )
+
+        if response.status_code != 200:
+            raise Exception(f"Error: {response.status_code} {response.text}")
+
+        _json_response = response.json()
+        _predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response)  # type: ignore
+
+        return process_response(
+            model=model,
+            model_response=model_response,
+            _predictions=_predictions,
+            input=input,
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py
new file mode 100644
index 00000000..592dac58
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py
@@ -0,0 +1,74 @@
+"""
+Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format. 
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+from typing import List
+
+from litellm import EmbeddingResponse
+from litellm.types.llms.openai import EmbeddingInput
+from litellm.types.llms.vertex_ai import (
+    ContentType,
+    EmbedContentRequest,
+    PartType,
+    VertexAIBatchEmbeddingsRequestBody,
+    VertexAIBatchEmbeddingsResponseObject,
+)
+from litellm.types.utils import Embedding, Usage
+from litellm.utils import get_formatted_prompt, token_counter
+
+
+def transform_openai_input_gemini_content(
+    input: EmbeddingInput, model: str, optional_params: dict
+) -> VertexAIBatchEmbeddingsRequestBody:
+    """
+    The content to embed. Only the parts.text fields will be counted.
+    """
+    gemini_model_name = "models/{}".format(model)
+    requests: List[EmbedContentRequest] = []
+    if isinstance(input, str):
+        request = EmbedContentRequest(
+            model=gemini_model_name,
+            content=ContentType(parts=[PartType(text=input)]),
+            **optional_params
+        )
+        requests.append(request)
+    else:
+        for i in input:
+            request = EmbedContentRequest(
+                model=gemini_model_name,
+                content=ContentType(parts=[PartType(text=i)]),
+                **optional_params
+            )
+            requests.append(request)
+
+    return VertexAIBatchEmbeddingsRequestBody(requests=requests)
+
+
+def process_response(
+    input: EmbeddingInput,
+    model_response: EmbeddingResponse,
+    model: str,
+    _predictions: VertexAIBatchEmbeddingsResponseObject,
+) -> EmbeddingResponse:
+
+    openai_embeddings: List[Embedding] = []
+    for embedding in _predictions["embeddings"]:
+        openai_embedding = Embedding(
+            embedding=embedding["values"],
+            index=0,
+            object="embedding",
+        )
+        openai_embeddings.append(openai_embedding)
+
+    model_response.data = openai_embeddings
+    model_response.model = model
+
+    input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
+    prompt_tokens = token_counter(model=model, text=input_text)
+    model_response.usage = Usage(
+        prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
+    )
+
+    return model_response