aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/embed/embedding.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/bedrock/embed/embedding.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/bedrock/embed/embedding.py480
1 files changed, 480 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/embed/embedding.py b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/embed/embedding.py
new file mode 100644
index 00000000..9e4e4e22
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/embed/embedding.py
@@ -0,0 +1,480 @@
+"""
+Handles embedding calls to Bedrock's `/invoke` endpoint
+"""
+
+import copy
+import json
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import httpx
+
+import litellm
+from litellm.llms.cohere.embed.handler import embedding as cohere_embedding
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.secret_managers.main import get_secret
+from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
+from litellm.types.utils import EmbeddingResponse
+
+from ..base_aws_llm import BaseAWSLLM
+from ..common_utils import BedrockError
+from .amazon_titan_g1_transformation import AmazonTitanG1Config
+from .amazon_titan_multimodal_transformation import (
+ AmazonTitanMultimodalEmbeddingG1Config,
+)
+from .amazon_titan_v2_transformation import AmazonTitanV2Config
+from .cohere_transformation import BedrockCohereEmbeddingConfig
+
+
+class BedrockEmbedding(BaseAWSLLM):
+ def _load_credentials(
+ self,
+ optional_params: dict,
+ ) -> Tuple[Any, str]:
+ try:
+ from botocore.credentials import Credentials
+ except ImportError:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+ ## CREDENTIALS ##
+ # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
+ aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
+ aws_access_key_id = optional_params.pop("aws_access_key_id", None)
+ aws_session_token = optional_params.pop("aws_session_token", None)
+ aws_region_name = optional_params.pop("aws_region_name", None)
+ aws_role_name = optional_params.pop("aws_role_name", None)
+ aws_session_name = optional_params.pop("aws_session_name", None)
+ aws_profile_name = optional_params.pop("aws_profile_name", None)
+ aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
+ aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
+
+ ### SET REGION NAME ###
+ if aws_region_name is None:
+ # check env #
+ litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
+
+ if litellm_aws_region_name is not None and isinstance(
+ litellm_aws_region_name, str
+ ):
+ aws_region_name = litellm_aws_region_name
+
+ standard_aws_region_name = get_secret("AWS_REGION", None)
+ if standard_aws_region_name is not None and isinstance(
+ standard_aws_region_name, str
+ ):
+ aws_region_name = standard_aws_region_name
+
+ if aws_region_name is None:
+ aws_region_name = "us-west-2"
+
+ credentials: Credentials = self.get_credentials(
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ aws_session_token=aws_session_token,
+ aws_region_name=aws_region_name,
+ aws_session_name=aws_session_name,
+ aws_profile_name=aws_profile_name,
+ aws_role_name=aws_role_name,
+ aws_web_identity_token=aws_web_identity_token,
+ aws_sts_endpoint=aws_sts_endpoint,
+ )
+ return credentials, aws_region_name
+
+ async def async_embeddings(self):
+ pass
+
+ def _make_sync_call(
+ self,
+ client: Optional[HTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ api_base: str,
+ headers: dict,
+ data: dict,
+ ) -> dict:
+ if client is None or not isinstance(client, HTTPHandler):
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ timeout = httpx.Timeout(timeout)
+ _params["timeout"] = timeout
+ client = _get_httpx_client(_params) # type: ignore
+ else:
+ client = client
+ try:
+ response = client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise BedrockError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise BedrockError(status_code=408, message="Timeout error occurred.")
+
+ return response.json()
+
+ async def _make_async_call(
+ self,
+ client: Optional[AsyncHTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ api_base: str,
+ headers: dict,
+ data: dict,
+ ) -> dict:
+ if client is None or not isinstance(client, AsyncHTTPHandler):
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ timeout = httpx.Timeout(timeout)
+ _params["timeout"] = timeout
+ client = get_async_httpx_client(
+ params=_params, llm_provider=litellm.LlmProviders.BEDROCK
+ )
+ else:
+ client = client
+
+ try:
+ response = await client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise BedrockError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise BedrockError(status_code=408, message="Timeout error occurred.")
+
+ return response.json()
+
+ def _single_func_embeddings(
+ self,
+ client: Optional[HTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ batch_data: List[dict],
+ credentials: Any,
+ extra_headers: Optional[dict],
+ endpoint_url: str,
+ aws_region_name: str,
+ model: str,
+ logging_obj: Any,
+ ):
+ try:
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ except ImportError:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+
+ responses: List[dict] = []
+ for data in batch_data:
+ sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
+ headers = {"Content-Type": "application/json"}
+ if extra_headers is not None:
+ headers = {"Content-Type": "application/json", **extra_headers}
+ request = AWSRequest(
+ method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
+ )
+ sigv4.add_auth(request)
+ if (
+ extra_headers is not None and "Authorization" in extra_headers
+ ): # prevent sigv4 from overwriting the auth header
+ request.headers["Authorization"] = extra_headers["Authorization"]
+ prepped = request.prepare()
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=data,
+ api_key="",
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": prepped.url,
+ "headers": prepped.headers,
+ },
+ )
+ response = self._make_sync_call(
+ client=client,
+ timeout=timeout,
+ api_base=prepped.url,
+ headers=prepped.headers, # type: ignore
+ data=data,
+ )
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=data,
+ api_key="",
+ original_response=response,
+ additional_args={"complete_input_dict": data},
+ )
+
+ responses.append(response)
+
+ returned_response: Optional[EmbeddingResponse] = None
+
+ ## TRANSFORM RESPONSE ##
+ if model == "amazon.titan-embed-image-v1":
+ returned_response = (
+ AmazonTitanMultimodalEmbeddingG1Config()._transform_response(
+ response_list=responses, model=model
+ )
+ )
+ elif model == "amazon.titan-embed-text-v1":
+ returned_response = AmazonTitanG1Config()._transform_response(
+ response_list=responses, model=model
+ )
+ elif model == "amazon.titan-embed-text-v2:0":
+ returned_response = AmazonTitanV2Config()._transform_response(
+ response_list=responses, model=model
+ )
+
+ if returned_response is None:
+ raise Exception(
+ "Unable to map model response to known provider format. model={}".format(
+ model
+ )
+ )
+
+ return returned_response
+
+ async def _async_single_func_embeddings(
+ self,
+ client: Optional[AsyncHTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ batch_data: List[dict],
+ credentials: Any,
+ extra_headers: Optional[dict],
+ endpoint_url: str,
+ aws_region_name: str,
+ model: str,
+ logging_obj: Any,
+ ):
+ try:
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ except ImportError:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+
+ responses: List[dict] = []
+ for data in batch_data:
+ sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
+ headers = {"Content-Type": "application/json"}
+ if extra_headers is not None:
+ headers = {"Content-Type": "application/json", **extra_headers}
+ request = AWSRequest(
+ method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
+ )
+ sigv4.add_auth(request)
+ if (
+ extra_headers is not None and "Authorization" in extra_headers
+ ): # prevent sigv4 from overwriting the auth header
+ request.headers["Authorization"] = extra_headers["Authorization"]
+ prepped = request.prepare()
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=data,
+ api_key="",
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": prepped.url,
+ "headers": prepped.headers,
+ },
+ )
+ response = await self._make_async_call(
+ client=client,
+ timeout=timeout,
+ api_base=prepped.url,
+ headers=prepped.headers, # type: ignore
+ data=data,
+ )
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=data,
+ api_key="",
+ original_response=response,
+ additional_args={"complete_input_dict": data},
+ )
+
+ responses.append(response)
+
+ returned_response: Optional[EmbeddingResponse] = None
+
+ ## TRANSFORM RESPONSE ##
+ if model == "amazon.titan-embed-image-v1":
+ returned_response = (
+ AmazonTitanMultimodalEmbeddingG1Config()._transform_response(
+ response_list=responses, model=model
+ )
+ )
+ elif model == "amazon.titan-embed-text-v1":
+ returned_response = AmazonTitanG1Config()._transform_response(
+ response_list=responses, model=model
+ )
+ elif model == "amazon.titan-embed-text-v2:0":
+ returned_response = AmazonTitanV2Config()._transform_response(
+ response_list=responses, model=model
+ )
+
+ if returned_response is None:
+ raise Exception(
+ "Unable to map model response to known provider format. model={}".format(
+ model
+ )
+ )
+
+ return returned_response
+
+ def embeddings(
+ self,
+ model: str,
+ input: List[str],
+ api_base: Optional[str],
+ model_response: EmbeddingResponse,
+ print_verbose: Callable,
+ encoding,
+ logging_obj,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ aembedding: Optional[bool],
+ extra_headers: Optional[dict],
+ optional_params: dict,
+ litellm_params: dict,
+ ) -> EmbeddingResponse:
+ try:
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ except ImportError:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+
+ credentials, aws_region_name = self._load_credentials(optional_params)
+
+ ### TRANSFORMATION ###
+ provider = model.split(".")[0]
+ inference_params = copy.deepcopy(optional_params)
+ inference_params = {
+ k: v
+ for k, v in inference_params.items()
+ if k.lower() not in self.aws_authentication_params
+ }
+ inference_params.pop(
+ "user", None
+ ) # make sure user is not passed in for bedrock call
+ modelId = (
+ optional_params.pop("model_id", None) or model
+ ) # default to model if not passed
+
+ data: Optional[CohereEmbeddingRequest] = None
+ batch_data: Optional[List] = None
+ if provider == "cohere":
+ data = BedrockCohereEmbeddingConfig()._transform_request(
+ model=model, input=input, inference_params=inference_params
+ )
+ elif provider == "amazon" and model in [
+ "amazon.titan-embed-image-v1",
+ "amazon.titan-embed-text-v1",
+ "amazon.titan-embed-text-v2:0",
+ ]:
+ batch_data = []
+ for i in input:
+ if model == "amazon.titan-embed-image-v1":
+ transformed_request: (
+ AmazonEmbeddingRequest
+ ) = AmazonTitanMultimodalEmbeddingG1Config()._transform_request(
+ input=i, inference_params=inference_params
+ )
+ elif model == "amazon.titan-embed-text-v1":
+ transformed_request = AmazonTitanG1Config()._transform_request(
+ input=i, inference_params=inference_params
+ )
+ elif model == "amazon.titan-embed-text-v2:0":
+ transformed_request = AmazonTitanV2Config()._transform_request(
+ input=i, inference_params=inference_params
+ )
+ else:
+ raise Exception(
+ "Unmapped model. Received={}. Expected={}".format(
+ model,
+ [
+ "amazon.titan-embed-image-v1",
+ "amazon.titan-embed-text-v1",
+ "amazon.titan-embed-text-v2:0",
+ ],
+ )
+ )
+ batch_data.append(transformed_request)
+
+ ### SET RUNTIME ENDPOINT ###
+ endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
+ api_base=api_base,
+ aws_bedrock_runtime_endpoint=optional_params.pop(
+ "aws_bedrock_runtime_endpoint", None
+ ),
+ aws_region_name=aws_region_name,
+ )
+ endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
+
+ if batch_data is not None:
+ if aembedding:
+ return self._async_single_func_embeddings( # type: ignore
+ client=(
+ client
+ if client is not None and isinstance(client, AsyncHTTPHandler)
+ else None
+ ),
+ timeout=timeout,
+ batch_data=batch_data,
+ credentials=credentials,
+ extra_headers=extra_headers,
+ endpoint_url=endpoint_url,
+ aws_region_name=aws_region_name,
+ model=model,
+ logging_obj=logging_obj,
+ )
+ return self._single_func_embeddings(
+ client=(
+ client
+ if client is not None and isinstance(client, HTTPHandler)
+ else None
+ ),
+ timeout=timeout,
+ batch_data=batch_data,
+ credentials=credentials,
+ extra_headers=extra_headers,
+ endpoint_url=endpoint_url,
+ aws_region_name=aws_region_name,
+ model=model,
+ logging_obj=logging_obj,
+ )
+ elif data is None:
+ raise Exception("Unable to map Bedrock request to provider")
+
+ sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
+ headers = {"Content-Type": "application/json"}
+ if extra_headers is not None:
+ headers = {"Content-Type": "application/json", **extra_headers}
+
+ request = AWSRequest(
+ method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
+ )
+ sigv4.add_auth(request)
+ if (
+ extra_headers is not None and "Authorization" in extra_headers
+ ): # prevent sigv4 from overwriting the auth header
+ request.headers["Authorization"] = extra_headers["Authorization"]
+ prepped = request.prepare()
+
+ ## ROUTING ##
+ return cohere_embedding(
+ model=model,
+ input=input,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ encoding=encoding,
+ data=data, # type: ignore
+ complete_api_base=prepped.url,
+ api_key=None,
+ aembedding=aembedding,
+ timeout=timeout,
+ client=client,
+ headers=prepped.headers, # type: ignore
+ )