diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/bedrock/common_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/bedrock/common_utils.py | 407 |
1 files changed, 407 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/common_utils.py new file mode 100644 index 00000000..4677a579 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/common_utils.py @@ -0,0 +1,407 @@ +""" +Common utilities used across bedrock chat/embedding/image generation +""" + +import os +from typing import List, Literal, Optional, Union + +import httpx + +import litellm +from litellm.llms.base_llm.base_utils import BaseLLMModelInfo +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.secret_managers.main import get_secret + + +class BedrockError(BaseLLMException): + pass + + +class AmazonBedrockGlobalConfig: + def __init__(self): + pass + + def get_mapped_special_auth_params(self) -> dict: + """ + Mapping of common auth params across bedrock/vertex/azure/watsonx + """ + return {"region_name": "aws_region_name"} + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + + def get_all_regions(self) -> List[str]: + return ( + self.get_us_regions() + + self.get_eu_regions() + + self.get_ap_regions() + + self.get_ca_regions() + + self.get_sa_regions() + ) + + def get_ap_regions(self) -> List[str]: + return ["ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-south-1"] + + def get_sa_regions(self) -> List[str]: + return ["sa-east-1"] + + def get_eu_regions(self) -> List[str]: + """ + Source: https://www.aws-services.info/bedrock.html + """ + return [ + "eu-west-1", + "eu-west-2", + "eu-west-3", + "eu-central-1", + ] + + def get_ca_regions(self) -> List[str]: + return ["ca-central-1"] + + def get_us_regions(self) -> List[str]: + """ + Source: https://www.aws-services.info/bedrock.html + """ + return [ + "us-east-2", + "us-east-1", + "us-west-1", + "us-west-2", + "us-gov-west-1", + ] + + +def add_custom_header(headers): + """Closure to capture the headers and add them.""" + + def callback(request, **kwargs): + """Actual callback function that Boto3 will call.""" + for header_name, header_value in headers.items(): + request.headers.add_header(header_name, header_value) + + return callback + + +def init_bedrock_client( + region_name=None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_bedrock_runtime_endpoint: Optional[str] = None, + aws_session_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + aws_role_name: Optional[str] = None, + aws_web_identity_token: Optional[str] = None, + extra_headers: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, +): + # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + standard_aws_region_name = get_secret("AWS_REGION", None) + ## CHECK IS 'os.environ/' passed in + # Define the list of parameters to check + params_to_check = [ + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_bedrock_runtime_endpoint, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + ] + + # Iterate over parameters and update if needed + for i, param in enumerate(params_to_check): + if param and param.startswith("os.environ/"): + params_to_check[i] = get_secret(param) # type: ignore + # Assign updated values back to parameters + ( + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_bedrock_runtime_endpoint, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + ) = params_to_check + + # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. + ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) + + ### SET REGION NAME + if region_name: + pass + elif aws_region_name: + region_name = aws_region_name + elif litellm_aws_region_name: + region_name = litellm_aws_region_name + elif standard_aws_region_name: + region_name = standard_aws_region_name + else: + raise BedrockError( + message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", + status_code=401, + ) + + # check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client + env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") + if aws_bedrock_runtime_endpoint: + endpoint_url = aws_bedrock_runtime_endpoint + elif env_aws_bedrock_runtime_endpoint: + endpoint_url = env_aws_bedrock_runtime_endpoint + else: + endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com" + + import boto3 + + if isinstance(timeout, float): + config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore + elif isinstance(timeout, httpx.Timeout): + config = boto3.session.Config( # type: ignore + connect_timeout=timeout.connect, read_timeout=timeout.read + ) + else: + config = boto3.session.Config() # type: ignore + + ### CHECK STS ### + if ( + aws_web_identity_token is not None + and aws_role_name is not None + and aws_session_name is not None + ): + oidc_token = get_secret(aws_web_identity_token) + + if oidc_token is None: + raise BedrockError( + message="OIDC token could not be retrieved from secret manager.", + status_code=401, + ) + + sts_client = boto3.client("sts") + + # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=aws_role_name, + RoleSessionName=aws_session_name, + WebIdentityToken=oidc_token, + DurationSeconds=3600, + ) + + client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], + aws_session_token=sts_response["Credentials"]["SessionToken"], + region_name=region_name, + endpoint_url=endpoint_url, + config=config, + verify=ssl_verify, + ) + elif aws_role_name is not None and aws_session_name is not None: + # use sts if role name passed in + sts_client = boto3.client( + "sts", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) + + sts_response = sts_client.assume_role( + RoleArn=aws_role_name, RoleSessionName=aws_session_name + ) + + client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], + aws_session_token=sts_response["Credentials"]["SessionToken"], + region_name=region_name, + endpoint_url=endpoint_url, + config=config, + verify=ssl_verify, + ) + elif aws_access_key_id is not None: + # uses auth params passed to completion + # aws_access_key_id is not None, assume user is trying to auth using litellm.completion + + client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=region_name, + endpoint_url=endpoint_url, + config=config, + verify=ssl_verify, + ) + elif aws_profile_name is not None: + # uses auth values from AWS profile usually stored in ~/.aws/credentials + + client = boto3.Session(profile_name=aws_profile_name).client( + service_name="bedrock-runtime", + region_name=region_name, + endpoint_url=endpoint_url, + config=config, + verify=ssl_verify, + ) + else: + # aws_access_key_id is None, assume user is trying to auth using env variables + # boto3 automatically reads env variables + + client = boto3.client( + service_name="bedrock-runtime", + region_name=region_name, + endpoint_url=endpoint_url, + config=config, + verify=ssl_verify, + ) + if extra_headers: + client.meta.events.register( + "before-sign.bedrock-runtime.*", add_custom_header(extra_headers) + ) + + return client + + +class ModelResponseIterator: + def __init__(self, model_response): + self.model_response = model_response + self.is_done = False + + # Sync iterator + def __iter__(self): + return self + + def __next__(self): + if self.is_done: + raise StopIteration + self.is_done = True + return self.model_response + + # Async iterator + def __aiter__(self): + return self + + async def __anext__(self): + if self.is_done: + raise StopAsyncIteration + self.is_done = True + return self.model_response + + +def get_bedrock_tool_name(response_tool_name: str) -> str: + """ + If litellm formatted the input tool name, we need to convert it back to the original name. + + Args: + response_tool_name (str): The name of the tool as received from the response. + + Returns: + str: The original name of the tool. + """ + + if response_tool_name in litellm.bedrock_tool_name_mappings.cache_dict: + response_tool_name = litellm.bedrock_tool_name_mappings.cache_dict[ + response_tool_name + ] + return response_tool_name + + +class BedrockModelInfo(BaseLLMModelInfo): + + global_config = AmazonBedrockGlobalConfig() + all_global_regions = global_config.get_all_regions() + + @staticmethod + def extract_model_name_from_arn(model: str) -> str: + """ + Extract the model name from an AWS Bedrock ARN. + Returns the string after the last '/' if 'arn' is in the input string. + + Args: + arn (str): The ARN string to parse + + Returns: + str: The extracted model name if 'arn' is in the string, + otherwise returns the original string + """ + if "arn" in model.lower(): + return model.split("/")[-1] + return model + + @staticmethod + def get_non_litellm_routing_model_name(model: str) -> str: + if model.startswith("bedrock/"): + model = model.split("/", 1)[1] + + if model.startswith("converse/"): + model = model.split("/", 1)[1] + + if model.startswith("invoke/"): + model = model.split("/", 1)[1] + + return model + + @staticmethod + def get_base_model(model: str) -> str: + """ + Get the base model from the given model name. + + Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" + AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" + """ + + model = BedrockModelInfo.get_non_litellm_routing_model_name(model=model) + model = BedrockModelInfo.extract_model_name_from_arn(model) + + potential_region = model.split(".", 1)[0] + + alt_potential_region = model.split("/", 1)[ + 0 + ] # in model cost map we store regional information like `/us-west-2/bedrock-model` + + if ( + potential_region + in BedrockModelInfo._supported_cross_region_inference_region() + ): + return model.split(".", 1)[1] + elif ( + alt_potential_region in BedrockModelInfo.all_global_regions + and len(model.split("/", 1)) > 1 + ): + return model.split("/", 1)[1] + + return model + + @staticmethod + def _supported_cross_region_inference_region() -> List[str]: + """ + Abbreviations of regions AWS Bedrock supports for cross region inference + """ + return ["us", "eu", "apac"] + + @staticmethod + def get_bedrock_route(model: str) -> Literal["converse", "invoke", "converse_like"]: + """ + Get the bedrock route for the given model. + """ + base_model = BedrockModelInfo.get_base_model(model) + alt_model = BedrockModelInfo.get_non_litellm_routing_model_name(model=model) + if "invoke/" in model: + return "invoke" + elif "converse_like" in model: + return "converse_like" + elif "converse/" in model: + return "converse" + elif ( + base_model in litellm.bedrock_converse_models + or alt_model in litellm.bedrock_converse_models + ): + return "converse" + return "invoke" |