about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py324
1 files changed, 324 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py
new file mode 100644
index 00000000..2030cb2a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py
@@ -0,0 +1,324 @@
+import asyncio
+import json
+import uuid
+from datetime import datetime, timezone
+from typing import Any, List, Optional
+
+from fastapi import status
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import (
+    GenerateKeyRequest,
+    GenerateKeyResponse,
+    KeyRequest,
+    LiteLLM_AuditLogs,
+    LiteLLM_VerificationToken,
+    LitellmTableNames,
+    ProxyErrorTypes,
+    ProxyException,
+    RegenerateKeyRequest,
+    UpdateKeyRequest,
+    UserAPIKeyAuth,
+    WebhookEvent,
+)
+
+# NOTE: This is the prefix for all virtual keys stored in AWS Secrets Manager
+LITELLM_PREFIX_STORED_VIRTUAL_KEYS = "litellm/"
+
+
+class KeyManagementEventHooks:
+
+    @staticmethod
+    async def async_key_generated_hook(
+        data: GenerateKeyRequest,
+        response: GenerateKeyResponse,
+        user_api_key_dict: UserAPIKeyAuth,
+        litellm_changed_by: Optional[str] = None,
+    ):
+        """
+        Hook that runs after a successful /key/generate request
+
+        Handles the following:
+        - Sending Email with Key Details
+        - Storing Audit Logs for key generation
+        - Storing Generated Key in DB
+        """
+        from litellm.proxy.management_helpers.audit_logs import (
+            create_audit_log_for_update,
+        )
+        from litellm.proxy.proxy_server import litellm_proxy_admin_name
+
+        if data.send_invite_email is True:
+            await KeyManagementEventHooks._send_key_created_email(
+                response.model_dump(exclude_none=True)
+            )
+
+        # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+        if litellm.store_audit_logs is True:
+            _updated_values = response.model_dump_json(exclude_none=True)
+            asyncio.create_task(
+                create_audit_log_for_update(
+                    request_data=LiteLLM_AuditLogs(
+                        id=str(uuid.uuid4()),
+                        updated_at=datetime.now(timezone.utc),
+                        changed_by=litellm_changed_by
+                        or user_api_key_dict.user_id
+                        or litellm_proxy_admin_name,
+                        changed_by_api_key=user_api_key_dict.api_key,
+                        table_name=LitellmTableNames.KEY_TABLE_NAME,
+                        object_id=response.token_id or "",
+                        action="created",
+                        updated_values=_updated_values,
+                        before_value=None,
+                    )
+                )
+            )
+        # store the generated key in the secret manager
+        await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
+            secret_name=data.key_alias or f"virtual-key-{response.token_id}",
+            secret_token=response.key,
+        )
+
+    @staticmethod
+    async def async_key_updated_hook(
+        data: UpdateKeyRequest,
+        existing_key_row: Any,
+        response: Any,
+        user_api_key_dict: UserAPIKeyAuth,
+        litellm_changed_by: Optional[str] = None,
+    ):
+        """
+        Post /key/update processing hook
+
+        Handles the following:
+        - Storing Audit Logs for key update
+        """
+        from litellm.proxy.management_helpers.audit_logs import (
+            create_audit_log_for_update,
+        )
+        from litellm.proxy.proxy_server import litellm_proxy_admin_name
+
+        # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+        if litellm.store_audit_logs is True:
+            _updated_values = json.dumps(data.json(exclude_none=True), default=str)
+
+            _before_value = existing_key_row.json(exclude_none=True)
+            _before_value = json.dumps(_before_value, default=str)
+
+            asyncio.create_task(
+                create_audit_log_for_update(
+                    request_data=LiteLLM_AuditLogs(
+                        id=str(uuid.uuid4()),
+                        updated_at=datetime.now(timezone.utc),
+                        changed_by=litellm_changed_by
+                        or user_api_key_dict.user_id
+                        or litellm_proxy_admin_name,
+                        changed_by_api_key=user_api_key_dict.api_key,
+                        table_name=LitellmTableNames.KEY_TABLE_NAME,
+                        object_id=data.key,
+                        action="updated",
+                        updated_values=_updated_values,
+                        before_value=_before_value,
+                    )
+                )
+            )
+
+    @staticmethod
+    async def async_key_rotated_hook(
+        data: Optional[RegenerateKeyRequest],
+        existing_key_row: Any,
+        response: GenerateKeyResponse,
+        user_api_key_dict: UserAPIKeyAuth,
+        litellm_changed_by: Optional[str] = None,
+    ):
+        # store the generated key in the secret manager
+        if data is not None and response.token_id is not None:
+            initial_secret_name = (
+                existing_key_row.key_alias or f"virtual-key-{existing_key_row.token}"
+            )
+            await KeyManagementEventHooks._rotate_virtual_key_in_secret_manager(
+                current_secret_name=initial_secret_name,
+                new_secret_name=data.key_alias or f"virtual-key-{response.token_id}",
+                new_secret_value=response.key,
+            )
+
+    @staticmethod
+    async def async_key_deleted_hook(
+        data: KeyRequest,
+        keys_being_deleted: List[LiteLLM_VerificationToken],
+        response: dict,
+        user_api_key_dict: UserAPIKeyAuth,
+        litellm_changed_by: Optional[str] = None,
+    ):
+        """
+        Post /key/delete processing hook
+
+        Handles the following:
+        - Storing Audit Logs for key deletion
+        """
+        from litellm.proxy.management_helpers.audit_logs import (
+            create_audit_log_for_update,
+        )
+        from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
+
+        # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+        # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
+        if litellm.store_audit_logs is True and data.keys is not None:
+            # make an audit log for each team deleted
+            for key in data.keys:
+                key_row = await prisma_client.get_data(  # type: ignore
+                    token=key, table_name="key", query_type="find_unique"
+                )
+
+                if key_row is None:
+                    raise ProxyException(
+                        message=f"Key {key} not found",
+                        type=ProxyErrorTypes.bad_request_error,
+                        param="key",
+                        code=status.HTTP_404_NOT_FOUND,
+                    )
+
+                key_row = key_row.json(exclude_none=True)
+                _key_row = json.dumps(key_row, default=str)
+
+                asyncio.create_task(
+                    create_audit_log_for_update(
+                        request_data=LiteLLM_AuditLogs(
+                            id=str(uuid.uuid4()),
+                            updated_at=datetime.now(timezone.utc),
+                            changed_by=litellm_changed_by
+                            or user_api_key_dict.user_id
+                            or litellm_proxy_admin_name,
+                            changed_by_api_key=user_api_key_dict.api_key,
+                            table_name=LitellmTableNames.KEY_TABLE_NAME,
+                            object_id=key,
+                            action="deleted",
+                            updated_values="{}",
+                            before_value=_key_row,
+                        )
+                    )
+                )
+        # delete the keys from the secret manager
+        await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
+            keys_being_deleted=keys_being_deleted
+        )
+        pass
+
+    @staticmethod
+    async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
+        """
+        Store a virtual key in the secret manager
+
+        Args:
+            secret_name: Name of the virtual key
+            secret_token: Value of the virtual key (example: sk-1234)
+        """
+        if litellm._key_management_settings is not None:
+            if litellm._key_management_settings.store_virtual_keys is True:
+                from litellm.secret_managers.base_secret_manager import (
+                    BaseSecretManager,
+                )
+
+                # store the key in the secret manager
+                if isinstance(litellm.secret_manager_client, BaseSecretManager):
+                    await litellm.secret_manager_client.async_write_secret(
+                        secret_name=KeyManagementEventHooks._get_secret_name(
+                            secret_name
+                        ),
+                        secret_value=secret_token,
+                    )
+
+    @staticmethod
+    async def _rotate_virtual_key_in_secret_manager(
+        current_secret_name: str, new_secret_name: str, new_secret_value: str
+    ):
+        """
+        Update a virtual key in the secret manager
+
+        Args:
+            secret_name: Name of the virtual key
+            secret_token: Value of the virtual key (example: sk-1234)
+        """
+        if litellm._key_management_settings is not None:
+            if litellm._key_management_settings.store_virtual_keys is True:
+                from litellm.secret_managers.base_secret_manager import (
+                    BaseSecretManager,
+                )
+
+                # store the key in the secret manager
+                if isinstance(litellm.secret_manager_client, BaseSecretManager):
+                    await litellm.secret_manager_client.async_rotate_secret(
+                        current_secret_name=KeyManagementEventHooks._get_secret_name(
+                            current_secret_name
+                        ),
+                        new_secret_name=KeyManagementEventHooks._get_secret_name(
+                            new_secret_name
+                        ),
+                        new_secret_value=new_secret_value,
+                    )
+
+    @staticmethod
+    def _get_secret_name(secret_name: str) -> str:
+        if litellm._key_management_settings.prefix_for_stored_virtual_keys.endswith(
+            "/"
+        ):
+            return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{secret_name}"
+        else:
+            return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}/{secret_name}"
+
+    @staticmethod
+    async def _delete_virtual_keys_from_secret_manager(
+        keys_being_deleted: List[LiteLLM_VerificationToken],
+    ):
+        """
+        Deletes virtual keys from the secret manager
+
+        Args:
+            keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
+        """
+        if litellm._key_management_settings is not None:
+            if litellm._key_management_settings.store_virtual_keys is True:
+                from litellm.secret_managers.base_secret_manager import (
+                    BaseSecretManager,
+                )
+
+                if isinstance(litellm.secret_manager_client, BaseSecretManager):
+                    for key in keys_being_deleted:
+                        if key.key_alias is not None:
+                            await litellm.secret_manager_client.async_delete_secret(
+                                secret_name=KeyManagementEventHooks._get_secret_name(
+                                    key.key_alias
+                                )
+                            )
+                        else:
+                            verbose_proxy_logger.warning(
+                                f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
+                            )
+
+    @staticmethod
+    async def _send_key_created_email(response: dict):
+        from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
+
+        if "email" not in general_settings.get("alerting", []):
+            raise ValueError(
+                "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
+            )
+        event = WebhookEvent(
+            event="key_created",
+            event_group="key",
+            event_message="API Key Created",
+            token=response.get("token", ""),
+            spend=response.get("spend", 0.0),
+            max_budget=response.get("max_budget", 0.0),
+            user_id=response.get("user_id", None),
+            team_id=response.get("team_id", "Default Team"),
+            key_alias=response.get("key_alias", None),
+        )
+
+        # If user configured email alerting - send an Email letting their end-user know the key was created
+        asyncio.create_task(
+            proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
+                webhook_event=event,
+            )
+        )