aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py228
1 files changed, 228 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py
new file mode 100644
index 00000000..3ef40703
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py
@@ -0,0 +1,228 @@
+from typing import Literal, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.llms.vertex_ai.vertex_ai_non_gemini import VertexAIError
+from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+from litellm.types.llms.vertex_ai import *
+from litellm.types.utils import EmbeddingResponse
+
+from .types import *
+
+
+class VertexEmbedding(VertexBase):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def embedding(
+ self,
+ model: str,
+ input: Union[list, str],
+ print_verbose,
+ model_response: EmbeddingResponse,
+ optional_params: dict,
+ logging_obj: LiteLLMLoggingObject,
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ timeout: Optional[Union[float, httpx.Timeout]],
+ api_key: Optional[str] = None,
+ encoding=None,
+ aembedding=False,
+ api_base: Optional[str] = None,
+ client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
+ vertex_project: Optional[str] = None,
+ vertex_location: Optional[str] = None,
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
+ gemini_api_key: Optional[str] = None,
+ extra_headers: Optional[dict] = None,
+ ) -> EmbeddingResponse:
+ if aembedding is True:
+ return self.async_embedding( # type: ignore
+ model=model,
+ input=input,
+ logging_obj=logging_obj,
+ model_response=model_response,
+ optional_params=optional_params,
+ encoding=encoding,
+ custom_llm_provider=custom_llm_provider,
+ timeout=timeout,
+ api_base=api_base,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ gemini_api_key=gemini_api_key,
+ extra_headers=extra_headers,
+ )
+
+ should_use_v1beta1_features = self.is_using_v1beta1_features(
+ optional_params=optional_params
+ )
+
+ _auth_header, vertex_project = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+ auth_header, api_base = self._get_token_and_url(
+ model=model,
+ gemini_api_key=gemini_api_key,
+ auth_header=_auth_header,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=False,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=should_use_v1beta1_features,
+ mode="embedding",
+ )
+ headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
+ vertex_request: VertexEmbeddingRequest = (
+ litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
+ input=input, optional_params=optional_params, model=model
+ )
+ )
+
+ _client_params = {}
+ if timeout:
+ _client_params["timeout"] = timeout
+ if client is None or not isinstance(client, HTTPHandler):
+ client = _get_httpx_client(params=_client_params)
+ else:
+ client = client # type: ignore
+ ## LOGGING
+ logging_obj.pre_call(
+ input=vertex_request,
+ api_key="",
+ additional_args={
+ "complete_input_dict": vertex_request,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ try:
+ response = client.post(api_base, headers=headers, json=vertex_request) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+ _json_response = response.json()
+ ## LOGGING POST-CALL
+ logging_obj.post_call(
+ input=input, api_key=None, original_response=_json_response
+ )
+
+ model_response = (
+ litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
+ response=_json_response, model=model, model_response=model_response
+ )
+ )
+
+ return model_response
+
+ async def async_embedding(
+ self,
+ model: str,
+ input: Union[list, str],
+ model_response: litellm.EmbeddingResponse,
+ logging_obj: LiteLLMLoggingObject,
+ optional_params: dict,
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ timeout: Optional[Union[float, httpx.Timeout]],
+ api_base: Optional[str] = None,
+ client: Optional[AsyncHTTPHandler] = None,
+ vertex_project: Optional[str] = None,
+ vertex_location: Optional[str] = None,
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
+ gemini_api_key: Optional[str] = None,
+ extra_headers: Optional[dict] = None,
+ encoding=None,
+ ) -> litellm.EmbeddingResponse:
+ """
+ Async embedding implementation
+ """
+ should_use_v1beta1_features = self.is_using_v1beta1_features(
+ optional_params=optional_params
+ )
+ _auth_header, vertex_project = await self._ensure_access_token_async(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+ auth_header, api_base = self._get_token_and_url(
+ model=model,
+ gemini_api_key=gemini_api_key,
+ auth_header=_auth_header,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=False,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=should_use_v1beta1_features,
+ mode="embedding",
+ )
+ headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
+ vertex_request: VertexEmbeddingRequest = (
+ litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
+ input=input, optional_params=optional_params, model=model
+ )
+ )
+
+ _async_client_params = {}
+ if timeout:
+ _async_client_params["timeout"] = timeout
+ if client is None or not isinstance(client, AsyncHTTPHandler):
+ client = get_async_httpx_client(
+ params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
+ )
+ else:
+ client = client # type: ignore
+ ## LOGGING
+ logging_obj.pre_call(
+ input=vertex_request,
+ api_key="",
+ additional_args={
+ "complete_input_dict": vertex_request,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ try:
+ response = await client.post(api_base, headers=headers, json=vertex_request) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+ _json_response = response.json()
+ ## LOGGING POST-CALL
+ logging_obj.post_call(
+ input=input, api_key=None, original_response=_json_response
+ )
+
+ model_response = (
+ litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
+ response=_json_response, model=model, model_response=model_response
+ )
+ )
+
+ return model_response