about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py228
1 files changed, 228 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py
new file mode 100644
index 00000000..3ef40703
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py
@@ -0,0 +1,228 @@
+from typing import Literal, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
+from litellm.llms.custom_httpx.http_handler import (
+    AsyncHTTPHandler,
+    HTTPHandler,
+    _get_httpx_client,
+    get_async_httpx_client,
+)
+from litellm.llms.vertex_ai.vertex_ai_non_gemini import VertexAIError
+from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+from litellm.types.llms.vertex_ai import *
+from litellm.types.utils import EmbeddingResponse
+
+from .types import *
+
+
+class VertexEmbedding(VertexBase):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def embedding(
+        self,
+        model: str,
+        input: Union[list, str],
+        print_verbose,
+        model_response: EmbeddingResponse,
+        optional_params: dict,
+        logging_obj: LiteLLMLoggingObject,
+        custom_llm_provider: Literal[
+            "vertex_ai", "vertex_ai_beta", "gemini"
+        ],  # if it's vertex_ai or gemini (google ai studio)
+        timeout: Optional[Union[float, httpx.Timeout]],
+        api_key: Optional[str] = None,
+        encoding=None,
+        aembedding=False,
+        api_base: Optional[str] = None,
+        client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
+        vertex_project: Optional[str] = None,
+        vertex_location: Optional[str] = None,
+        vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
+        gemini_api_key: Optional[str] = None,
+        extra_headers: Optional[dict] = None,
+    ) -> EmbeddingResponse:
+        if aembedding is True:
+            return self.async_embedding(  # type: ignore
+                model=model,
+                input=input,
+                logging_obj=logging_obj,
+                model_response=model_response,
+                optional_params=optional_params,
+                encoding=encoding,
+                custom_llm_provider=custom_llm_provider,
+                timeout=timeout,
+                api_base=api_base,
+                vertex_project=vertex_project,
+                vertex_location=vertex_location,
+                vertex_credentials=vertex_credentials,
+                gemini_api_key=gemini_api_key,
+                extra_headers=extra_headers,
+            )
+
+        should_use_v1beta1_features = self.is_using_v1beta1_features(
+            optional_params=optional_params
+        )
+
+        _auth_header, vertex_project = self._ensure_access_token(
+            credentials=vertex_credentials,
+            project_id=vertex_project,
+            custom_llm_provider=custom_llm_provider,
+        )
+        auth_header, api_base = self._get_token_and_url(
+            model=model,
+            gemini_api_key=gemini_api_key,
+            auth_header=_auth_header,
+            vertex_project=vertex_project,
+            vertex_location=vertex_location,
+            vertex_credentials=vertex_credentials,
+            stream=False,
+            custom_llm_provider=custom_llm_provider,
+            api_base=api_base,
+            should_use_v1beta1_features=should_use_v1beta1_features,
+            mode="embedding",
+        )
+        headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
+        vertex_request: VertexEmbeddingRequest = (
+            litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
+                input=input, optional_params=optional_params, model=model
+            )
+        )
+
+        _client_params = {}
+        if timeout:
+            _client_params["timeout"] = timeout
+        if client is None or not isinstance(client, HTTPHandler):
+            client = _get_httpx_client(params=_client_params)
+        else:
+            client = client  # type: ignore
+        ## LOGGING
+        logging_obj.pre_call(
+            input=vertex_request,
+            api_key="",
+            additional_args={
+                "complete_input_dict": vertex_request,
+                "api_base": api_base,
+                "headers": headers,
+            },
+        )
+
+        try:
+            response = client.post(api_base, headers=headers, json=vertex_request)  # type: ignore
+            response.raise_for_status()
+        except httpx.HTTPStatusError as err:
+            error_code = err.response.status_code
+            raise VertexAIError(status_code=error_code, message=err.response.text)
+        except httpx.TimeoutException:
+            raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+        _json_response = response.json()
+        ## LOGGING POST-CALL
+        logging_obj.post_call(
+            input=input, api_key=None, original_response=_json_response
+        )
+
+        model_response = (
+            litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
+                response=_json_response, model=model, model_response=model_response
+            )
+        )
+
+        return model_response
+
+    async def async_embedding(
+        self,
+        model: str,
+        input: Union[list, str],
+        model_response: litellm.EmbeddingResponse,
+        logging_obj: LiteLLMLoggingObject,
+        optional_params: dict,
+        custom_llm_provider: Literal[
+            "vertex_ai", "vertex_ai_beta", "gemini"
+        ],  # if it's vertex_ai or gemini (google ai studio)
+        timeout: Optional[Union[float, httpx.Timeout]],
+        api_base: Optional[str] = None,
+        client: Optional[AsyncHTTPHandler] = None,
+        vertex_project: Optional[str] = None,
+        vertex_location: Optional[str] = None,
+        vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
+        gemini_api_key: Optional[str] = None,
+        extra_headers: Optional[dict] = None,
+        encoding=None,
+    ) -> litellm.EmbeddingResponse:
+        """
+        Async embedding implementation
+        """
+        should_use_v1beta1_features = self.is_using_v1beta1_features(
+            optional_params=optional_params
+        )
+        _auth_header, vertex_project = await self._ensure_access_token_async(
+            credentials=vertex_credentials,
+            project_id=vertex_project,
+            custom_llm_provider=custom_llm_provider,
+        )
+        auth_header, api_base = self._get_token_and_url(
+            model=model,
+            gemini_api_key=gemini_api_key,
+            auth_header=_auth_header,
+            vertex_project=vertex_project,
+            vertex_location=vertex_location,
+            vertex_credentials=vertex_credentials,
+            stream=False,
+            custom_llm_provider=custom_llm_provider,
+            api_base=api_base,
+            should_use_v1beta1_features=should_use_v1beta1_features,
+            mode="embedding",
+        )
+        headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
+        vertex_request: VertexEmbeddingRequest = (
+            litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
+                input=input, optional_params=optional_params, model=model
+            )
+        )
+
+        _async_client_params = {}
+        if timeout:
+            _async_client_params["timeout"] = timeout
+        if client is None or not isinstance(client, AsyncHTTPHandler):
+            client = get_async_httpx_client(
+                params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
+            )
+        else:
+            client = client  # type: ignore
+        ## LOGGING
+        logging_obj.pre_call(
+            input=vertex_request,
+            api_key="",
+            additional_args={
+                "complete_input_dict": vertex_request,
+                "api_base": api_base,
+                "headers": headers,
+            },
+        )
+
+        try:
+            response = await client.post(api_base, headers=headers, json=vertex_request)  # type: ignore
+            response.raise_for_status()
+        except httpx.HTTPStatusError as err:
+            error_code = err.response.status_code
+            raise VertexAIError(status_code=error_code, message=err.response.text)
+        except httpx.TimeoutException:
+            raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+        _json_response = response.json()
+        ## LOGGING POST-CALL
+        logging_obj.post_call(
+            input=input, api_key=None, original_response=_json_response
+        )
+
+        model_response = (
+            litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
+                response=_json_response, model=model, model_response=model_response
+            )
+        )
+
+        return model_response