aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py374
1 files changed, 374 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py
new file mode 100644
index 00000000..69a5cf91
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py
@@ -0,0 +1,374 @@
+# What is this?
+## Helper utils for the management endpoints (keys/users/teams)
+import uuid
+from datetime import datetime
+from functools import wraps
+from typing import Optional, Tuple
+
+from fastapi import HTTPException, Request
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types
+ DeleteCustomerRequest,
+ DeleteTeamRequest,
+ DeleteUserRequest,
+ KeyRequest,
+ LiteLLM_TeamMembership,
+ LiteLLM_UserTable,
+ ManagementEndpointLoggingPayload,
+ Member,
+ SSOUserDefinedValues,
+ UpdateCustomerRequest,
+ UpdateKeyRequest,
+ UpdateTeamRequest,
+ UpdateUserRequest,
+ UserAPIKeyAuth,
+ VirtualKeyEvent,
+)
+from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
+from litellm.proxy.utils import PrismaClient
+
+
+def get_new_internal_user_defaults(
+ user_id: str, user_email: Optional[str] = None
+) -> dict:
+ user_info = litellm.default_internal_user_params or {}
+
+ returned_dict: SSOUserDefinedValues = {
+ "models": user_info.get("models", None),
+ "max_budget": user_info.get("max_budget", litellm.max_internal_user_budget),
+ "budget_duration": user_info.get(
+ "budget_duration", litellm.internal_user_budget_duration
+ ),
+ "user_email": user_email or user_info.get("user_email", None),
+ "user_id": user_id,
+ "user_role": "internal_user",
+ }
+
+ non_null_dict = {}
+ for k, v in returned_dict.items():
+ if v is not None:
+ non_null_dict[k] = v
+ return non_null_dict
+
+
+async def add_new_member(
+ new_member: Member,
+ max_budget_in_team: Optional[float],
+ prisma_client: PrismaClient,
+ team_id: str,
+ user_api_key_dict: UserAPIKeyAuth,
+ litellm_proxy_admin_name: str,
+) -> Tuple[LiteLLM_UserTable, Optional[LiteLLM_TeamMembership]]:
+ """
+ Add a new member to a team
+
+ - add team id to user table
+ - add team member w/ budget to team member table
+
+ Returns created/existing user + team membership w/ budget id
+ """
+ returned_user: Optional[LiteLLM_UserTable] = None
+ returned_team_membership: Optional[LiteLLM_TeamMembership] = None
+ ## ADD TEAM ID, to USER TABLE IF NEW ##
+ if new_member.user_id is not None:
+ new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id)
+ _returned_user = await prisma_client.db.litellm_usertable.upsert(
+ where={"user_id": new_member.user_id},
+ data={
+ "update": {"teams": {"push": [team_id]}},
+ "create": {"teams": [team_id], **new_user_defaults}, # type: ignore
+ },
+ )
+ if _returned_user is not None:
+ returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
+ elif new_member.user_email is not None:
+ new_user_defaults = get_new_internal_user_defaults(
+ user_id=str(uuid.uuid4()), user_email=new_member.user_email
+ )
+ ## user email is not unique acc. to prisma schema -> future improvement
+ ### for now: check if it exists in db, if not - insert it
+ existing_user_row: Optional[list] = await prisma_client.get_data(
+ key_val={"user_email": new_member.user_email},
+ table_name="user",
+ query_type="find_all",
+ )
+ if existing_user_row is None or (
+ isinstance(existing_user_row, list) and len(existing_user_row) == 0
+ ):
+ new_user_defaults["teams"] = [team_id]
+ _returned_user = await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore
+
+ if _returned_user is not None:
+ returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
+ elif len(existing_user_row) == 1:
+ user_info = existing_user_row[0]
+ _returned_user = await prisma_client.db.litellm_usertable.update(
+ where={"user_id": user_info.user_id}, # type: ignore
+ data={"teams": {"push": [team_id]}},
+ )
+ if _returned_user is not None:
+ returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
+ elif len(existing_user_row) > 1:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Multiple users with this email found in db. Please use 'user_id' instead."
+ },
+ )
+
+ # Check if trying to set a budget for team member
+ if (
+ max_budget_in_team is not None
+ and returned_user is not None
+ and returned_user.user_id is not None
+ ):
+ # create a new budget item for this member
+ response = await prisma_client.db.litellm_budgettable.create(
+ data={
+ "max_budget": max_budget_in_team,
+ "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
+ "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
+ }
+ )
+
+ _budget_id = response.budget_id
+ _returned_team_membership = (
+ await prisma_client.db.litellm_teammembership.create(
+ data={
+ "team_id": team_id,
+ "user_id": returned_user.user_id,
+ "budget_id": _budget_id,
+ },
+ include={"litellm_budget_table": True},
+ )
+ )
+
+ returned_team_membership = LiteLLM_TeamMembership(
+ **_returned_team_membership.model_dump()
+ )
+
+ if returned_user is None:
+ raise Exception("Unable to update user table with membership information!")
+
+ return returned_user, returned_team_membership
+
+
+def _delete_user_id_from_cache(kwargs):
+ from litellm.proxy.proxy_server import user_api_key_cache
+
+ if kwargs.get("data") is not None:
+ update_user_request = kwargs.get("data")
+ if isinstance(update_user_request, UpdateUserRequest):
+ user_api_key_cache.delete_cache(key=update_user_request.user_id)
+
+ # delete user request
+ if isinstance(update_user_request, DeleteUserRequest):
+ for user_id in update_user_request.user_ids:
+ user_api_key_cache.delete_cache(key=user_id)
+ pass
+
+
+def _delete_api_key_from_cache(kwargs):
+ from litellm.proxy.proxy_server import user_api_key_cache
+
+ if kwargs.get("data") is not None:
+ update_request = kwargs.get("data")
+ if isinstance(update_request, UpdateKeyRequest):
+ user_api_key_cache.delete_cache(key=update_request.key)
+
+ # delete key request
+ if isinstance(update_request, KeyRequest):
+ for key in update_request.keys:
+ user_api_key_cache.delete_cache(key=key)
+ pass
+
+
+def _delete_team_id_from_cache(kwargs):
+ from litellm.proxy.proxy_server import user_api_key_cache
+
+ if kwargs.get("data") is not None:
+ update_request = kwargs.get("data")
+ if isinstance(update_request, UpdateTeamRequest):
+ user_api_key_cache.delete_cache(key=update_request.team_id)
+
+ # delete team request
+ if isinstance(update_request, DeleteTeamRequest):
+ for team_id in update_request.team_ids:
+ user_api_key_cache.delete_cache(key=team_id)
+ pass
+
+
+def _delete_customer_id_from_cache(kwargs):
+ from litellm.proxy.proxy_server import user_api_key_cache
+
+ if kwargs.get("data") is not None:
+ update_request = kwargs.get("data")
+ if isinstance(update_request, UpdateCustomerRequest):
+ user_api_key_cache.delete_cache(key=update_request.user_id)
+
+ # delete customer request
+ if isinstance(update_request, DeleteCustomerRequest):
+ for user_id in update_request.user_ids:
+ user_api_key_cache.delete_cache(key=user_id)
+ pass
+
+
+async def send_management_endpoint_alert(
+ request_kwargs: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ function_name: str,
+):
+ """
+ Sends a slack alert when:
+ - A virtual key is created, updated, or deleted
+ - An internal user is created, updated, or deleted
+ - A team is created, updated, or deleted
+ """
+ from litellm.proxy.proxy_server import premium_user, proxy_logging_obj
+ from litellm.types.integrations.slack_alerting import AlertType
+
+ if premium_user is not True:
+ return
+
+ management_function_to_event_name = {
+ "generate_key_fn": AlertType.new_virtual_key_created,
+ "update_key_fn": AlertType.virtual_key_updated,
+ "delete_key_fn": AlertType.virtual_key_deleted,
+ # Team events
+ "new_team": AlertType.new_team_created,
+ "update_team": AlertType.team_updated,
+ "delete_team": AlertType.team_deleted,
+ # Internal User events
+ "new_user": AlertType.new_internal_user_created,
+ "user_update": AlertType.internal_user_updated,
+ "delete_user": AlertType.internal_user_deleted,
+ }
+
+ # Check if alerting is enabled
+ if (
+ proxy_logging_obj is not None
+ and proxy_logging_obj.slack_alerting_instance is not None
+ ):
+
+ # Virtual Key Events
+ if function_name in management_function_to_event_name:
+ _event_name: AlertType = management_function_to_event_name[function_name]
+
+ key_event = VirtualKeyEvent(
+ created_by_user_id=user_api_key_dict.user_id or "Unknown",
+ created_by_user_role=user_api_key_dict.user_role or "Unknown",
+ created_by_key_alias=user_api_key_dict.key_alias,
+ request_kwargs=request_kwargs,
+ )
+
+ # replace all "_" with " " and capitalize
+ event_name = _event_name.replace("_", " ").title()
+ await proxy_logging_obj.slack_alerting_instance.send_virtual_key_event_slack(
+ key_event=key_event,
+ event_name=event_name,
+ alert_type=_event_name,
+ )
+
+
+def management_endpoint_wrapper(func):
+ """
+ This wrapper does the following:
+
+ 1. Log I/O, Exceptions to OTEL
+ 2. Create an Audit log for success calls
+ """
+
+ @wraps(func)
+ async def wrapper(*args, **kwargs):
+ start_time = datetime.now()
+ _http_request: Optional[Request] = None
+ try:
+ result = await func(*args, **kwargs)
+ end_time = datetime.now()
+ try:
+ if kwargs is None:
+ kwargs = {}
+ user_api_key_dict: UserAPIKeyAuth = (
+ kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
+ )
+
+ await send_management_endpoint_alert(
+ request_kwargs=kwargs,
+ user_api_key_dict=user_api_key_dict,
+ function_name=func.__name__,
+ )
+ _http_request = kwargs.get("http_request", None)
+ parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
+ if parent_otel_span is not None:
+ from litellm.proxy.proxy_server import open_telemetry_logger
+
+ if open_telemetry_logger is not None:
+ if _http_request:
+ _route = _http_request.url.path
+ _request_body: dict = await _read_request_body(
+ request=_http_request
+ )
+ _response = dict(result) if result is not None else None
+
+ logging_payload = ManagementEndpointLoggingPayload(
+ route=_route,
+ request_data=_request_body,
+ response=_response,
+ start_time=start_time,
+ end_time=end_time,
+ )
+
+ await open_telemetry_logger.async_management_endpoint_success_hook( # type: ignore
+ logging_payload=logging_payload,
+ parent_otel_span=parent_otel_span,
+ )
+
+ # Delete updated/deleted info from cache
+ _delete_api_key_from_cache(kwargs=kwargs)
+ _delete_user_id_from_cache(kwargs=kwargs)
+ _delete_team_id_from_cache(kwargs=kwargs)
+ _delete_customer_id_from_cache(kwargs=kwargs)
+ except Exception as e:
+ # Non-Blocking Exception
+ verbose_logger.debug("Error in management endpoint wrapper: %s", str(e))
+ pass
+
+ return result
+ except Exception as e:
+ end_time = datetime.now()
+
+ if kwargs is None:
+ kwargs = {}
+ user_api_key_dict: UserAPIKeyAuth = (
+ kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
+ )
+ parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
+ if parent_otel_span is not None:
+ from litellm.proxy.proxy_server import open_telemetry_logger
+
+ if open_telemetry_logger is not None:
+ _http_request = kwargs.get("http_request")
+ if _http_request:
+ _route = _http_request.url.path
+ _request_body: dict = await _read_request_body(
+ request=_http_request
+ )
+ logging_payload = ManagementEndpointLoggingPayload(
+ route=_route,
+ request_data=_request_body,
+ response=None,
+ start_time=start_time,
+ end_time=end_time,
+ exception=e,
+ )
+
+ await open_telemetry_logger.async_management_endpoint_failure_hook( # type: ignore
+ logging_payload=logging_payload,
+ parent_otel_span=parent_otel_span,
+ )
+
+ raise e
+
+ return wrapper