diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/bedrock/base_aws_llm.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/bedrock/base_aws_llm.py | 627 |
1 files changed, 627 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/base_aws_llm.py b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/base_aws_llm.py new file mode 100644 index 00000000..5482d806 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/base_aws_llm.py @@ -0,0 +1,627 @@ +import hashlib +import json +import os +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast, get_args + +import httpx +from pydantic import BaseModel + +from litellm._logging import verbose_logger +from litellm.caching.caching import DualCache +from litellm.constants import BEDROCK_INVOKE_PROVIDERS_LITERAL +from litellm.litellm_core_utils.dd_tracing import tracer +from litellm.secret_managers.main import get_secret + +if TYPE_CHECKING: + from botocore.awsrequest import AWSPreparedRequest + from botocore.credentials import Credentials +else: + Credentials = Any + AWSPreparedRequest = Any + + +class Boto3CredentialsInfo(BaseModel): + credentials: Credentials + aws_region_name: str + aws_bedrock_runtime_endpoint: Optional[str] + + +class AwsAuthError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class BaseAWSLLM: + def __init__(self) -> None: + self.iam_cache = DualCache() + super().__init__() + self.aws_authentication_params = [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region_name", + "aws_session_name", + "aws_profile_name", + "aws_role_name", + "aws_web_identity_token", + "aws_sts_endpoint", + "aws_bedrock_runtime_endpoint", + ] + + def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str: + """ + Generate a unique cache key based on the credential arguments. + """ + # Convert credential arguments to a JSON string and hash it to create a unique key + credential_str = json.dumps(credential_args, sort_keys=True) + return hashlib.sha256(credential_str.encode()).hexdigest() + + @tracer.wrap() + def get_credentials( + self, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_session_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + aws_role_name: Optional[str] = None, + aws_web_identity_token: Optional[str] = None, + aws_sts_endpoint: Optional[str] = None, + ): + """ + Return a boto3.Credentials object + """ + ## CHECK IS 'os.environ/' passed in + params_to_check: List[Optional[str]] = [ + aws_access_key_id, + aws_secret_access_key, + aws_session_token, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + aws_sts_endpoint, + ] + + # Iterate over parameters and update if needed + for i, param in enumerate(params_to_check): + if param and param.startswith("os.environ/"): + _v = get_secret(param) + if _v is not None and isinstance(_v, str): + params_to_check[i] = _v + elif param is None: # check if uppercase value in env + key = self.aws_authentication_params[i] + if key.upper() in os.environ: + params_to_check[i] = os.getenv(key) + + # Assign updated values back to parameters + ( + aws_access_key_id, + aws_secret_access_key, + aws_session_token, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + aws_sts_endpoint, + ) = params_to_check + + verbose_logger.debug( + "in get credentials\n" + "aws_access_key_id=%s\n" + "aws_secret_access_key=%s\n" + "aws_session_token=%s\n" + "aws_region_name=%s\n" + "aws_session_name=%s\n" + "aws_profile_name=%s\n" + "aws_role_name=%s\n" + "aws_web_identity_token=%s\n" + "aws_sts_endpoint=%s", + aws_access_key_id, + aws_secret_access_key, + aws_session_token, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + aws_sts_endpoint, + ) + + # create cache key for non-expiring auth flows + args = {k: v for k, v in locals().items() if k.startswith("aws_")} + + cache_key = self.get_cache_key(args) + _cached_credentials = self.iam_cache.get_cache(cache_key) + if _cached_credentials: + return _cached_credentials + + ######################################################### + # Handle diff boto3 auth flows + # for each helper + # Return: + # Credentials - boto3.Credentials + # cache ttl - Optional[int]. If None, the credentials are not cached. Some auth flows have no expiry time. + ######################################################### + if ( + aws_web_identity_token is not None + and aws_role_name is not None + and aws_session_name is not None + ): + credentials, _cache_ttl = self._auth_with_web_identity_token( + aws_web_identity_token=aws_web_identity_token, + aws_role_name=aws_role_name, + aws_session_name=aws_session_name, + aws_region_name=aws_region_name, + aws_sts_endpoint=aws_sts_endpoint, + ) + elif aws_role_name is not None and aws_session_name is not None: + credentials, _cache_ttl = self._auth_with_aws_role( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_role_name=aws_role_name, + aws_session_name=aws_session_name, + ) + + elif aws_profile_name is not None: ### CHECK SESSION ### + credentials, _cache_ttl = self._auth_with_aws_profile(aws_profile_name) + elif ( + aws_access_key_id is not None + and aws_secret_access_key is not None + and aws_session_token is not None + ): + credentials, _cache_ttl = self._auth_with_aws_session_token( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + elif ( + aws_access_key_id is not None + and aws_secret_access_key is not None + and aws_region_name is not None + ): + credentials, _cache_ttl = self._auth_with_access_key_and_secret_key( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + ) + else: + credentials, _cache_ttl = self._auth_with_env_vars() + + self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl) + return credentials + + def _get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]: + try: + # First check if the string contains the expected prefix + if not isinstance(model, str) or "arn:aws:bedrock" not in model: + return None + + # Split the ARN and check if we have enough parts + parts = model.split(":") + if len(parts) < 4: + return None + + # Get the region from the correct position + region = parts[3] + if not region: # Check if region is empty + return None + + return region + except Exception: + # Catch any unexpected errors and return None + return None + + @staticmethod + def _get_provider_from_model_path( + model_path: str, + ) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]: + """ + Helper function to get the provider from a model path with format: provider/model-name + + Args: + model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name') + + Returns: + Optional[str]: The provider name, or None if no valid provider found + """ + parts = model_path.split("/") + if len(parts) >= 1: + provider = parts[0] + if provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL): + return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, provider) + return None + + @staticmethod + def get_bedrock_invoke_provider( + model: str, + ) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]: + """ + Helper function to get the bedrock provider from the model + + handles 3 scenarions: + 1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` + 2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` + 3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama` + 4. model=us.amazon.nova-pro-v1:0 -> Returns `nova` + """ + if model.startswith("invoke/"): + model = model.replace("invoke/", "", 1) + + _split_model = model.split(".")[0] + if _split_model in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL): + return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model) + + # If not a known provider, check for pattern with two slashes + provider = BaseAWSLLM._get_provider_from_model_path(model) + if provider is not None: + return provider + + # check if provider == "nova" + if "nova" in model: + return "nova" + else: + for provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL): + if provider in model: + return provider + return None + + def _get_aws_region_name( + self, + optional_params: dict, + model: Optional[str] = None, + model_id: Optional[str] = None, + ) -> str: + """ + Get the AWS region name from the environment variables. + + Parameters: + optional_params (dict): Optional parameters for the model call + model (str): The model name + model_id (str): The model ID. This is the ARN of the model, if passed in as a separate param. + + Returns: + str: The AWS region name + """ + aws_region_name = optional_params.get("aws_region_name", None) + ### SET REGION NAME ### + if aws_region_name is None: + # check model arn # + if model_id is not None: + aws_region_name = self._get_aws_region_from_model_arn(model_id) + else: + aws_region_name = self._get_aws_region_from_model_arn(model) + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if ( + aws_region_name is None + and litellm_aws_region_name is not None + and isinstance(litellm_aws_region_name, str) + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if ( + aws_region_name is None + and standard_aws_region_name is not None + and isinstance(standard_aws_region_name, str) + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + return aws_region_name + + @tracer.wrap() + def _auth_with_web_identity_token( + self, + aws_web_identity_token: str, + aws_role_name: str, + aws_session_name: str, + aws_region_name: Optional[str], + aws_sts_endpoint: Optional[str], + ) -> Tuple[Credentials, Optional[int]]: + """ + Authenticate with AWS Web Identity Token + """ + import boto3 + + verbose_logger.debug( + f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}" + ) + + if aws_sts_endpoint is None: + sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com" + else: + sts_endpoint = aws_sts_endpoint + + oidc_token = get_secret(aws_web_identity_token) + + if oidc_token is None: + raise AwsAuthError( + message="OIDC token could not be retrieved from secret manager.", + status_code=401, + ) + + with tracer.trace("boto3.client(sts)"): + sts_client = boto3.client( + "sts", + region_name=aws_region_name, + endpoint_url=sts_endpoint, + ) + + # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=aws_role_name, + RoleSessionName=aws_session_name, + WebIdentityToken=oidc_token, + DurationSeconds=3600, + Policy='{"Version":"2012-10-17","Statement":[{"Sid":"BedrockLiteLLM","Effect":"Allow","Action":["bedrock:InvokeModel","bedrock:InvokeModelWithResponseStream"],"Resource":"*","Condition":{"Bool":{"aws:SecureTransport":"true"},"StringLike":{"aws:UserAgent":"litellm/*"}}}]}', + ) + + iam_creds_dict = { + "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"], + "aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"], + "aws_session_token": sts_response["Credentials"]["SessionToken"], + "region_name": aws_region_name, + } + + if sts_response["PackedPolicySize"] > 75: + verbose_logger.warning( + f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}" + ) + + with tracer.trace("boto3.Session(**iam_creds_dict)"): + session = boto3.Session(**iam_creds_dict) + + iam_creds = session.get_credentials() + return iam_creds, self._get_default_ttl_for_boto3_credentials() + + @tracer.wrap() + def _auth_with_aws_role( + self, + aws_access_key_id: Optional[str], + aws_secret_access_key: Optional[str], + aws_role_name: str, + aws_session_name: str, + ) -> Tuple[Credentials, Optional[int]]: + """ + Authenticate with AWS Role + """ + import boto3 + from botocore.credentials import Credentials + + with tracer.trace("boto3.client(sts)"): + sts_client = boto3.client( + "sts", + aws_access_key_id=aws_access_key_id, # [OPTIONAL] + aws_secret_access_key=aws_secret_access_key, # [OPTIONAL] + ) + + sts_response = sts_client.assume_role( + RoleArn=aws_role_name, RoleSessionName=aws_session_name + ) + + # Extract the credentials from the response and convert to Session Credentials + sts_credentials = sts_response["Credentials"] + credentials = Credentials( + access_key=sts_credentials["AccessKeyId"], + secret_key=sts_credentials["SecretAccessKey"], + token=sts_credentials["SessionToken"], + ) + + sts_expiry = sts_credentials["Expiration"] + # Convert to timezone-aware datetime for comparison + current_time = datetime.now(sts_expiry.tzinfo) + sts_ttl = (sts_expiry - current_time).total_seconds() - 60 + return credentials, sts_ttl + + @tracer.wrap() + def _auth_with_aws_profile( + self, aws_profile_name: str + ) -> Tuple[Credentials, Optional[int]]: + """ + Authenticate with AWS profile + """ + import boto3 + + # uses auth values from AWS profile usually stored in ~/.aws/credentials + with tracer.trace("boto3.Session(profile_name=aws_profile_name)"): + client = boto3.Session(profile_name=aws_profile_name) + return client.get_credentials(), None + + @tracer.wrap() + def _auth_with_aws_session_token( + self, + aws_access_key_id: str, + aws_secret_access_key: str, + aws_session_token: str, + ) -> Tuple[Credentials, Optional[int]]: + """ + Authenticate with AWS Session Token + """ + ### CHECK FOR AWS SESSION TOKEN ### + from botocore.credentials import Credentials + + credentials = Credentials( + access_key=aws_access_key_id, + secret_key=aws_secret_access_key, + token=aws_session_token, + ) + + return credentials, None + + @tracer.wrap() + def _auth_with_access_key_and_secret_key( + self, + aws_access_key_id: str, + aws_secret_access_key: str, + aws_region_name: Optional[str], + ) -> Tuple[Credentials, Optional[int]]: + """ + Authenticate with AWS Access Key and Secret Key + """ + import boto3 + + # Check if credentials are already in cache. These credentials have no expiry time. + with tracer.trace( + "boto3.Session(aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region_name)" + ): + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) + + credentials = session.get_credentials() + return credentials, self._get_default_ttl_for_boto3_credentials() + + @tracer.wrap() + def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]: + """ + Authenticate with AWS Environment Variables + """ + import boto3 + + with tracer.trace("boto3.Session()"): + session = boto3.Session() + credentials = session.get_credentials() + return credentials, None + + @tracer.wrap() + def _get_default_ttl_for_boto3_credentials(self) -> int: + """ + Get the default TTL for boto3 credentials + + Returns `3600-60` which is 59 minutes + """ + return 3600 - 60 + + def get_runtime_endpoint( + self, + api_base: Optional[str], + aws_bedrock_runtime_endpoint: Optional[str], + aws_region_name: str, + ) -> Tuple[str, str]: + env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") + if api_base is not None: + endpoint_url = api_base + elif aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + endpoint_url = aws_bedrock_runtime_endpoint + elif env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + endpoint_url = env_aws_bedrock_runtime_endpoint + else: + endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + + # Determine proxy_endpoint_url + if env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + proxy_endpoint_url = env_aws_bedrock_runtime_endpoint + elif aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + proxy_endpoint_url = aws_bedrock_runtime_endpoint + else: + proxy_endpoint_url = endpoint_url + + return endpoint_url, proxy_endpoint_url + + def _get_boto_credentials_from_optional_params( + self, optional_params: dict, model: Optional[str] = None + ) -> Boto3CredentialsInfo: + """ + Get boto3 credentials from optional params + + Args: + optional_params (dict): Optional parameters for the model call + + Returns: + Credentials: Boto3 credentials object + """ + try: + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) + aws_region_name = self._get_aws_region_name(optional_params, model) + optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_profile_name = optional_params.pop("aws_profile_name", None) + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) + aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + aws_web_identity_token=aws_web_identity_token, + aws_sts_endpoint=aws_sts_endpoint, + ) + + return Boto3CredentialsInfo( + credentials=credentials, + aws_region_name=aws_region_name, + aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + ) + + @tracer.wrap() + def get_request_headers( + self, + credentials: Credentials, + aws_region_name: str, + extra_headers: Optional[dict], + endpoint_url: str, + data: str, + headers: dict, + ) -> AWSPreparedRequest: + try: + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + + request = AWSRequest( + method="POST", url=endpoint_url, data=data, headers=headers + ) + sigv4.add_auth(request) + if ( + extra_headers is not None and "Authorization" in extra_headers + ): # prevent sigv4 from overwriting the auth header + request.headers["Authorization"] = extra_headers["Authorization"] + prepped = request.prepare() + + return prepped |