aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py324
1 files changed, 324 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py
new file mode 100644
index 00000000..2030cb2a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py
@@ -0,0 +1,324 @@
+import asyncio
+import json
+import uuid
+from datetime import datetime, timezone
+from typing import Any, List, Optional
+
+from fastapi import status
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import (
+ GenerateKeyRequest,
+ GenerateKeyResponse,
+ KeyRequest,
+ LiteLLM_AuditLogs,
+ LiteLLM_VerificationToken,
+ LitellmTableNames,
+ ProxyErrorTypes,
+ ProxyException,
+ RegenerateKeyRequest,
+ UpdateKeyRequest,
+ UserAPIKeyAuth,
+ WebhookEvent,
+)
+
+# NOTE: This is the prefix for all virtual keys stored in AWS Secrets Manager
+LITELLM_PREFIX_STORED_VIRTUAL_KEYS = "litellm/"
+
+
+class KeyManagementEventHooks:
+
+ @staticmethod
+ async def async_key_generated_hook(
+ data: GenerateKeyRequest,
+ response: GenerateKeyResponse,
+ user_api_key_dict: UserAPIKeyAuth,
+ litellm_changed_by: Optional[str] = None,
+ ):
+ """
+ Hook that runs after a successful /key/generate request
+
+ Handles the following:
+ - Sending Email with Key Details
+ - Storing Audit Logs for key generation
+ - Storing Generated Key in DB
+ """
+ from litellm.proxy.management_helpers.audit_logs import (
+ create_audit_log_for_update,
+ )
+ from litellm.proxy.proxy_server import litellm_proxy_admin_name
+
+ if data.send_invite_email is True:
+ await KeyManagementEventHooks._send_key_created_email(
+ response.model_dump(exclude_none=True)
+ )
+
+ # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+ if litellm.store_audit_logs is True:
+ _updated_values = response.model_dump_json(exclude_none=True)
+ asyncio.create_task(
+ create_audit_log_for_update(
+ request_data=LiteLLM_AuditLogs(
+ id=str(uuid.uuid4()),
+ updated_at=datetime.now(timezone.utc),
+ changed_by=litellm_changed_by
+ or user_api_key_dict.user_id
+ or litellm_proxy_admin_name,
+ changed_by_api_key=user_api_key_dict.api_key,
+ table_name=LitellmTableNames.KEY_TABLE_NAME,
+ object_id=response.token_id or "",
+ action="created",
+ updated_values=_updated_values,
+ before_value=None,
+ )
+ )
+ )
+ # store the generated key in the secret manager
+ await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
+ secret_name=data.key_alias or f"virtual-key-{response.token_id}",
+ secret_token=response.key,
+ )
+
+ @staticmethod
+ async def async_key_updated_hook(
+ data: UpdateKeyRequest,
+ existing_key_row: Any,
+ response: Any,
+ user_api_key_dict: UserAPIKeyAuth,
+ litellm_changed_by: Optional[str] = None,
+ ):
+ """
+ Post /key/update processing hook
+
+ Handles the following:
+ - Storing Audit Logs for key update
+ """
+ from litellm.proxy.management_helpers.audit_logs import (
+ create_audit_log_for_update,
+ )
+ from litellm.proxy.proxy_server import litellm_proxy_admin_name
+
+ # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+ if litellm.store_audit_logs is True:
+ _updated_values = json.dumps(data.json(exclude_none=True), default=str)
+
+ _before_value = existing_key_row.json(exclude_none=True)
+ _before_value = json.dumps(_before_value, default=str)
+
+ asyncio.create_task(
+ create_audit_log_for_update(
+ request_data=LiteLLM_AuditLogs(
+ id=str(uuid.uuid4()),
+ updated_at=datetime.now(timezone.utc),
+ changed_by=litellm_changed_by
+ or user_api_key_dict.user_id
+ or litellm_proxy_admin_name,
+ changed_by_api_key=user_api_key_dict.api_key,
+ table_name=LitellmTableNames.KEY_TABLE_NAME,
+ object_id=data.key,
+ action="updated",
+ updated_values=_updated_values,
+ before_value=_before_value,
+ )
+ )
+ )
+
+ @staticmethod
+ async def async_key_rotated_hook(
+ data: Optional[RegenerateKeyRequest],
+ existing_key_row: Any,
+ response: GenerateKeyResponse,
+ user_api_key_dict: UserAPIKeyAuth,
+ litellm_changed_by: Optional[str] = None,
+ ):
+ # store the generated key in the secret manager
+ if data is not None and response.token_id is not None:
+ initial_secret_name = (
+ existing_key_row.key_alias or f"virtual-key-{existing_key_row.token}"
+ )
+ await KeyManagementEventHooks._rotate_virtual_key_in_secret_manager(
+ current_secret_name=initial_secret_name,
+ new_secret_name=data.key_alias or f"virtual-key-{response.token_id}",
+ new_secret_value=response.key,
+ )
+
+ @staticmethod
+ async def async_key_deleted_hook(
+ data: KeyRequest,
+ keys_being_deleted: List[LiteLLM_VerificationToken],
+ response: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ litellm_changed_by: Optional[str] = None,
+ ):
+ """
+ Post /key/delete processing hook
+
+ Handles the following:
+ - Storing Audit Logs for key deletion
+ """
+ from litellm.proxy.management_helpers.audit_logs import (
+ create_audit_log_for_update,
+ )
+ from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
+
+ # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+ # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
+ if litellm.store_audit_logs is True and data.keys is not None:
+ # make an audit log for each team deleted
+ for key in data.keys:
+ key_row = await prisma_client.get_data( # type: ignore
+ token=key, table_name="key", query_type="find_unique"
+ )
+
+ if key_row is None:
+ raise ProxyException(
+ message=f"Key {key} not found",
+ type=ProxyErrorTypes.bad_request_error,
+ param="key",
+ code=status.HTTP_404_NOT_FOUND,
+ )
+
+ key_row = key_row.json(exclude_none=True)
+ _key_row = json.dumps(key_row, default=str)
+
+ asyncio.create_task(
+ create_audit_log_for_update(
+ request_data=LiteLLM_AuditLogs(
+ id=str(uuid.uuid4()),
+ updated_at=datetime.now(timezone.utc),
+ changed_by=litellm_changed_by
+ or user_api_key_dict.user_id
+ or litellm_proxy_admin_name,
+ changed_by_api_key=user_api_key_dict.api_key,
+ table_name=LitellmTableNames.KEY_TABLE_NAME,
+ object_id=key,
+ action="deleted",
+ updated_values="{}",
+ before_value=_key_row,
+ )
+ )
+ )
+ # delete the keys from the secret manager
+ await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
+ keys_being_deleted=keys_being_deleted
+ )
+ pass
+
+ @staticmethod
+ async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
+ """
+ Store a virtual key in the secret manager
+
+ Args:
+ secret_name: Name of the virtual key
+ secret_token: Value of the virtual key (example: sk-1234)
+ """
+ if litellm._key_management_settings is not None:
+ if litellm._key_management_settings.store_virtual_keys is True:
+ from litellm.secret_managers.base_secret_manager import (
+ BaseSecretManager,
+ )
+
+ # store the key in the secret manager
+ if isinstance(litellm.secret_manager_client, BaseSecretManager):
+ await litellm.secret_manager_client.async_write_secret(
+ secret_name=KeyManagementEventHooks._get_secret_name(
+ secret_name
+ ),
+ secret_value=secret_token,
+ )
+
+ @staticmethod
+ async def _rotate_virtual_key_in_secret_manager(
+ current_secret_name: str, new_secret_name: str, new_secret_value: str
+ ):
+ """
+ Update a virtual key in the secret manager
+
+ Args:
+ secret_name: Name of the virtual key
+ secret_token: Value of the virtual key (example: sk-1234)
+ """
+ if litellm._key_management_settings is not None:
+ if litellm._key_management_settings.store_virtual_keys is True:
+ from litellm.secret_managers.base_secret_manager import (
+ BaseSecretManager,
+ )
+
+ # store the key in the secret manager
+ if isinstance(litellm.secret_manager_client, BaseSecretManager):
+ await litellm.secret_manager_client.async_rotate_secret(
+ current_secret_name=KeyManagementEventHooks._get_secret_name(
+ current_secret_name
+ ),
+ new_secret_name=KeyManagementEventHooks._get_secret_name(
+ new_secret_name
+ ),
+ new_secret_value=new_secret_value,
+ )
+
+ @staticmethod
+ def _get_secret_name(secret_name: str) -> str:
+ if litellm._key_management_settings.prefix_for_stored_virtual_keys.endswith(
+ "/"
+ ):
+ return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{secret_name}"
+ else:
+ return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}/{secret_name}"
+
+ @staticmethod
+ async def _delete_virtual_keys_from_secret_manager(
+ keys_being_deleted: List[LiteLLM_VerificationToken],
+ ):
+ """
+ Deletes virtual keys from the secret manager
+
+ Args:
+ keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
+ """
+ if litellm._key_management_settings is not None:
+ if litellm._key_management_settings.store_virtual_keys is True:
+ from litellm.secret_managers.base_secret_manager import (
+ BaseSecretManager,
+ )
+
+ if isinstance(litellm.secret_manager_client, BaseSecretManager):
+ for key in keys_being_deleted:
+ if key.key_alias is not None:
+ await litellm.secret_manager_client.async_delete_secret(
+ secret_name=KeyManagementEventHooks._get_secret_name(
+ key.key_alias
+ )
+ )
+ else:
+ verbose_proxy_logger.warning(
+ f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
+ )
+
+ @staticmethod
+ async def _send_key_created_email(response: dict):
+ from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
+
+ if "email" not in general_settings.get("alerting", []):
+ raise ValueError(
+ "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
+ )
+ event = WebhookEvent(
+ event="key_created",
+ event_group="key",
+ event_message="API Key Created",
+ token=response.get("token", ""),
+ spend=response.get("spend", 0.0),
+ max_budget=response.get("max_budget", 0.0),
+ user_id=response.get("user_id", None),
+ team_id=response.get("team_id", "Default Team"),
+ key_alias=response.get("key_alias", None),
+ )
+
+ # If user configured email alerting - send an Email letting their end-user know the key was created
+ asyncio.create_task(
+ proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
+ webhook_event=event,
+ )
+ )