diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py | 374 |
1 files changed, 374 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py new file mode 100644 index 00000000..69a5cf91 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_helpers/utils.py @@ -0,0 +1,374 @@ +# What is this? +## Helper utils for the management endpoints (keys/users/teams) +import uuid +from datetime import datetime +from functools import wraps +from typing import Optional, Tuple + +from fastapi import HTTPException, Request + +import litellm +from litellm._logging import verbose_logger +from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types + DeleteCustomerRequest, + DeleteTeamRequest, + DeleteUserRequest, + KeyRequest, + LiteLLM_TeamMembership, + LiteLLM_UserTable, + ManagementEndpointLoggingPayload, + Member, + SSOUserDefinedValues, + UpdateCustomerRequest, + UpdateKeyRequest, + UpdateTeamRequest, + UpdateUserRequest, + UserAPIKeyAuth, + VirtualKeyEvent, +) +from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm.proxy.utils import PrismaClient + + +def get_new_internal_user_defaults( + user_id: str, user_email: Optional[str] = None +) -> dict: + user_info = litellm.default_internal_user_params or {} + + returned_dict: SSOUserDefinedValues = { + "models": user_info.get("models", None), + "max_budget": user_info.get("max_budget", litellm.max_internal_user_budget), + "budget_duration": user_info.get( + "budget_duration", litellm.internal_user_budget_duration + ), + "user_email": user_email or user_info.get("user_email", None), + "user_id": user_id, + "user_role": "internal_user", + } + + non_null_dict = {} + for k, v in returned_dict.items(): + if v is not None: + non_null_dict[k] = v + return non_null_dict + + +async def add_new_member( + new_member: Member, + max_budget_in_team: Optional[float], + prisma_client: PrismaClient, + team_id: str, + user_api_key_dict: UserAPIKeyAuth, + litellm_proxy_admin_name: str, +) -> Tuple[LiteLLM_UserTable, Optional[LiteLLM_TeamMembership]]: + """ + Add a new member to a team + + - add team id to user table + - add team member w/ budget to team member table + + Returns created/existing user + team membership w/ budget id + """ + returned_user: Optional[LiteLLM_UserTable] = None + returned_team_membership: Optional[LiteLLM_TeamMembership] = None + ## ADD TEAM ID, to USER TABLE IF NEW ## + if new_member.user_id is not None: + new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id) + _returned_user = await prisma_client.db.litellm_usertable.upsert( + where={"user_id": new_member.user_id}, + data={ + "update": {"teams": {"push": [team_id]}}, + "create": {"teams": [team_id], **new_user_defaults}, # type: ignore + }, + ) + if _returned_user is not None: + returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) + elif new_member.user_email is not None: + new_user_defaults = get_new_internal_user_defaults( + user_id=str(uuid.uuid4()), user_email=new_member.user_email + ) + ## user email is not unique acc. to prisma schema -> future improvement + ### for now: check if it exists in db, if not - insert it + existing_user_row: Optional[list] = await prisma_client.get_data( + key_val={"user_email": new_member.user_email}, + table_name="user", + query_type="find_all", + ) + if existing_user_row is None or ( + isinstance(existing_user_row, list) and len(existing_user_row) == 0 + ): + new_user_defaults["teams"] = [team_id] + _returned_user = await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore + + if _returned_user is not None: + returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) + elif len(existing_user_row) == 1: + user_info = existing_user_row[0] + _returned_user = await prisma_client.db.litellm_usertable.update( + where={"user_id": user_info.user_id}, # type: ignore + data={"teams": {"push": [team_id]}}, + ) + if _returned_user is not None: + returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) + elif len(existing_user_row) > 1: + raise HTTPException( + status_code=400, + detail={ + "error": "Multiple users with this email found in db. Please use 'user_id' instead." + }, + ) + + # Check if trying to set a budget for team member + if ( + max_budget_in_team is not None + and returned_user is not None + and returned_user.user_id is not None + ): + # create a new budget item for this member + response = await prisma_client.db.litellm_budgettable.create( + data={ + "max_budget": max_budget_in_team, + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } + ) + + _budget_id = response.budget_id + _returned_team_membership = ( + await prisma_client.db.litellm_teammembership.create( + data={ + "team_id": team_id, + "user_id": returned_user.user_id, + "budget_id": _budget_id, + }, + include={"litellm_budget_table": True}, + ) + ) + + returned_team_membership = LiteLLM_TeamMembership( + **_returned_team_membership.model_dump() + ) + + if returned_user is None: + raise Exception("Unable to update user table with membership information!") + + return returned_user, returned_team_membership + + +def _delete_user_id_from_cache(kwargs): + from litellm.proxy.proxy_server import user_api_key_cache + + if kwargs.get("data") is not None: + update_user_request = kwargs.get("data") + if isinstance(update_user_request, UpdateUserRequest): + user_api_key_cache.delete_cache(key=update_user_request.user_id) + + # delete user request + if isinstance(update_user_request, DeleteUserRequest): + for user_id in update_user_request.user_ids: + user_api_key_cache.delete_cache(key=user_id) + pass + + +def _delete_api_key_from_cache(kwargs): + from litellm.proxy.proxy_server import user_api_key_cache + + if kwargs.get("data") is not None: + update_request = kwargs.get("data") + if isinstance(update_request, UpdateKeyRequest): + user_api_key_cache.delete_cache(key=update_request.key) + + # delete key request + if isinstance(update_request, KeyRequest): + for key in update_request.keys: + user_api_key_cache.delete_cache(key=key) + pass + + +def _delete_team_id_from_cache(kwargs): + from litellm.proxy.proxy_server import user_api_key_cache + + if kwargs.get("data") is not None: + update_request = kwargs.get("data") + if isinstance(update_request, UpdateTeamRequest): + user_api_key_cache.delete_cache(key=update_request.team_id) + + # delete team request + if isinstance(update_request, DeleteTeamRequest): + for team_id in update_request.team_ids: + user_api_key_cache.delete_cache(key=team_id) + pass + + +def _delete_customer_id_from_cache(kwargs): + from litellm.proxy.proxy_server import user_api_key_cache + + if kwargs.get("data") is not None: + update_request = kwargs.get("data") + if isinstance(update_request, UpdateCustomerRequest): + user_api_key_cache.delete_cache(key=update_request.user_id) + + # delete customer request + if isinstance(update_request, DeleteCustomerRequest): + for user_id in update_request.user_ids: + user_api_key_cache.delete_cache(key=user_id) + pass + + +async def send_management_endpoint_alert( + request_kwargs: dict, + user_api_key_dict: UserAPIKeyAuth, + function_name: str, +): + """ + Sends a slack alert when: + - A virtual key is created, updated, or deleted + - An internal user is created, updated, or deleted + - A team is created, updated, or deleted + """ + from litellm.proxy.proxy_server import premium_user, proxy_logging_obj + from litellm.types.integrations.slack_alerting import AlertType + + if premium_user is not True: + return + + management_function_to_event_name = { + "generate_key_fn": AlertType.new_virtual_key_created, + "update_key_fn": AlertType.virtual_key_updated, + "delete_key_fn": AlertType.virtual_key_deleted, + # Team events + "new_team": AlertType.new_team_created, + "update_team": AlertType.team_updated, + "delete_team": AlertType.team_deleted, + # Internal User events + "new_user": AlertType.new_internal_user_created, + "user_update": AlertType.internal_user_updated, + "delete_user": AlertType.internal_user_deleted, + } + + # Check if alerting is enabled + if ( + proxy_logging_obj is not None + and proxy_logging_obj.slack_alerting_instance is not None + ): + + # Virtual Key Events + if function_name in management_function_to_event_name: + _event_name: AlertType = management_function_to_event_name[function_name] + + key_event = VirtualKeyEvent( + created_by_user_id=user_api_key_dict.user_id or "Unknown", + created_by_user_role=user_api_key_dict.user_role or "Unknown", + created_by_key_alias=user_api_key_dict.key_alias, + request_kwargs=request_kwargs, + ) + + # replace all "_" with " " and capitalize + event_name = _event_name.replace("_", " ").title() + await proxy_logging_obj.slack_alerting_instance.send_virtual_key_event_slack( + key_event=key_event, + event_name=event_name, + alert_type=_event_name, + ) + + +def management_endpoint_wrapper(func): + """ + This wrapper does the following: + + 1. Log I/O, Exceptions to OTEL + 2. Create an Audit log for success calls + """ + + @wraps(func) + async def wrapper(*args, **kwargs): + start_time = datetime.now() + _http_request: Optional[Request] = None + try: + result = await func(*args, **kwargs) + end_time = datetime.now() + try: + if kwargs is None: + kwargs = {} + user_api_key_dict: UserAPIKeyAuth = ( + kwargs.get("user_api_key_dict") or UserAPIKeyAuth() + ) + + await send_management_endpoint_alert( + request_kwargs=kwargs, + user_api_key_dict=user_api_key_dict, + function_name=func.__name__, + ) + _http_request = kwargs.get("http_request", None) + parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None) + if parent_otel_span is not None: + from litellm.proxy.proxy_server import open_telemetry_logger + + if open_telemetry_logger is not None: + if _http_request: + _route = _http_request.url.path + _request_body: dict = await _read_request_body( + request=_http_request + ) + _response = dict(result) if result is not None else None + + logging_payload = ManagementEndpointLoggingPayload( + route=_route, + request_data=_request_body, + response=_response, + start_time=start_time, + end_time=end_time, + ) + + await open_telemetry_logger.async_management_endpoint_success_hook( # type: ignore + logging_payload=logging_payload, + parent_otel_span=parent_otel_span, + ) + + # Delete updated/deleted info from cache + _delete_api_key_from_cache(kwargs=kwargs) + _delete_user_id_from_cache(kwargs=kwargs) + _delete_team_id_from_cache(kwargs=kwargs) + _delete_customer_id_from_cache(kwargs=kwargs) + except Exception as e: + # Non-Blocking Exception + verbose_logger.debug("Error in management endpoint wrapper: %s", str(e)) + pass + + return result + except Exception as e: + end_time = datetime.now() + + if kwargs is None: + kwargs = {} + user_api_key_dict: UserAPIKeyAuth = ( + kwargs.get("user_api_key_dict") or UserAPIKeyAuth() + ) + parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None) + if parent_otel_span is not None: + from litellm.proxy.proxy_server import open_telemetry_logger + + if open_telemetry_logger is not None: + _http_request = kwargs.get("http_request") + if _http_request: + _route = _http_request.url.path + _request_body: dict = await _read_request_body( + request=_http_request + ) + logging_payload = ManagementEndpointLoggingPayload( + route=_route, + request_data=_request_body, + response=None, + start_time=start_time, + end_time=end_time, + exception=e, + ) + + await open_telemetry_logger.async_management_endpoint_failure_hook( # type: ignore + logging_payload=logging_payload, + parent_otel_span=parent_otel_span, + ) + + raise e + + return wrapper |