diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/secret_managers/main.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/secret_managers/main.py | 354 |
1 files changed, 354 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/secret_managers/main.py b/.venv/lib/python3.12/site-packages/litellm/secret_managers/main.py new file mode 100644 index 00000000..e505484b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/secret_managers/main.py @@ -0,0 +1,354 @@ +import ast +import base64 +import binascii +import os +import traceback +from typing import Any, Optional, Union + +import httpx + +import litellm +from litellm._logging import print_verbose, verbose_logger +from litellm.caching.caching import DualCache +from litellm.llms.custom_httpx.http_handler import HTTPHandler +from litellm.proxy._types import KeyManagementSystem + +oidc_cache = DualCache() + + +######### Secret Manager ############################ +# checks if user has passed in a secret manager client +# if passed in then checks the secret there +def _is_base64(s): + try: + return base64.b64encode(base64.b64decode(s)).decode() == s + except binascii.Error: + return False + + +def str_to_bool(value: Optional[str]) -> Optional[bool]: + """ + Converts a string to a boolean if it's a recognized boolean string. + Returns None if the string is not a recognized boolean value. + + :param value: The string to be checked. + :return: True or False if the string is a recognized boolean, otherwise None. + """ + if value is None: + return None + + true_values = {"true"} + false_values = {"false"} + + value_lower = value.strip().lower() + + if value_lower in true_values: + return True + elif value_lower in false_values: + return False + else: + return None + + +def get_secret_str( + secret_name: str, + default_value: Optional[Union[str, bool]] = None, +) -> Optional[str]: + """ + Guarantees response from 'get_secret' is either string or none. Used for fixing linting errors. + """ + value = get_secret(secret_name=secret_name, default_value=default_value) + if value is not None and not isinstance(value, str): + return None + + return value + + +def get_secret_bool( + secret_name: str, + default_value: Optional[bool] = None, +) -> Optional[bool]: + """ + Guarantees response from 'get_secret' is either boolean or none. Used for fixing linting errors. + + Args: + secret_name: The name of the secret to get. + default_value: The default value to return if the secret is not found. + + Returns: + The secret value as a boolean or None if the secret is not found. + """ + _secret_value = get_secret(secret_name, default_value) + if _secret_value is None: + return None + elif isinstance(_secret_value, bool): + return _secret_value + else: + return str_to_bool(_secret_value) + + +def get_secret( # noqa: PLR0915 + secret_name: str, + default_value: Optional[Union[str, bool]] = None, +): + key_management_system = litellm._key_management_system + key_management_settings = litellm._key_management_settings + secret = None + + if secret_name.startswith("os.environ/"): + secret_name = secret_name.replace("os.environ/", "") + + # Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke + if secret_name.startswith("oidc/"): + secret_name_split = secret_name.replace("oidc/", "") + oidc_provider, oidc_aud = secret_name_split.split("/", 1) + # TODO: Add caching for HTTP requests + if oidc_provider == "google": + oidc_token = oidc_cache.get_cache(key=secret_name) + if oidc_token is not None: + return oidc_token + + oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + # https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature + response = oidc_client.get( + "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity", + params={"audience": oidc_aud}, + headers={"Metadata-Flavor": "Google"}, + ) + if response.status_code == 200: + oidc_token = response.text + oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60) + return oidc_token + else: + raise ValueError("Google OIDC provider failed") + elif oidc_provider == "circleci": + # https://circleci.com/docs/openid-connect-tokens/ + env_secret = os.getenv("CIRCLE_OIDC_TOKEN") + if env_secret is None: + raise ValueError("CIRCLE_OIDC_TOKEN not found in environment") + return env_secret + elif oidc_provider == "circleci_v2": + # https://circleci.com/docs/openid-connect-tokens/ + env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2") + if env_secret is None: + raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment") + return env_secret + elif oidc_provider == "github": + # https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions + actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL") + actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") + if ( + actions_id_token_request_url is None + or actions_id_token_request_token is None + ): + raise ValueError( + "ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment" + ) + + oidc_token = oidc_cache.get_cache(key=secret_name) + if oidc_token is not None: + return oidc_token + + oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + response = oidc_client.get( + actions_id_token_request_url, + params={"audience": oidc_aud}, + headers={ + "Authorization": f"Bearer {actions_id_token_request_token}", + "Accept": "application/json; api-version=2.0", + }, + ) + if response.status_code == 200: + oidc_token = response.json().get("value", None) + oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5) + return oidc_token + else: + raise ValueError("Github OIDC provider failed") + elif oidc_provider == "azure": + # https://azure.github.io/azure-workload-identity/docs/quick-start.html + azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE") + if azure_federated_token_file is None: + raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment") + with open(azure_federated_token_file, "r") as f: + oidc_token = f.read() + return oidc_token + elif oidc_provider == "file": + # Load token from a file + with open(oidc_aud, "r") as f: + oidc_token = f.read() + return oidc_token + elif oidc_provider == "env": + # Load token directly from an environment variable + oidc_token = os.getenv(oidc_aud) + if oidc_token is None: + raise ValueError(f"Environment variable {oidc_aud} not found") + return oidc_token + elif oidc_provider == "env_path": + # Load token from a file path specified in an environment variable + token_file_path = os.getenv(oidc_aud) + if token_file_path is None: + raise ValueError(f"Environment variable {oidc_aud} not found") + with open(token_file_path, "r") as f: + oidc_token = f.read() + return oidc_token + else: + raise ValueError("Unsupported OIDC provider") + + try: + if ( + _should_read_secret_from_secret_manager() + and litellm.secret_manager_client is not None + ): + try: + client = litellm.secret_manager_client + key_manager = "local" + if key_management_system is not None: + key_manager = key_management_system.value + + if key_management_settings is not None: + if ( + key_management_settings.hosted_keys is not None + and secret_name not in key_management_settings.hosted_keys + ): # allow user to specify which keys to check in hosted key manager + key_manager = "local" + + if ( + key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value + or type(client).__module__ + "." + type(client).__name__ + == "azure.keyvault.secrets._client.SecretClient" + ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient + secret = client.get_secret(secret_name).value + elif ( + key_manager == KeyManagementSystem.GOOGLE_KMS.value + or client.__class__.__name__ == "KeyManagementServiceClient" + ): + encrypted_secret: Any = os.getenv(secret_name) + if encrypted_secret is None: + raise ValueError( + "Google KMS requires the encrypted secret to be in the environment!" + ) + b64_flag = _is_base64(encrypted_secret) + if b64_flag is True: # if passed in as encoded b64 string + encrypted_secret = base64.b64decode(encrypted_secret) + ciphertext = encrypted_secret + else: + raise ValueError( + "Google KMS requires the encrypted secret to be encoded in base64" + ) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce + response = client.decrypt( + request={ + "name": litellm._google_kms_resource_name, + "ciphertext": ciphertext, + } + ) + secret = response.plaintext.decode( + "utf-8" + ) # assumes the original value was encoded with utf-8 + elif key_manager == KeyManagementSystem.AWS_KMS.value: + """ + Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys. + """ + encrypted_value = os.getenv(secret_name, None) + if encrypted_value is None: + raise Exception( + "AWS KMS - Encrypted Value of Key={} is None".format( + secret_name + ) + ) + # Decode the base64 encoded ciphertext + ciphertext_blob = base64.b64decode(encrypted_value) + + # Set up the parameters for the decrypt call + params = {"CiphertextBlob": ciphertext_blob} + # Perform the decryption + response = client.decrypt(**params) + + # Extract and decode the plaintext + plaintext = response["Plaintext"] + secret = plaintext.decode("utf-8") + if isinstance(secret, str): + secret = secret.strip() + elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) + + if isinstance(client, AWSSecretsManagerV2): + secret = client.sync_read_secret( + secret_name=secret_name, + primary_secret_name=key_management_settings.primary_secret_name, + ) + print_verbose(f"get_secret_value_response: {secret}") + elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value: + try: + secret = client.get_secret_from_google_secret_manager( + secret_name + ) + print_verbose(f"secret from google secret manager: {secret}") + if secret is None: + raise ValueError( + f"No secret found in Google Secret Manager for {secret_name}" + ) + except Exception as e: + print_verbose(f"An error occurred - {str(e)}") + raise e + elif key_manager == KeyManagementSystem.HASHICORP_VAULT.value: + try: + secret = client.sync_read_secret(secret_name=secret_name) + if secret is None: + raise ValueError( + f"No secret found in Hashicorp Secret Manager for {secret_name}" + ) + except Exception as e: + print_verbose(f"An error occurred - {str(e)}") + raise e + elif key_manager == "local": + secret = os.getenv(secret_name) + else: # assume the default is infisicial client + secret = client.get_secret(secret_name).secret_value + except Exception as e: # check if it's in os.environ + verbose_logger.error( + f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}" + ) + secret = os.getenv(secret_name) + try: + if isinstance(secret, str): + secret_value_as_bool = ast.literal_eval(secret) + if isinstance(secret_value_as_bool, bool): + return secret_value_as_bool + else: + return secret + except Exception: + return secret + else: + secret = os.environ.get(secret_name) + secret_value_as_bool = str_to_bool(secret) if secret is not None else None + if secret_value_as_bool is not None and isinstance( + secret_value_as_bool, bool + ): + return secret_value_as_bool + else: + return secret + except Exception as e: + if default_value is not None: + return default_value + else: + raise e + + +def _should_read_secret_from_secret_manager() -> bool: + """ + Returns True if the secret manager should be used to read the secret, False otherwise + + - If the secret manager client is not set, return False + - If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True + - Otherwise, return False + """ + if litellm.secret_manager_client is not None: + if litellm._key_management_settings is not None: + if ( + litellm._key_management_settings.access_mode == "read_only" + or litellm._key_management_settings.access_mode == "read_and_write" + ): + return True + return False |