aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py187
1 files changed, 187 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py
new file mode 100644
index 00000000..053cdb91
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py
@@ -0,0 +1,187 @@
+import os
+from typing import Any, Optional, Union
+
+import httpx
+
+
+def init_rds_client(
+ aws_access_key_id: Optional[str] = None,
+ aws_secret_access_key: Optional[str] = None,
+ aws_region_name: Optional[str] = None,
+ aws_session_name: Optional[str] = None,
+ aws_profile_name: Optional[str] = None,
+ aws_role_name: Optional[str] = None,
+ aws_web_identity_token: Optional[str] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+):
+ from litellm.secret_managers.main import get_secret
+
+ # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
+ litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
+ standard_aws_region_name = get_secret("AWS_REGION", None)
+ ## CHECK IS 'os.environ/' passed in
+ # Define the list of parameters to check
+ params_to_check = [
+ aws_access_key_id,
+ aws_secret_access_key,
+ aws_region_name,
+ aws_session_name,
+ aws_profile_name,
+ aws_role_name,
+ aws_web_identity_token,
+ ]
+
+ # Iterate over parameters and update if needed
+ for i, param in enumerate(params_to_check):
+ if param and param.startswith("os.environ/"):
+ params_to_check[i] = get_secret(param) # type: ignore
+ # Assign updated values back to parameters
+ (
+ aws_access_key_id,
+ aws_secret_access_key,
+ aws_region_name,
+ aws_session_name,
+ aws_profile_name,
+ aws_role_name,
+ aws_web_identity_token,
+ ) = params_to_check
+
+ ### SET REGION NAME
+ region_name = aws_region_name
+ if aws_region_name:
+ region_name = aws_region_name
+ elif litellm_aws_region_name:
+ region_name = litellm_aws_region_name
+ elif standard_aws_region_name:
+ region_name = standard_aws_region_name
+ else:
+ raise Exception(
+ "AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
+ )
+
+ import boto3
+
+ if isinstance(timeout, float):
+ config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore
+ elif isinstance(timeout, httpx.Timeout):
+ config = boto3.session.Config( # type: ignore
+ connect_timeout=timeout.connect, read_timeout=timeout.read
+ )
+ else:
+ config = boto3.session.Config() # type: ignore
+
+ ### CHECK STS ###
+ if (
+ aws_web_identity_token is not None
+ and aws_role_name is not None
+ and aws_session_name is not None
+ ):
+ try:
+ oidc_token = open(aws_web_identity_token).read() # check if filepath
+ except Exception:
+ oidc_token = get_secret(aws_web_identity_token)
+
+ if oidc_token is None:
+ raise Exception(
+ "OIDC token could not be retrieved from secret manager.",
+ )
+
+ sts_client = boto3.client("sts")
+
+ # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
+ sts_response = sts_client.assume_role_with_web_identity(
+ RoleArn=aws_role_name,
+ RoleSessionName=aws_session_name,
+ WebIdentityToken=oidc_token,
+ DurationSeconds=3600,
+ )
+
+ client = boto3.client(
+ service_name="rds",
+ aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
+ aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
+ aws_session_token=sts_response["Credentials"]["SessionToken"],
+ region_name=region_name,
+ config=config,
+ )
+
+ elif aws_role_name is not None and aws_session_name is not None:
+ # use sts if role name passed in
+ sts_client = boto3.client(
+ "sts",
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ )
+
+ sts_response = sts_client.assume_role(
+ RoleArn=aws_role_name, RoleSessionName=aws_session_name
+ )
+
+ client = boto3.client(
+ service_name="rds",
+ aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
+ aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
+ aws_session_token=sts_response["Credentials"]["SessionToken"],
+ region_name=region_name,
+ config=config,
+ )
+ elif aws_access_key_id is not None:
+ # uses auth params passed to completion
+ # aws_access_key_id is not None, assume user is trying to auth using litellm.completion
+
+ client = boto3.client(
+ service_name="rds",
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=region_name,
+ config=config,
+ )
+ elif aws_profile_name is not None:
+ # uses auth values from AWS profile usually stored in ~/.aws/credentials
+
+ client = boto3.Session(profile_name=aws_profile_name).client(
+ service_name="rds",
+ region_name=region_name,
+ config=config,
+ )
+
+ else:
+ # aws_access_key_id is None, assume user is trying to auth using env variables
+ # boto3 automatically reads env variables
+
+ client = boto3.client(
+ service_name="rds",
+ region_name=region_name,
+ config=config,
+ )
+
+ return client
+
+
+def generate_iam_auth_token(
+ db_host, db_port, db_user, client: Optional[Any] = None
+) -> str:
+ from urllib.parse import quote
+
+ if client is None:
+ boto_client = init_rds_client(
+ aws_region_name=os.getenv("AWS_REGION_NAME"),
+ aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
+ aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
+ aws_session_name=os.getenv("AWS_SESSION_NAME"),
+ aws_profile_name=os.getenv("AWS_PROFILE_NAME"),
+ aws_role_name=os.getenv("AWS_ROLE_NAME", os.getenv("AWS_ROLE_ARN")),
+ aws_web_identity_token=os.getenv(
+ "AWS_WEB_IDENTITY_TOKEN", os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
+ ),
+ )
+ else:
+ boto_client = client
+
+ token = boto_client.generate_db_auth_token(
+ DBHostname=db_host, Port=db_port, DBUsername=db_user
+ )
+ cleaned_token = quote(token, safe="")
+
+ return cleaned_token