about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/audit_logs.py98
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py374
2 files changed, 472 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/audit_logs.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/audit_logs.py
new file mode 100644
index 00000000..d6c83c38
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/audit_logs.py
@@ -0,0 +1,98 @@
+"""
+Functions to create audit logs for LiteLLM Proxy
+"""
+
+import json
+import uuid
+from datetime import datetime, timezone
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import (
+    AUDIT_ACTIONS,
+    LiteLLM_AuditLogs,
+    LitellmTableNames,
+    Optional,
+    UserAPIKeyAuth,
+)
+
+
+async def create_object_audit_log(
+    object_id: str,
+    action: AUDIT_ACTIONS,
+    litellm_changed_by: Optional[str],
+    user_api_key_dict: UserAPIKeyAuth,
+    litellm_proxy_admin_name: Optional[str],
+    table_name: LitellmTableNames,
+    before_value: Optional[str] = None,
+    after_value: Optional[str] = None,
+):
+    """
+    Create an audit log for an internal user.
+
+    Parameters:
+    - user_id: str - The id of the user to create the audit log for.
+    - action: AUDIT_ACTIONS - The action to create the audit log for.
+    - user_row: LiteLLM_UserTable - The user row to create the audit log for.
+    - litellm_changed_by: Optional[str] - The user id of the user who is changing the user.
+    - user_api_key_dict: UserAPIKeyAuth - The user api key dictionary.
+    - litellm_proxy_admin_name: Optional[str] - The name of the proxy admin.
+    """
+    if not litellm.store_audit_logs:
+        return
+
+    await create_audit_log_for_update(
+        request_data=LiteLLM_AuditLogs(
+            id=str(uuid.uuid4()),
+            updated_at=datetime.now(timezone.utc),
+            changed_by=litellm_changed_by
+            or user_api_key_dict.user_id
+            or litellm_proxy_admin_name,
+            changed_by_api_key=user_api_key_dict.api_key,
+            table_name=table_name,
+            object_id=object_id,
+            action=action,
+            updated_values=after_value,
+            before_value=before_value,
+        )
+    )
+
+
+async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs):
+    """
+    Create an audit log for an object.
+    """
+    if not litellm.store_audit_logs:
+        return
+
+    from litellm.proxy.proxy_server import premium_user, prisma_client
+
+    if premium_user is not True:
+        return
+
+    if litellm.store_audit_logs is not True:
+        return
+    if prisma_client is None:
+        raise Exception("prisma_client is None, no DB connected")
+
+    verbose_proxy_logger.debug("creating audit log for %s", request_data)
+
+    if isinstance(request_data.updated_values, dict):
+        request_data.updated_values = json.dumps(request_data.updated_values)
+
+    if isinstance(request_data.before_value, dict):
+        request_data.before_value = json.dumps(request_data.before_value)
+
+    _request_data = request_data.model_dump(exclude_none=True)
+
+    try:
+        await prisma_client.db.litellm_auditlog.create(
+            data={
+                **_request_data,  # type: ignore
+            }
+        )
+    except Exception as e:
+        # [Non-Blocking Exception. Do not allow blocking LLM API call]
+        verbose_proxy_logger.error(f"Failed Creating audit log {e}")
+
+    return
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py
new file mode 100644
index 00000000..69a5cf91
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py
@@ -0,0 +1,374 @@
+# What is this?
+## Helper utils for the management endpoints (keys/users/teams)
+import uuid
+from datetime import datetime
+from functools import wraps
+from typing import Optional, Tuple
+
+from fastapi import HTTPException, Request
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.proxy._types import (  # key request types; user request types; team request types; customer request types
+    DeleteCustomerRequest,
+    DeleteTeamRequest,
+    DeleteUserRequest,
+    KeyRequest,
+    LiteLLM_TeamMembership,
+    LiteLLM_UserTable,
+    ManagementEndpointLoggingPayload,
+    Member,
+    SSOUserDefinedValues,
+    UpdateCustomerRequest,
+    UpdateKeyRequest,
+    UpdateTeamRequest,
+    UpdateUserRequest,
+    UserAPIKeyAuth,
+    VirtualKeyEvent,
+)
+from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
+from litellm.proxy.utils import PrismaClient
+
+
+def get_new_internal_user_defaults(
+    user_id: str, user_email: Optional[str] = None
+) -> dict:
+    user_info = litellm.default_internal_user_params or {}
+
+    returned_dict: SSOUserDefinedValues = {
+        "models": user_info.get("models", None),
+        "max_budget": user_info.get("max_budget", litellm.max_internal_user_budget),
+        "budget_duration": user_info.get(
+            "budget_duration", litellm.internal_user_budget_duration
+        ),
+        "user_email": user_email or user_info.get("user_email", None),
+        "user_id": user_id,
+        "user_role": "internal_user",
+    }
+
+    non_null_dict = {}
+    for k, v in returned_dict.items():
+        if v is not None:
+            non_null_dict[k] = v
+    return non_null_dict
+
+
+async def add_new_member(
+    new_member: Member,
+    max_budget_in_team: Optional[float],
+    prisma_client: PrismaClient,
+    team_id: str,
+    user_api_key_dict: UserAPIKeyAuth,
+    litellm_proxy_admin_name: str,
+) -> Tuple[LiteLLM_UserTable, Optional[LiteLLM_TeamMembership]]:
+    """
+    Add a new member to a team
+
+    - add team id to user table
+    - add team member w/ budget to team member table
+
+    Returns created/existing user + team membership w/ budget id
+    """
+    returned_user: Optional[LiteLLM_UserTable] = None
+    returned_team_membership: Optional[LiteLLM_TeamMembership] = None
+    ## ADD TEAM ID, to USER TABLE IF NEW ##
+    if new_member.user_id is not None:
+        new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id)
+        _returned_user = await prisma_client.db.litellm_usertable.upsert(
+            where={"user_id": new_member.user_id},
+            data={
+                "update": {"teams": {"push": [team_id]}},
+                "create": {"teams": [team_id], **new_user_defaults},  # type: ignore
+            },
+        )
+        if _returned_user is not None:
+            returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
+    elif new_member.user_email is not None:
+        new_user_defaults = get_new_internal_user_defaults(
+            user_id=str(uuid.uuid4()), user_email=new_member.user_email
+        )
+        ## user email is not unique acc. to prisma schema -> future improvement
+        ### for now: check if it exists in db, if not - insert it
+        existing_user_row: Optional[list] = await prisma_client.get_data(
+            key_val={"user_email": new_member.user_email},
+            table_name="user",
+            query_type="find_all",
+        )
+        if existing_user_row is None or (
+            isinstance(existing_user_row, list) and len(existing_user_row) == 0
+        ):
+            new_user_defaults["teams"] = [team_id]
+            _returned_user = await prisma_client.insert_data(data=new_user_defaults, table_name="user")  # type: ignore
+
+            if _returned_user is not None:
+                returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
+        elif len(existing_user_row) == 1:
+            user_info = existing_user_row[0]
+            _returned_user = await prisma_client.db.litellm_usertable.update(
+                where={"user_id": user_info.user_id},  # type: ignore
+                data={"teams": {"push": [team_id]}},
+            )
+            if _returned_user is not None:
+                returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
+        elif len(existing_user_row) > 1:
+            raise HTTPException(
+                status_code=400,
+                detail={
+                    "error": "Multiple users with this email found in db. Please use 'user_id' instead."
+                },
+            )
+
+    # Check if trying to set a budget for team member
+    if (
+        max_budget_in_team is not None
+        and returned_user is not None
+        and returned_user.user_id is not None
+    ):
+        # create a new budget item for this member
+        response = await prisma_client.db.litellm_budgettable.create(
+            data={
+                "max_budget": max_budget_in_team,
+                "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
+                "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
+            }
+        )
+
+        _budget_id = response.budget_id
+        _returned_team_membership = (
+            await prisma_client.db.litellm_teammembership.create(
+                data={
+                    "team_id": team_id,
+                    "user_id": returned_user.user_id,
+                    "budget_id": _budget_id,
+                },
+                include={"litellm_budget_table": True},
+            )
+        )
+
+        returned_team_membership = LiteLLM_TeamMembership(
+            **_returned_team_membership.model_dump()
+        )
+
+    if returned_user is None:
+        raise Exception("Unable to update user table with membership information!")
+
+    return returned_user, returned_team_membership
+
+
+def _delete_user_id_from_cache(kwargs):
+    from litellm.proxy.proxy_server import user_api_key_cache
+
+    if kwargs.get("data") is not None:
+        update_user_request = kwargs.get("data")
+        if isinstance(update_user_request, UpdateUserRequest):
+            user_api_key_cache.delete_cache(key=update_user_request.user_id)
+
+        # delete user request
+        if isinstance(update_user_request, DeleteUserRequest):
+            for user_id in update_user_request.user_ids:
+                user_api_key_cache.delete_cache(key=user_id)
+    pass
+
+
+def _delete_api_key_from_cache(kwargs):
+    from litellm.proxy.proxy_server import user_api_key_cache
+
+    if kwargs.get("data") is not None:
+        update_request = kwargs.get("data")
+        if isinstance(update_request, UpdateKeyRequest):
+            user_api_key_cache.delete_cache(key=update_request.key)
+
+        # delete key request
+        if isinstance(update_request, KeyRequest):
+            for key in update_request.keys:
+                user_api_key_cache.delete_cache(key=key)
+    pass
+
+
+def _delete_team_id_from_cache(kwargs):
+    from litellm.proxy.proxy_server import user_api_key_cache
+
+    if kwargs.get("data") is not None:
+        update_request = kwargs.get("data")
+        if isinstance(update_request, UpdateTeamRequest):
+            user_api_key_cache.delete_cache(key=update_request.team_id)
+
+        # delete team request
+        if isinstance(update_request, DeleteTeamRequest):
+            for team_id in update_request.team_ids:
+                user_api_key_cache.delete_cache(key=team_id)
+    pass
+
+
+def _delete_customer_id_from_cache(kwargs):
+    from litellm.proxy.proxy_server import user_api_key_cache
+
+    if kwargs.get("data") is not None:
+        update_request = kwargs.get("data")
+        if isinstance(update_request, UpdateCustomerRequest):
+            user_api_key_cache.delete_cache(key=update_request.user_id)
+
+        # delete customer request
+        if isinstance(update_request, DeleteCustomerRequest):
+            for user_id in update_request.user_ids:
+                user_api_key_cache.delete_cache(key=user_id)
+    pass
+
+
+async def send_management_endpoint_alert(
+    request_kwargs: dict,
+    user_api_key_dict: UserAPIKeyAuth,
+    function_name: str,
+):
+    """
+    Sends a slack alert when:
+    - A virtual key is created, updated, or deleted
+    - An internal user is created, updated, or deleted
+    - A team is created, updated, or deleted
+    """
+    from litellm.proxy.proxy_server import premium_user, proxy_logging_obj
+    from litellm.types.integrations.slack_alerting import AlertType
+
+    if premium_user is not True:
+        return
+
+    management_function_to_event_name = {
+        "generate_key_fn": AlertType.new_virtual_key_created,
+        "update_key_fn": AlertType.virtual_key_updated,
+        "delete_key_fn": AlertType.virtual_key_deleted,
+        # Team events
+        "new_team": AlertType.new_team_created,
+        "update_team": AlertType.team_updated,
+        "delete_team": AlertType.team_deleted,
+        # Internal User events
+        "new_user": AlertType.new_internal_user_created,
+        "user_update": AlertType.internal_user_updated,
+        "delete_user": AlertType.internal_user_deleted,
+    }
+
+    # Check if alerting is enabled
+    if (
+        proxy_logging_obj is not None
+        and proxy_logging_obj.slack_alerting_instance is not None
+    ):
+
+        # Virtual Key Events
+        if function_name in management_function_to_event_name:
+            _event_name: AlertType = management_function_to_event_name[function_name]
+
+            key_event = VirtualKeyEvent(
+                created_by_user_id=user_api_key_dict.user_id or "Unknown",
+                created_by_user_role=user_api_key_dict.user_role or "Unknown",
+                created_by_key_alias=user_api_key_dict.key_alias,
+                request_kwargs=request_kwargs,
+            )
+
+            # replace all "_" with " " and capitalize
+            event_name = _event_name.replace("_", " ").title()
+            await proxy_logging_obj.slack_alerting_instance.send_virtual_key_event_slack(
+                key_event=key_event,
+                event_name=event_name,
+                alert_type=_event_name,
+            )
+
+
+def management_endpoint_wrapper(func):
+    """
+    This wrapper does the following:
+
+    1. Log I/O, Exceptions to OTEL
+    2. Create an Audit log for success calls
+    """
+
+    @wraps(func)
+    async def wrapper(*args, **kwargs):
+        start_time = datetime.now()
+        _http_request: Optional[Request] = None
+        try:
+            result = await func(*args, **kwargs)
+            end_time = datetime.now()
+            try:
+                if kwargs is None:
+                    kwargs = {}
+                user_api_key_dict: UserAPIKeyAuth = (
+                    kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
+                )
+
+                await send_management_endpoint_alert(
+                    request_kwargs=kwargs,
+                    user_api_key_dict=user_api_key_dict,
+                    function_name=func.__name__,
+                )
+                _http_request = kwargs.get("http_request", None)
+                parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
+                if parent_otel_span is not None:
+                    from litellm.proxy.proxy_server import open_telemetry_logger
+
+                    if open_telemetry_logger is not None:
+                        if _http_request:
+                            _route = _http_request.url.path
+                            _request_body: dict = await _read_request_body(
+                                request=_http_request
+                            )
+                            _response = dict(result) if result is not None else None
+
+                            logging_payload = ManagementEndpointLoggingPayload(
+                                route=_route,
+                                request_data=_request_body,
+                                response=_response,
+                                start_time=start_time,
+                                end_time=end_time,
+                            )
+
+                            await open_telemetry_logger.async_management_endpoint_success_hook(  # type: ignore
+                                logging_payload=logging_payload,
+                                parent_otel_span=parent_otel_span,
+                            )
+
+                # Delete updated/deleted info from cache
+                _delete_api_key_from_cache(kwargs=kwargs)
+                _delete_user_id_from_cache(kwargs=kwargs)
+                _delete_team_id_from_cache(kwargs=kwargs)
+                _delete_customer_id_from_cache(kwargs=kwargs)
+            except Exception as e:
+                # Non-Blocking Exception
+                verbose_logger.debug("Error in management endpoint wrapper: %s", str(e))
+                pass
+
+            return result
+        except Exception as e:
+            end_time = datetime.now()
+
+            if kwargs is None:
+                kwargs = {}
+            user_api_key_dict: UserAPIKeyAuth = (
+                kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
+            )
+            parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
+            if parent_otel_span is not None:
+                from litellm.proxy.proxy_server import open_telemetry_logger
+
+                if open_telemetry_logger is not None:
+                    _http_request = kwargs.get("http_request")
+                    if _http_request:
+                        _route = _http_request.url.path
+                        _request_body: dict = await _read_request_body(
+                            request=_http_request
+                        )
+                        logging_payload = ManagementEndpointLoggingPayload(
+                            route=_route,
+                            request_data=_request_body,
+                            response=None,
+                            start_time=start_time,
+                            end_time=end_time,
+                            exception=e,
+                        )
+
+                        await open_telemetry_logger.async_management_endpoint_failure_hook(  # type: ignore
+                            logging_payload=logging_payload,
+                            parent_otel_span=parent_otel_span,
+                        )
+
+            raise e
+
+    return wrapper