diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints')
12 files changed, 9481 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/budget_management_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/budget_management_endpoints.py new file mode 100644 index 00000000..20aa1c6b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/budget_management_endpoints.py @@ -0,0 +1,287 @@ +""" +BUDGET MANAGEMENT + +All /budget management endpoints + +/budget/new +/budget/info +/budget/update +/budget/delete +/budget/settings +/budget/list +""" + +#### BUDGET TABLE MANAGEMENT #### +from fastapi import APIRouter, Depends, HTTPException + +from litellm.proxy._types import * +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.utils import jsonify_object + +router = APIRouter() + + +@router.post( + "/budget/new", + tags=["budget management"], + dependencies=[Depends(user_api_key_auth)], +) +async def new_budget( + budget_obj: BudgetNewRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Create a new budget object. Can apply this to teams, orgs, end-users, keys. + + Parameters: + - budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.) + - budget_id: Optional[str] - The id of the budget. If not provided, a new id will be generated. + - max_budget: Optional[float] - The max budget for the budget. + - soft_budget: Optional[float] - The soft budget for the budget. + - max_parallel_requests: Optional[int] - The max number of parallel requests for the budget. + - tpm_limit: Optional[int] - The tokens per minute limit for the budget. + - rpm_limit: Optional[int] - The requests per minute limit for the budget. + - model_max_budget: Optional[dict] - Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d", "tpm_limit": 100000, "rpm_limit": 100000}} + """ + from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + budget_obj_json = budget_obj.model_dump(exclude_none=True) + budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries + response = await prisma_client.db.litellm_budgettable.create( + data={ + **budget_obj_jsonified, # type: ignore + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } # type: ignore + ) + + return response + + +@router.post( + "/budget/update", + tags=["budget management"], + dependencies=[Depends(user_api_key_auth)], +) +async def update_budget( + budget_obj: BudgetNewRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Update an existing budget object. + + Parameters: + - budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.) + - budget_id: Optional[str] - The id of the budget. If not provided, a new id will be generated. + - max_budget: Optional[float] - The max budget for the budget. + - soft_budget: Optional[float] - The soft budget for the budget. + - max_parallel_requests: Optional[int] - The max number of parallel requests for the budget. + - tpm_limit: Optional[int] - The tokens per minute limit for the budget. + - rpm_limit: Optional[int] - The requests per minute limit for the budget. + - model_max_budget: Optional[dict] - Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d", "tpm_limit": 100000, "rpm_limit": 100000}} + """ + from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + if budget_obj.budget_id is None: + raise HTTPException(status_code=400, detail={"error": "budget_id is required"}) + + response = await prisma_client.db.litellm_budgettable.update( + where={"budget_id": budget_obj.budget_id}, + data={ + **budget_obj.model_dump(exclude_none=True), # type: ignore + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + }, # type: ignore + ) + + return response + + +@router.post( + "/budget/info", + tags=["budget management"], + dependencies=[Depends(user_api_key_auth)], +) +async def info_budget(data: BudgetRequest): + """ + Get the budget id specific information + + Parameters: + - budgets: List[str] - The list of budget ids to get information for + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if len(data.budgets) == 0: + raise HTTPException( + status_code=400, + detail={ + "error": f"Specify list of budget id's to query. Passed in={data.budgets}" + }, + ) + response = await prisma_client.db.litellm_budgettable.find_many( + where={"budget_id": {"in": data.budgets}}, + ) + + return response + + +@router.get( + "/budget/settings", + tags=["budget management"], + dependencies=[Depends(user_api_key_auth)], +) +async def budget_settings( + budget_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Get list of configurable params + current value for a budget item + description of each field + + Used on Admin UI. + + Query Parameters: + - budget_id: str - The budget id to get information for + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: + raise HTTPException( + status_code=400, + detail={ + "error": "{}, your role={}".format( + CommonProxyErrors.not_allowed_access.value, + user_api_key_dict.user_role, + ) + }, + ) + + ## get budget item from db + db_budget_row = await prisma_client.db.litellm_budgettable.find_first( + where={"budget_id": budget_id} + ) + + if db_budget_row is not None: + db_budget_row_dict = db_budget_row.model_dump(exclude_none=True) + else: + db_budget_row_dict = {} + + allowed_args = { + "max_parallel_requests": {"type": "Integer"}, + "tpm_limit": {"type": "Integer"}, + "rpm_limit": {"type": "Integer"}, + "budget_duration": {"type": "String"}, + "max_budget": {"type": "Float"}, + "soft_budget": {"type": "Float"}, + } + + return_val = [] + + for field_name, field_info in BudgetNewRequest.model_fields.items(): + if field_name in allowed_args: + + _stored_in_db = True + + _response_obj = ConfigList( + field_name=field_name, + field_type=allowed_args[field_name]["type"], + field_description=field_info.description or "", + field_value=db_budget_row_dict.get(field_name, None), + stored_in_db=_stored_in_db, + field_default_value=field_info.default, + ) + return_val.append(_response_obj) + + return return_val + + +@router.get( + "/budget/list", + tags=["budget management"], + dependencies=[Depends(user_api_key_auth)], +) +async def list_budget( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """List all the created budgets in proxy db. Used on Admin UI.""" + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: + raise HTTPException( + status_code=400, + detail={ + "error": "{}, your role={}".format( + CommonProxyErrors.not_allowed_access.value, + user_api_key_dict.user_role, + ) + }, + ) + + response = await prisma_client.db.litellm_budgettable.find_many() + + return response + + +@router.post( + "/budget/delete", + tags=["budget management"], + dependencies=[Depends(user_api_key_auth)], +) +async def delete_budget( + data: BudgetDeleteRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Delete budget + + Parameters: + - id: str - The budget id to delete + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: + raise HTTPException( + status_code=400, + detail={ + "error": "{}, your role={}".format( + CommonProxyErrors.not_allowed_access.value, + user_api_key_dict.user_role, + ) + }, + ) + + response = await prisma_client.db.litellm_budgettable.delete( + where={"budget_id": data.id} + ) + + return response diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/common_utils.py new file mode 100644 index 00000000..d80a06c5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/common_utils.py @@ -0,0 +1,41 @@ +from typing import Any, Union + +from litellm.proxy._types import ( + GenerateKeyRequest, + LiteLLM_ManagementEndpoint_MetadataFields_Premium, + LiteLLM_TeamTable, + UserAPIKeyAuth, +) +from litellm.proxy.utils import _premium_user_check + + +def _is_user_team_admin( + user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable +) -> bool: + for member in team_obj.members_with_roles: + if ( + member.user_id is not None and member.user_id == user_api_key_dict.user_id + ) and member.role == "admin": + + return True + + return False + + +def _set_object_metadata_field( + object_data: Union[LiteLLM_TeamTable, GenerateKeyRequest], + field_name: str, + value: Any, +) -> None: + """ + Helper function to set metadata fields that require premium user checks + + Args: + object_data: The team data object to modify + field_name: Name of the metadata field to set + value: Value to set for the field + """ + if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium: + _premium_user_check() + object_data.metadata = object_data.metadata or {} + object_data.metadata[field_name] = value diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/customer_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/customer_endpoints.py new file mode 100644 index 00000000..976ff858 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/customer_endpoints.py @@ -0,0 +1,620 @@ +""" +CUSTOMER MANAGEMENT + +All /customer management endpoints + +/customer/new +/customer/info +/customer/update +/customer/delete +""" + +#### END-USER/CUSTOMER MANAGEMENT #### +import traceback +from typing import List, Optional + +import fastapi +from fastapi import APIRouter, Depends, HTTPException, Request, status + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import * +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + +router = APIRouter() + + +@router.post( + "/end_user/block", + tags=["Customer Management"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +@router.post( + "/customer/block", + tags=["Customer Management"], + dependencies=[Depends(user_api_key_auth)], +) +async def block_user(data: BlockUsers): + """ + [BETA] Reject calls with this end-user id + + Parameters: + - user_ids (List[str], required): The unique `user_id`s for the users to block + + (any /chat/completion call with this user={end-user-id} param, will be rejected.) + + ``` + curl -X POST "http://0.0.0.0:8000/user/block" + -H "Authorization: Bearer sk-1234" + -d '{ + "user_ids": [<user_id>, ...] + }' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + try: + records = [] + if prisma_client is not None: + for id in data.user_ids: + record = await prisma_client.db.litellm_endusertable.upsert( + where={"user_id": id}, # type: ignore + data={ + "create": {"user_id": id, "blocked": True}, # type: ignore + "update": {"blocked": True}, + }, + ) + records.append(record) + else: + raise HTTPException( + status_code=500, + detail={"error": "Postgres DB Not connected"}, + ) + + return {"blocked_users": records} + except Exception as e: + verbose_proxy_logger.error(f"An error occurred - {str(e)}") + raise HTTPException(status_code=500, detail={"error": str(e)}) + + +@router.post( + "/end_user/unblock", + tags=["Customer Management"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +@router.post( + "/customer/unblock", + tags=["Customer Management"], + dependencies=[Depends(user_api_key_auth)], +) +async def unblock_user(data: BlockUsers): + """ + [BETA] Unblock calls with this user id + + Example + ``` + curl -X POST "http://0.0.0.0:8000/user/unblock" + -H "Authorization: Bearer sk-1234" + -d '{ + "user_ids": [<user_id>, ...] + }' + ``` + """ + from enterprise.enterprise_hooks.blocked_user_list import ( + _ENTERPRISE_BlockedUserList, + ) + + if ( + not any(isinstance(x, _ENTERPRISE_BlockedUserList) for x in litellm.callbacks) + or litellm.blocked_user_list is None + ): + raise HTTPException( + status_code=400, + detail={ + "error": "Blocked user check was never set. This call has no effect." + }, + ) + + if isinstance(litellm.blocked_user_list, list): + for id in data.user_ids: + litellm.blocked_user_list.remove(id) + else: + raise HTTPException( + status_code=500, + detail={ + "error": "`blocked_user_list` must be set as a list. Filepaths can't be updated." + }, + ) + + return {"blocked_users": litellm.blocked_user_list} + + +def new_budget_request(data: NewCustomerRequest) -> Optional[BudgetNewRequest]: + """ + Return a new budget object if new budget params are passed. + """ + budget_params = BudgetNewRequest.model_fields.keys() + budget_kv_pairs = {} + + # Get the actual values from the data object using getattr + for field_name in budget_params: + if field_name == "budget_id": + continue + value = getattr(data, field_name, None) + if value is not None: + budget_kv_pairs[field_name] = value + + if budget_kv_pairs: + return BudgetNewRequest(**budget_kv_pairs) + return None + + +@router.post( + "/end_user/new", + tags=["Customer Management"], + include_in_schema=False, + dependencies=[Depends(user_api_key_auth)], +) +@router.post( + "/customer/new", + tags=["Customer Management"], + dependencies=[Depends(user_api_key_auth)], +) +async def new_end_user( + data: NewCustomerRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Allow creating a new Customer + + + Parameters: + - user_id: str - The unique identifier for the user. + - alias: Optional[str] - A human-friendly alias for the user. + - blocked: bool - Flag to allow or disallow requests for this end-user. Default is False. + - max_budget: Optional[float] - The maximum budget allocated to the user. Either 'max_budget' or 'budget_id' should be provided, not both. + - budget_id: Optional[str] - The identifier for an existing budget allocated to the user. Either 'max_budget' or 'budget_id' should be provided, not both. + - allowed_model_region: Optional[Union[Literal["eu"], Literal["us"]]] - Require all user requests to use models in this specific region. + - default_model: Optional[str] - If no equivalent model in the allowed region, default all requests to this model. + - metadata: Optional[dict] = Metadata for customer, store information for customer. Example metadata = {"data_training_opt_out": True} + - budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). + - tpm_limit: Optional[int] - [Not Implemented Yet] Specify tpm limit for a given customer (Tokens per minute) + - rpm_limit: Optional[int] - [Not Implemented Yet] Specify rpm limit for a given customer (Requests per minute) + - model_max_budget: Optional[dict] - [Not Implemented Yet] Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d"}} + - max_parallel_requests: Optional[int] - [Not Implemented Yet] Specify max parallel requests for a given customer. + - soft_budget: Optional[float] - [Not Implemented Yet] Get alerts when customer crosses given budget, doesn't block requests. + + + - Allow specifying allowed regions + - Allow specifying default model + + Example curl: + ``` + curl --location 'http://0.0.0.0:4000/customer/new' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "user_id" : "ishaan-jaff-3", + "allowed_region": "eu", + "budget_id": "free_tier", + "default_model": "azure/gpt-3.5-turbo-eu" <- all calls from this user, use this model? + }' + + # return end-user object + ``` + + NOTE: This used to be called `/end_user/new`, we will still be maintaining compatibility for /end_user/XXX for these endpoints + """ + """ + Validation: + - check if default model exists + - create budget object if not already created + + - Add user to end user table + + Return + - end-user object + - currently allowed models + """ + from litellm.proxy.proxy_server import ( + litellm_proxy_admin_name, + llm_router, + prisma_client, + ) + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + try: + + ## VALIDATION ## + if data.default_model is not None: + if llm_router is None: + raise HTTPException( + status_code=422, + detail={"error": CommonProxyErrors.no_llm_router.value}, + ) + elif data.default_model not in llm_router.get_model_names(): + raise HTTPException( + status_code=422, + detail={ + "error": "Default Model not on proxy. Configure via `/model/new` or config.yaml. Default_model={}, proxy_model_names={}".format( + data.default_model, set(llm_router.get_model_names()) + ) + }, + ) + + new_end_user_obj: Dict = {} + + ## CREATE BUDGET ## if set + _new_budget = new_budget_request(data) + if _new_budget is not None: + try: + budget_record = await prisma_client.db.litellm_budgettable.create( + data={ + **_new_budget.model_dump(exclude_unset=True), + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore + "updated_by": user_api_key_dict.user_id + or litellm_proxy_admin_name, + } + ) + except Exception as e: + raise HTTPException(status_code=422, detail={"error": str(e)}) + + new_end_user_obj["budget_id"] = budget_record.budget_id + elif data.budget_id is not None: + new_end_user_obj["budget_id"] = data.budget_id + + _user_data = data.dict(exclude_none=True) + + for k, v in _user_data.items(): + if k not in BudgetNewRequest.model_fields.keys(): + new_end_user_obj[k] = v + + ## WRITE TO DB ## + end_user_record = await prisma_client.db.litellm_endusertable.create( + data=new_end_user_obj, # type: ignore + include={"litellm_budget_table": True}, + ) + + return end_user_record + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.management_endpoints.customer_endpoints.new_end_user(): Exception occured - {}".format( + str(e) + ) + ) + if "Unique constraint failed on the fields: (`user_id`)" in str(e): + raise ProxyException( + message=f"Customer already exists, passed user_id={data.user_id}. Please pass a new user_id.", + type="bad_request", + code=400, + param="user_id", + ) + + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Internal Server Error({str(e)})"), + type="internal_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Internal Server Error, " + str(e), + type="internal_error", + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +@router.get( + "/customer/info", + tags=["Customer Management"], + dependencies=[Depends(user_api_key_auth)], + response_model=LiteLLM_EndUserTable, +) +@router.get( + "/end_user/info", + tags=["Customer Management"], + include_in_schema=False, + dependencies=[Depends(user_api_key_auth)], +) +async def end_user_info( + end_user_id: str = fastapi.Query( + description="End User ID in the request parameters" + ), +): + """ + Get information about an end-user. An `end_user` is a customer (external user) of the proxy. + + Parameters: + - end_user_id (str, required): The unique identifier for the end-user + + Example curl: + ``` + curl -X GET 'http://localhost:4000/customer/info?end_user_id=test-litellm-user-4' \ + -H 'Authorization: Bearer sk-1234' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + user_info = await prisma_client.db.litellm_endusertable.find_first( + where={"user_id": end_user_id}, include={"litellm_budget_table": True} + ) + + if user_info is None: + raise HTTPException( + status_code=400, + detail={"error": "End User Id={} does not exist in db".format(end_user_id)}, + ) + return user_info.model_dump(exclude_none=True) + + +@router.post( + "/customer/update", + tags=["Customer Management"], + dependencies=[Depends(user_api_key_auth)], +) +@router.post( + "/end_user/update", + tags=["Customer Management"], + include_in_schema=False, + dependencies=[Depends(user_api_key_auth)], +) +async def update_end_user( + data: UpdateCustomerRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Example curl + + Parameters: + - user_id: str + - alias: Optional[str] = None # human-friendly alias + - blocked: bool = False # allow/disallow requests for this end-user + - max_budget: Optional[float] = None + - budget_id: Optional[str] = None # give either a budget_id or max_budget + - allowed_model_region: Optional[AllowedModelRegion] = ( + None # require all user requests to use models in this specific region + ) + - default_model: Optional[str] = ( + None # if no equivalent model in allowed region - default all requests to this model + ) + + Example curl: + ``` + curl --location 'http://0.0.0.0:4000/customer/update' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "user_id": "test-litellm-user-4", + "budget_id": "paid_tier" + }' + + See below for all params + ``` + """ + + from litellm.proxy.proxy_server import prisma_client + + try: + data_json: dict = data.json() + # get the row from db + if prisma_client is None: + raise Exception("Not connected to DB!") + + # get non default values for key + non_default_values = {} + for k, v in data_json.items(): + if v is not None and v not in ( + [], + {}, + 0, + ): # models default to [], spend defaults to 0, we should not reset these values + non_default_values[k] = v + + ## ADD USER, IF NEW ## + verbose_proxy_logger.debug("/customer/update: Received data = %s", data) + if data.user_id is not None and len(data.user_id) > 0: + non_default_values["user_id"] = data.user_id # type: ignore + verbose_proxy_logger.debug("In update customer, user_id condition block.") + response = await prisma_client.db.litellm_endusertable.update( + where={"user_id": data.user_id}, data=non_default_values # type: ignore + ) + if response is None: + raise ValueError( + f"Failed updating customer data. User ID does not exist passed user_id={data.user_id}" + ) + verbose_proxy_logger.debug( + f"received response from updating prisma client. response={response}" + ) + return response + else: + raise ValueError(f"user_id is required, passed user_id = {data.user_id}") + + # update based on remaining passed in values + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.update_end_user(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Internal Server Error({str(e)})"), + type="internal_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Internal Server Error, " + str(e), + type="internal_error", + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + pass + + +@router.post( + "/customer/delete", + tags=["Customer Management"], + dependencies=[Depends(user_api_key_auth)], +) +@router.post( + "/end_user/delete", + tags=["Customer Management"], + include_in_schema=False, + dependencies=[Depends(user_api_key_auth)], +) +async def delete_end_user( + data: DeleteCustomerRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Delete multiple end-users. + + Parameters: + - user_ids (List[str], required): The unique `user_id`s for the users to delete + + Example curl: + ``` + curl --location 'http://0.0.0.0:4000/customer/delete' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "user_ids" :["ishaan-jaff-5"] + }' + + See below for all params + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + try: + if prisma_client is None: + raise Exception("Not connected to DB!") + + verbose_proxy_logger.debug("/customer/delete: Received data = %s", data) + if ( + data.user_ids is not None + and isinstance(data.user_ids, list) + and len(data.user_ids) > 0 + ): + response = await prisma_client.db.litellm_endusertable.delete_many( + where={"user_id": {"in": data.user_ids}} + ) + if response is None: + raise ValueError( + f"Failed deleting customer data. User ID does not exist passed user_id={data.user_ids}" + ) + if response != len(data.user_ids): + raise ValueError( + f"Failed deleting all customer data. User ID does not exist passed user_id={data.user_ids}. Deleted {response} customers, passed {len(data.user_ids)} customers" + ) + verbose_proxy_logger.debug( + f"received response from updating prisma client. response={response}" + ) + return { + "deleted_customers": response, + "message": "Successfully deleted customers with ids: " + + str(data.user_ids), + } + else: + raise ValueError(f"user_id is required, passed user_id = {data.user_ids}") + + # update based on remaining passed in values + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.delete_end_user(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Internal Server Error({str(e)})"), + type="internal_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Internal Server Error, " + str(e), + type="internal_error", + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + pass + + +@router.get( + "/customer/list", + tags=["Customer Management"], + dependencies=[Depends(user_api_key_auth)], + response_model=List[LiteLLM_EndUserTable], +) +@router.get( + "/end_user/list", + tags=["Customer Management"], + include_in_schema=False, + dependencies=[Depends(user_api_key_auth)], +) +async def list_end_user( + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [Admin-only] List all available customers + + Example curl: + ``` + curl --location --request GET 'http://0.0.0.0:4000/customer/list' \ + --header 'Authorization: Bearer sk-1234' + ``` + + """ + from litellm.proxy.proxy_server import prisma_client + + if ( + user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN + and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY + ): + raise HTTPException( + status_code=401, + detail={ + "error": "Admin-only endpoint. Your user role={}".format( + user_api_key_dict.user_role + ) + }, + ) + + if prisma_client is None: + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + response = await prisma_client.db.litellm_endusertable.find_many( + include={"litellm_budget_table": True} + ) + + returned_response: List[LiteLLM_EndUserTable] = [] + for item in response: + returned_response.append(LiteLLM_EndUserTable(**item.model_dump())) + return returned_response diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/internal_user_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/internal_user_endpoints.py new file mode 100644 index 00000000..43d8273d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -0,0 +1,1243 @@ +""" +Internal User Management Endpoints + + +These are members of a Team on LiteLLM + +/user/new +/user/update +/user/delete +/user/info +/user/list +""" + +import asyncio +import traceback +import uuid +from datetime import datetime, timedelta, timezone +from typing import Any, List, Optional, Union, cast + +import fastapi +from fastapi import APIRouter, Depends, Header, HTTPException, Request, status + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.duration_parser import duration_in_seconds +from litellm.proxy._types import * +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.management_endpoints.key_management_endpoints import ( + generate_key_helper_fn, + prepare_metadata_fields, +) +from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update +from litellm.proxy.management_helpers.utils import management_endpoint_wrapper +from litellm.proxy.utils import handle_exception_on_proxy + +router = APIRouter() + + +async def create_internal_user_audit_log( + user_id: str, + action: AUDIT_ACTIONS, + litellm_changed_by: Optional[str], + user_api_key_dict: UserAPIKeyAuth, + litellm_proxy_admin_name: Optional[str], + before_value: Optional[str] = None, + after_value: Optional[str] = None, +): + """ + Create an audit log for an internal user. + + Parameters: + - user_id: str - The id of the user to create the audit log for. + - action: AUDIT_ACTIONS - The action to create the audit log for. + - user_row: LiteLLM_UserTable - The user row to create the audit log for. + - litellm_changed_by: Optional[str] - The user id of the user who is changing the user. + - user_api_key_dict: UserAPIKeyAuth - The user api key dictionary. + - litellm_proxy_admin_name: Optional[str] - The name of the proxy admin. + """ + if not litellm.store_audit_logs: + return + + await create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.USER_TABLE_NAME, + object_id=user_id, + action=action, + updated_values=after_value, + before_value=before_value, + ) + ) + + +def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> dict: + if "user_id" in data_json and data_json["user_id"] is None: + data_json["user_id"] = str(uuid.uuid4()) + auto_create_key = data_json.pop("auto_create_key", True) + if auto_create_key is False: + data_json["table_name"] = ( + "user" # only create a user, don't create key if 'auto_create_key' set to False + ) + + is_internal_user = False + if data.user_role and data.user_role.is_internal_user_role: + is_internal_user = True + if litellm.default_internal_user_params: + for key, value in litellm.default_internal_user_params.items(): + if key == "available_teams": + continue + elif key not in data_json or data_json[key] is None: + data_json[key] = value + elif ( + key == "models" + and isinstance(data_json[key], list) + and len(data_json[key]) == 0 + ): + data_json[key] = value + + if "max_budget" in data_json and data_json["max_budget"] is None: + if is_internal_user and litellm.max_internal_user_budget is not None: + data_json["max_budget"] = litellm.max_internal_user_budget + + if "budget_duration" in data_json and data_json["budget_duration"] is None: + if is_internal_user and litellm.internal_user_budget_duration is not None: + data_json["budget_duration"] = litellm.internal_user_budget_duration + + return data_json + + +async def _check_duplicate_user_email( + user_email: Optional[str], prisma_client: Any +) -> None: + """ + Helper function to check if a user email already exists in the database. + + Args: + user_email (Optional[str]): Email to check + prisma_client (Any): Database client instance + + Raises: + Exception: If database is not connected + HTTPException: If user with email already exists + """ + if user_email: + if prisma_client is None: + raise Exception("Database not connected") + + existing_user = await prisma_client.db.litellm_usertable.find_first( + where={"user_email": user_email} + ) + + if existing_user is not None: + raise HTTPException( + status_code=400, + detail={"error": f"User with email {user_email} already exists"}, + ) + + +@router.post( + "/user/new", + tags=["Internal User management"], + dependencies=[Depends(user_api_key_auth)], + response_model=NewUserResponse, +) +@management_endpoint_wrapper +async def new_user( + data: NewUserRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Use this to create a new INTERNAL user with a budget. + Internal Users can access LiteLLM Admin UI to make keys, request access to models. + This creates a new user and generates a new api key for the new user. The new api key is returned. + + Returns user id, budget + new key. + + Parameters: + - user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated. + - user_alias: Optional[str] - A descriptive name for you to know who this user id refers to. + - teams: Optional[list] - specify a list of team id's a user belongs to. + - user_email: Optional[str] - Specify a user email. + - send_invite_email: Optional[bool] - Specify if an invite email should be sent. + - user_role: Optional[str] - Specify a user role - "proxy_admin", "proxy_admin_viewer", "internal_user", "internal_user_viewer", "team", "customer". Info about each role here: `https://github.com/BerriAI/litellm/litellm/proxy/_types.py#L20` + - max_budget: Optional[float] - Specify max budget for a given user. + - budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"), months ("1mo"). + - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models). Set to ['no-default-models'] to block all model access. Restricting user to only team-based model access. + - tpm_limit: Optional[int] - Specify tpm limit for a given user (Tokens per minute) + - rpm_limit: Optional[int] - Specify rpm limit for a given user (Requests per minute) + - auto_create_key: bool - Default=True. Flag used for returning a key as part of the /user/new response + - aliases: Optional[dict] - Model aliases for the user - [Docs](https://litellm.vercel.app/docs/proxy/virtual_keys#model-aliases) + - config: Optional[dict] - [DEPRECATED PARAM] User-specific config. + - allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request- + - blocked: Optional[bool] - [Not Implemented Yet] Whether the user is blocked. + - guardrails: Optional[List[str]] - [Not Implemented Yet] List of active guardrails for the user + - permissions: Optional[dict] - [Not Implemented Yet] User-specific permissions, eg. turning off pii masking. + - metadata: Optional[dict] - Metadata for user, store information for user. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } + - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. + - soft_budget: Optional[float] - Get alerts when user crosses given budget, doesn't block requests. + - model_max_budget: Optional[dict] - Model-specific max budget for user. [Docs](https://docs.litellm.ai/docs/proxy/users#add-model-specific-budgets-to-keys) + - model_rpm_limit: Optional[float] - Model-specific rpm limit for user. [Docs](https://docs.litellm.ai/docs/proxy/users#add-model-specific-limits-to-keys) + - model_tpm_limit: Optional[float] - Model-specific tpm limit for user. [Docs](https://docs.litellm.ai/docs/proxy/users#add-model-specific-limits-to-keys) + - spend: Optional[float] - Amount spent by user. Default is 0. Will be updated by proxy whenever user is used. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"), months ("1mo"). + - team_id: Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None. + - duration: Optional[str] - Duration for the key auto-created on `/user/new`. Default is None. + - key_alias: Optional[str] - Alias for the key auto-created on `/user/new`. Default is None. + + Returns: + - key: (str) The generated api key for the user + - expires: (datetime) Datetime object for when key expires. + - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. + - max_budget: (float|None) Max budget for given user. + + Usage Example + + ```shell + curl -X POST "http://localhost:4000/user/new" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "username": "new_user", + "email": "new_user@example.com" + }' + ``` + """ + try: + from litellm.proxy.proxy_server import ( + general_settings, + litellm_proxy_admin_name, + prisma_client, + proxy_logging_obj, + ) + + # Check for duplicate email + await _check_duplicate_user_email(data.user_email, prisma_client) + + data_json = data.json() # type: ignore + data_json = _update_internal_new_user_params(data_json, data) + response = await generate_key_helper_fn(request_type="user", **data_json) + # Admin UI Logic + # Add User to Team and Organization + # if team_id passed add this user to the team + if data_json.get("team_id", None) is not None: + from litellm.proxy.management_endpoints.team_endpoints import ( + team_member_add, + ) + + try: + await team_member_add( + data=TeamMemberAddRequest( + team_id=data_json.get("team_id", None), + member=Member( + user_id=data_json.get("user_id", None), + role="user", + user_email=data_json.get("user_email", None), + ), + ), + http_request=Request( + scope={"type": "http", "path": "/user/new"}, + ), + user_api_key_dict=user_api_key_dict, + ) + except HTTPException as e: + if e.status_code == 400 and ( + "already exists" in str(e) or "doesn't exist" in str(e) + ): + verbose_proxy_logger.debug( + "litellm.proxy.management_endpoints.internal_user_endpoints.new_user(): User already exists in team - {}".format( + str(e) + ) + ) + else: + verbose_proxy_logger.debug( + "litellm.proxy.management_endpoints.internal_user_endpoints.new_user(): Exception occured - {}".format( + str(e) + ) + ) + except Exception as e: + if "already exists" in str(e) or "doesn't exist" in str(e): + verbose_proxy_logger.debug( + "litellm.proxy.management_endpoints.internal_user_endpoints.new_user(): User already exists in team - {}".format( + str(e) + ) + ) + else: + raise e + + if data.send_invite_email is True: + # check if user has setup email alerting + if "email" not in general_settings.get("alerting", []): + raise ValueError( + "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`" + ) + + event = WebhookEvent( + event="internal_user_created", + event_group="internal_user", + event_message="Welcome to LiteLLM Proxy", + token=response.get("token", ""), + spend=response.get("spend", 0.0), + max_budget=response.get("max_budget", 0.0), + user_id=response.get("user_id", None), + user_email=response.get("user_email", None), + team_id=response.get("team_id", "Default Team"), + key_alias=response.get("key_alias", None), + ) + + # If user configured email alerting - send an Email letting their end-user know the key was created + asyncio.create_task( + proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email( + webhook_event=event, + ) + ) + + try: + if prisma_client is None: + raise Exception(CommonProxyErrors.db_not_connected_error.value) + user_row: BaseModel = await prisma_client.db.litellm_usertable.find_first( + where={"user_id": response["user_id"]} + ) + + user_row_litellm_typed = LiteLLM_UserTable( + **user_row.model_dump(exclude_none=True) + ) + asyncio.create_task( + create_internal_user_audit_log( + user_id=user_row_litellm_typed.user_id, + action="created", + litellm_changed_by=user_api_key_dict.user_id, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + before_value=None, + after_value=user_row_litellm_typed.model_dump_json( + exclude_none=True + ), + ) + ) + except Exception as e: + verbose_proxy_logger.warning( + "Unable to create audit log for user on `/user/new` - {}".format(str(e)) + ) + + return NewUserResponse( + key=response.get("token", ""), + expires=response.get("expires", None), + max_budget=response["max_budget"], + user_id=response["user_id"], + user_role=response.get("user_role", None), + user_email=response.get("user_email", None), + user_alias=response.get("user_alias", None), + teams=response.get("teams", None), + team_id=response.get("team_id", None), + metadata=response.get("metadata", None), + models=response.get("models", None), + tpm_limit=response.get("tpm_limit", None), + rpm_limit=response.get("rpm_limit", None), + budget_duration=response.get("budget_duration", None), + model_max_budget=response.get("model_max_budget", None), + ) + except Exception as e: + verbose_proxy_logger.exception( + "/user/new: Exception occured - {}".format(str(e)) + ) + raise handle_exception_on_proxy(e) + + +@router.get( + "/user/available_roles", + tags=["Internal User management"], + include_in_schema=False, + dependencies=[Depends(user_api_key_auth)], +) +async def ui_get_available_role( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Endpoint used by Admin UI to show all available roles to assign a user + return { + "proxy_admin": { + "description": "Proxy Admin role", + "ui_label": "Admin" + } + } + """ + + _data_to_return = {} + for role in LitellmUserRoles: + + # We only show a subset of roles on UI + if role in [ + LitellmUserRoles.PROXY_ADMIN, + LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + ]: + _data_to_return[role.value] = { + "description": role.description, + "ui_label": role.ui_label, + } + return _data_to_return + + +def get_team_from_list( + team_list: Optional[Union[List[LiteLLM_TeamTable], List[TeamListResponseObject]]], + team_id: str, +) -> Optional[Union[LiteLLM_TeamTable, LiteLLM_TeamMembership]]: + if team_list is None: + return None + + for team in team_list: + if team.team_id == team_id: + return team + return None + + +@router.get( + "/user/info", + tags=["Internal User management"], + dependencies=[Depends(user_api_key_auth)], + # response_model=UserInfoResponse, +) +@management_endpoint_wrapper +async def user_info( + user_id: Optional[str] = fastapi.Query( + default=None, description="User ID in the request parameters" + ), + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [10/07/2024] + Note: To get all users (+pagination), use `/user/list` endpoint. + + + Use this to get user information. (user row + all user key info) + + Example request + ``` + curl -X GET 'http://localhost:4000/user/info?user_id=krrish7%40berri.ai' \ + --header 'Authorization: Bearer sk-1234' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + try: + if prisma_client is None: + raise Exception( + "Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + if ( + user_id is None + and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN + ): + return await _get_user_info_for_proxy_admin() + elif user_id is None: + user_id = user_api_key_dict.user_id + ## GET USER ROW ## + if user_id is not None: + user_info = await prisma_client.get_data(user_id=user_id) + else: + user_info = None + ## GET ALL TEAMS ## + team_list = [] + team_id_list = [] + from litellm.proxy.management_endpoints.team_endpoints import list_team + + teams_1 = await list_team( + http_request=Request( + scope={"type": "http", "path": "/user/info"}, + ), + user_id=user_id, + user_api_key_dict=user_api_key_dict, + ) + + if teams_1 is not None and isinstance(teams_1, list): + team_list = teams_1 + for team in teams_1: + team_id_list.append(team.team_id) + + teams_2: Optional[Any] = None + if user_info is not None: + # *NEW* get all teams in user 'teams' field + teams_2 = await prisma_client.get_data( + team_id_list=user_info.teams, table_name="team", query_type="find_all" + ) + + if teams_2 is not None and isinstance(teams_2, list): + for team in teams_2: + if team.team_id not in team_id_list: + team_list.append(team) + team_id_list.append(team.team_id) + + elif ( + user_api_key_dict.user_id is not None and user_id is None + ): # the key querying the endpoint is the one asking for it's teams + caller_user_info = await prisma_client.get_data( + user_id=user_api_key_dict.user_id + ) + # *NEW* get all teams in user 'teams' field + if caller_user_info is not None: + teams_2 = await prisma_client.get_data( + team_id_list=caller_user_info.teams, + table_name="team", + query_type="find_all", + ) + + if teams_2 is not None and isinstance(teams_2, list): + for team in teams_2: + if team.team_id not in team_id_list: + team_list.append(team) + team_id_list.append(team.team_id) + + ## GET ALL KEYS ## + keys = await prisma_client.get_data( + user_id=user_id, + table_name="key", + query_type="find_all", + ) + + if user_info is None and keys is not None: + ## make sure we still return a total spend ## + spend = 0 + for k in keys: + spend += getattr(k, "spend", 0) + user_info = {"spend": spend} + + ## REMOVE HASHED TOKEN INFO before returning ## + returned_keys = _process_keys_for_user_info(keys=keys, all_teams=teams_1) + team_list.sort(key=lambda x: (getattr(x, "team_alias", "") or "")) + _user_info = ( + user_info.model_dump() if isinstance(user_info, BaseModel) else user_info + ) + response_data = UserInfoResponse( + user_id=user_id, user_info=_user_info, keys=returned_keys, teams=team_list + ) + + return response_data + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.user_info(): Exception occured - {}".format( + str(e) + ) + ) + raise handle_exception_on_proxy(e) + + +async def _get_user_info_for_proxy_admin(): + """ + Admin UI Endpoint - Returns All Teams and Keys when Proxy Admin is querying + + - get all teams in LiteLLM_TeamTable + - get all keys in LiteLLM_VerificationToken table + + Why separate helper for proxy admin ? + - To get Faster UI load times, get all teams and virtual keys in 1 query + """ + + from litellm.proxy.proxy_server import prisma_client + + sql_query = """ + SELECT + (SELECT json_agg(t.*) FROM "LiteLLM_TeamTable" t) as teams, + (SELECT json_agg(k.*) FROM "LiteLLM_VerificationToken" k WHERE k.team_id != 'litellm-dashboard' OR k.team_id IS NULL) as keys + """ + if prisma_client is None: + raise Exception( + "Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + results = await prisma_client.db.query_raw(sql_query) + + verbose_proxy_logger.debug("results_keys: %s", results) + + _keys_in_db: List = results[0]["keys"] or [] + # cast all keys to LiteLLM_VerificationToken + keys_in_db = [] + for key in _keys_in_db: + if key.get("models") is None: + key["models"] = [] + keys_in_db.append(LiteLLM_VerificationToken(**key)) + + # cast all teams to LiteLLM_TeamTable + _teams_in_db: List = results[0]["teams"] or [] + _teams_in_db = [LiteLLM_TeamTable(**team) for team in _teams_in_db] + _teams_in_db.sort(key=lambda x: (getattr(x, "team_alias", "") or "")) + returned_keys = _process_keys_for_user_info(keys=keys_in_db, all_teams=_teams_in_db) + return UserInfoResponse( + user_id=None, + user_info=None, + keys=returned_keys, + teams=_teams_in_db, + ) + + +def _process_keys_for_user_info( + keys: Optional[List[LiteLLM_VerificationToken]], + all_teams: Optional[Union[List[LiteLLM_TeamTable], List[TeamListResponseObject]]], +): + from litellm.proxy.proxy_server import general_settings, litellm_master_key_hash + + returned_keys = [] + if keys is None: + pass + else: + for key in keys: + if ( + key.token == litellm_master_key_hash + and general_settings.get("disable_master_key_return", False) + is True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui + ): + continue + + try: + _key: dict = key.model_dump() # noqa + except Exception: + # if using pydantic v1 + _key = key.dict() + if ( + "team_id" in _key + and _key["team_id"] is not None + and _key["team_id"] != "litellm-dashboard" + ): + team_info = get_team_from_list( + team_list=all_teams, team_id=_key["team_id"] + ) + if team_info is not None: + team_alias = getattr(team_info, "team_alias", None) + _key["team_alias"] = team_alias + else: + _key["team_alias"] = None + else: + _key["team_alias"] = "None" + returned_keys.append(_key) + return returned_keys + + +def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> dict: + non_default_values = {} + for k, v in data_json.items(): + if ( + v is not None + and v + not in ( + [], + {}, + 0, + ) + and k not in LiteLLM_ManagementEndpoint_MetadataFields + ): # models default to [], spend defaults to 0, we should not reset these values + non_default_values[k] = v + + is_internal_user = False + if data.user_role == LitellmUserRoles.INTERNAL_USER: + is_internal_user = True + + if "budget_duration" in non_default_values: + duration_s = duration_in_seconds(duration=non_default_values["budget_duration"]) + user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["budget_reset_at"] = user_reset_at + + if "max_budget" not in non_default_values: + if ( + is_internal_user and litellm.max_internal_user_budget is not None + ): # applies internal user limits, if user role updated + non_default_values["max_budget"] = litellm.max_internal_user_budget + + if ( + "budget_duration" not in non_default_values + ): # applies internal user limits, if user role updated + if is_internal_user and litellm.internal_user_budget_duration is not None: + non_default_values["budget_duration"] = ( + litellm.internal_user_budget_duration + ) + duration_s = duration_in_seconds( + duration=non_default_values["budget_duration"] + ) + user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["budget_reset_at"] = user_reset_at + + return non_default_values + + +@router.post( + "/user/update", + tags=["Internal User management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def user_update( + data: UpdateUserRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Example curl + + ``` + curl --location 'http://0.0.0.0:4000/user/update' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "user_id": "test-litellm-user-4", + "user_role": "proxy_admin_viewer" + }' + ``` + + Parameters: + - user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated. + - user_email: Optional[str] - Specify a user email. + - password: Optional[str] - Specify a user password. + - user_alias: Optional[str] - A descriptive name for you to know who this user id refers to. + - teams: Optional[list] - specify a list of team id's a user belongs to. + - send_invite_email: Optional[bool] - Specify if an invite email should be sent. + - user_role: Optional[str] - Specify a user role - "proxy_admin", "proxy_admin_viewer", "internal_user", "internal_user_viewer", "team", "customer". Info about each role here: `https://github.com/BerriAI/litellm/litellm/proxy/_types.py#L20` + - max_budget: Optional[float] - Specify max budget for a given user. + - budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"), months ("1mo"). + - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) + - tpm_limit: Optional[int] - Specify tpm limit for a given user (Tokens per minute) + - rpm_limit: Optional[int] - Specify rpm limit for a given user (Requests per minute) + - auto_create_key: bool - Default=True. Flag used for returning a key as part of the /user/new response + - aliases: Optional[dict] - Model aliases for the user - [Docs](https://litellm.vercel.app/docs/proxy/virtual_keys#model-aliases) + - config: Optional[dict] - [DEPRECATED PARAM] User-specific config. + - allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request- + - blocked: Optional[bool] - [Not Implemented Yet] Whether the user is blocked. + - guardrails: Optional[List[str]] - [Not Implemented Yet] List of active guardrails for the user + - permissions: Optional[dict] - [Not Implemented Yet] User-specific permissions, eg. turning off pii masking. + - metadata: Optional[dict] - Metadata for user, store information for user. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } + - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. + - soft_budget: Optional[float] - Get alerts when user crosses given budget, doesn't block requests. + - model_max_budget: Optional[dict] - Model-specific max budget for user. [Docs](https://docs.litellm.ai/docs/proxy/users#add-model-specific-budgets-to-keys) + - model_rpm_limit: Optional[float] - Model-specific rpm limit for user. [Docs](https://docs.litellm.ai/docs/proxy/users#add-model-specific-limits-to-keys) + - model_tpm_limit: Optional[float] - Model-specific tpm limit for user. [Docs](https://docs.litellm.ai/docs/proxy/users#add-model-specific-limits-to-keys) + - spend: Optional[float] - Amount spent by user. Default is 0. Will be updated by proxy whenever user is used. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"), months ("1mo"). + - team_id: Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None. + - duration: Optional[str] - [NOT IMPLEMENTED]. + - key_alias: Optional[str] - [NOT IMPLEMENTED]. + + + """ + from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client + + try: + data_json: dict = data.json() + # get the row from db + if prisma_client is None: + raise Exception("Not connected to DB!") + + # get non default values for key + non_default_values = _update_internal_user_params( + data_json=data_json, data=data + ) + + existing_user_row: Optional[BaseModel] = None + if data.user_id is not None: + existing_user_row = await prisma_client.db.litellm_usertable.find_first( + where={"user_id": data.user_id} + ) + if existing_user_row is not None: + existing_user_row = LiteLLM_UserTable( + **existing_user_row.model_dump(exclude_none=True) + ) + + existing_metadata = ( + cast(Dict, getattr(existing_user_row, "metadata", {}) or {}) + if existing_user_row is not None + else {} + ) + + non_default_values = prepare_metadata_fields( + data=data, + non_default_values=non_default_values, + existing_metadata=existing_metadata or {}, + ) + + ## ADD USER, IF NEW ## + verbose_proxy_logger.debug("/user/update: Received data = %s", data) + response: Optional[Any] = None + if data.user_id is not None and len(data.user_id) > 0: + non_default_values["user_id"] = data.user_id # type: ignore + verbose_proxy_logger.debug("In update user, user_id condition block.") + response = await prisma_client.update_data( + user_id=data.user_id, + data=non_default_values, + table_name="user", + ) + verbose_proxy_logger.debug( + f"received response from updating prisma client. response={response}" + ) + elif data.user_email is not None: + non_default_values["user_id"] = str(uuid.uuid4()) + non_default_values["user_email"] = data.user_email + ## user email is not unique acc. to prisma schema -> future improvement + ### for now: check if it exists in db, if not - insert it + existing_user_rows = await prisma_client.get_data( + key_val={"user_email": data.user_email}, + table_name="user", + query_type="find_all", + ) + if existing_user_rows is None or ( + isinstance(existing_user_rows, list) and len(existing_user_rows) == 0 + ): + response = await prisma_client.insert_data( + data=non_default_values, table_name="user" + ) + elif isinstance(existing_user_rows, list) and len(existing_user_rows) > 0: + for existing_user in existing_user_rows: + response = await prisma_client.update_data( + user_id=existing_user.user_id, + data=non_default_values, + table_name="user", + ) + + if response is not None: # emit audit log + try: + user_row: BaseModel = ( + await prisma_client.db.litellm_usertable.find_first( + where={"user_id": response["user_id"]} + ) + ) + + user_row_litellm_typed = LiteLLM_UserTable( + **user_row.model_dump(exclude_none=True) + ) + asyncio.create_task( + create_internal_user_audit_log( + user_id=user_row_litellm_typed.user_id, + action="updated", + litellm_changed_by=user_api_key_dict.user_id, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + before_value=( + existing_user_row.model_dump_json(exclude_none=True) + if existing_user_row + else None + ), + after_value=user_row_litellm_typed.model_dump_json( + exclude_none=True + ), + ) + ) + except Exception as e: + verbose_proxy_logger.warning( + "Unable to create audit log for user on `/user/update` - {}".format( + str(e) + ) + ) + return response # type: ignore + # update based on remaining passed in values + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.user_update(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + +async def get_user_key_counts( + prisma_client, + user_ids: Optional[List[str]] = None, +): + """ + Helper function to get the count of keys for each user using Prisma's count method. + + Args: + prisma_client: The Prisma client instance + user_ids: List of user IDs to get key counts for + + Returns: + Dictionary mapping user_id to key count + """ + from litellm.constants import UI_SESSION_TOKEN_TEAM_ID + + if not user_ids or len(user_ids) == 0: + return {} + + result = {} + + # Get count for each user_id individually + for user_id in user_ids: + count = await prisma_client.db.litellm_verificationtoken.count( + where={ + "user_id": user_id, + "OR": [ + {"team_id": None}, + {"team_id": {"not": UI_SESSION_TOKEN_TEAM_ID}}, + ], + } + ) + result[user_id] = count + + return result + + +@router.get( + "/user/get_users", + tags=["Internal User management"], + dependencies=[Depends(user_api_key_auth)], +) +@router.get( + "/user/list", + tags=["Internal User management"], + dependencies=[Depends(user_api_key_auth)], +) +async def get_users( + role: Optional[str] = fastapi.Query( + default=None, description="Filter users by role" + ), + user_ids: Optional[str] = fastapi.Query( + default=None, description="Get list of users by user_ids" + ), + page: int = fastapi.Query(default=1, ge=1, description="Page number"), + page_size: int = fastapi.Query( + default=25, ge=1, le=100, description="Number of items per page" + ), +): + """ + Get a paginated list of users, optionally filtered by role. + + Used by the UI to populate the user lists. + + Parameters: + role: Optional[str] + Filter users by role. Can be one of: + - proxy_admin + - proxy_admin_viewer + - internal_user + - internal_user_viewer + user_ids: Optional[str] + Get list of users by user_ids. Comma separated list of user_ids. + page: int + The page number to return + page_size: int + The number of items per page + + Currently - admin-only endpoint. + + Example curl: + ``` + http://0.0.0.0:4000/user/list?user_ids=default_user_id,693c1a4a-1cc0-4c7c-afe8-b5d2c8d52e17 + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": f"No db connected. prisma client={prisma_client}"}, + ) + + # Calculate skip and take for pagination + skip = (page - 1) * page_size + + # Prepare the query conditions + # Build where conditions based on provided parameters + where_conditions: Dict[str, Any] = {} + + if role: + where_conditions["user_role"] = { + "contains": role, + "mode": "insensitive", # Case-insensitive search + } + + if user_ids and isinstance(user_ids, str): + user_id_list = [uid.strip() for uid in user_ids.split(",") if uid.strip()] + where_conditions["user_id"] = { + "in": user_id_list, # Now passing a list of strings as required by Prisma + } + + users: Optional[List[LiteLLM_UserTable]] = ( + await prisma_client.db.litellm_usertable.find_many( + where=where_conditions, + skip=skip, + take=page_size, + order={"created_at": "desc"}, + ) + ) + + # Get total count of user rows + total_count = await prisma_client.db.litellm_usertable.count( + where=where_conditions # type: ignore + ) + + # Get key count for each user + if users is not None: + user_key_counts = await get_user_key_counts( + prisma_client, [user.user_id for user in users] + ) + else: + user_key_counts = {} + + verbose_proxy_logger.debug(f"Total count of users: {total_count}") + + # Calculate total pages + total_pages = -(-total_count // page_size) # Ceiling division + + # Prepare response + user_list: List[LiteLLM_UserTableWithKeyCount] = [] + if users is not None: + for user in users: + user_list.append( + LiteLLM_UserTableWithKeyCount( + **user.model_dump(), key_count=user_key_counts.get(user.user_id, 0) + ) + ) # Return full key object + else: + user_list = [] + + return { + "users": user_list, + "total": total_count, + "page": page, + "page_size": page_size, + "total_pages": total_pages, + } + + +@router.post( + "/user/delete", + tags=["Internal User management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def delete_user( + data: DeleteUserRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + delete user and associated user keys + + ``` + curl --location 'http://0.0.0.0:4000/user/delete' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data-raw '{ + "user_ids": ["45e3e396-ee08-4a61-a88e-16b3ce7e0849"] + }' + ``` + + Parameters: + - user_ids: List[str] - The list of user id's to be deleted. + """ + from litellm.proxy.proxy_server import ( + create_audit_log_for_update, + litellm_proxy_admin_name, + prisma_client, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.user_ids is None: + raise HTTPException(status_code=400, detail={"error": "No user id passed in"}) + + # check that all teams passed exist + for user_id in data.user_ids: + user_row = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_id} + ) + + if user_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"User not found, passed user_id={user_id}"}, + ) + else: + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes + if litellm.store_audit_logs is True: + # make an audit log for each team deleted + _user_row = user_row.json(exclude_none=True) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.USER_TABLE_NAME, + object_id=user_id, + action="deleted", + updated_values="{}", + before_value=_user_row, + ) + ) + ) + + # End of Audit logging + + ## DELETE ASSOCIATED KEYS + await prisma_client.db.litellm_verificationtoken.delete_many( + where={"user_id": {"in": data.user_ids}} + ) + + ## DELETE ASSOCIATED INVITATION LINKS + await prisma_client.db.litellm_invitationlink.delete_many( + where={"user_id": {"in": data.user_ids}} + ) + + ## DELETE ASSOCIATED ORGANIZATION MEMBERSHIPS + await prisma_client.db.litellm_organizationmembership.delete_many( + where={"user_id": {"in": data.user_ids}} + ) + + ## DELETE USERS + deleted_users = await prisma_client.db.litellm_usertable.delete_many( + where={"user_id": {"in": data.user_ids}} + ) + + return deleted_users + + +async def add_internal_user_to_organization( + user_id: str, + organization_id: str, + user_role: LitellmUserRoles, +): + """ + Helper function to add an internal user to an organization + + Adds the user to LiteLLM_OrganizationMembership table + + - Checks if organization_id exists + + Raises: + - Exception if database not connected + - Exception if user_id or organization_id not found + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise Exception("Database not connected") + + try: + # Check if organization_id exists + organization_row = await prisma_client.db.litellm_organizationtable.find_unique( + where={"organization_id": organization_id} + ) + if organization_row is None: + raise Exception( + f"Organization not found, passed organization_id={organization_id}" + ) + + # Create a new organization membership entry + new_membership = await prisma_client.db.litellm_organizationmembership.create( + data={ + "user_id": user_id, + "organization_id": organization_id, + "user_role": user_role, + # Note: You can also set budget within an organization if needed + } + ) + + return new_membership + except Exception as e: + raise Exception(f"Failed to add user to organization: {str(e)}") + + +@router.get( + "/user/filter/ui", + tags=["Internal User management"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, + responses={ + 200: {"model": List[LiteLLM_UserTableFiltered]}, + }, +) +async def ui_view_users( + user_id: Optional[str] = fastapi.Query( + default=None, description="User ID in the request parameters" + ), + user_email: Optional[str] = fastapi.Query( + default=None, description="User email in the request parameters" + ), + page: int = fastapi.Query( + default=1, description="Page number for pagination", ge=1 + ), + page_size: int = fastapi.Query( + default=50, description="Number of items per page", ge=1, le=100 + ), + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [PROXY-ADMIN ONLY]Filter users based on partial match of user_id or email with pagination. + + Args: + user_id (Optional[str]): Partial user ID to search for + user_email (Optional[str]): Partial email to search for + page (int): Page number for pagination (starts at 1) + page_size (int): Number of items per page (max 100) + user_api_key_dict (UserAPIKeyAuth): User authentication information + + Returns: + List[LiteLLM_SpendLogs]: Paginated list of matching user records + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + try: + # Calculate offset for pagination + skip = (page - 1) * page_size + + # Build where conditions based on provided parameters + where_conditions = {} + + if user_id: + where_conditions["user_id"] = { + "contains": user_id, + "mode": "insensitive", # Case-insensitive search + } + + if user_email: + where_conditions["user_email"] = { + "contains": user_email, + "mode": "insensitive", # Case-insensitive search + } + + # Query users with pagination and filters + users: Optional[List[BaseModel]] = ( + await prisma_client.db.litellm_usertable.find_many( + where=where_conditions, + skip=skip, + take=page_size, + order={"created_at": "desc"}, + ) + ) + + if not users: + return [] + + return [LiteLLM_UserTableFiltered(**user.model_dump()) for user in users] + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error searching users: {str(e)}") diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/key_management_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/key_management_endpoints.py new file mode 100644 index 00000000..9141d9d1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -0,0 +1,2621 @@ +""" +KEY MANAGEMENT + +All /key management endpoints + +/key/generate +/key/info +/key/update +/key/delete +""" + +import asyncio +import copy +import json +import secrets +import traceback +import uuid +from datetime import datetime, timedelta, timezone +from typing import List, Literal, Optional, Tuple, cast + +import fastapi +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.constants import UI_SESSION_TOKEN_TEAM_ID +from litellm.litellm_core_utils.duration_parser import duration_in_seconds +from litellm.proxy._types import * +from litellm.proxy.auth.auth_checks import ( + _cache_key_object, + _delete_cache_key_object, + get_key_object, + get_team_object, +) +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks +from litellm.proxy.management_endpoints.common_utils import ( + _is_user_team_admin, + _set_object_metadata_field, +) +from litellm.proxy.management_endpoints.model_management_endpoints import ( + _add_model_to_db, +) +from litellm.proxy.management_helpers.utils import management_endpoint_wrapper +from litellm.proxy.spend_tracking.spend_tracking_utils import _is_master_key +from litellm.proxy.utils import ( + PrismaClient, + _hash_token_if_needed, + handle_exception_on_proxy, + jsonify_object, +) +from litellm.router import Router +from litellm.secret_managers.main import get_secret +from litellm.types.router import Deployment +from litellm.types.utils import ( + BudgetConfig, + PersonalUIKeyGenerationConfig, + TeamUIKeyGenerationConfig, +) + + +def _is_team_key(data: Union[GenerateKeyRequest, LiteLLM_VerificationToken]): + return data.team_id is not None + + +def _get_user_in_team( + team_table: LiteLLM_TeamTableCachedObj, user_id: Optional[str] +) -> Optional[Member]: + if user_id is None: + return None + for member in team_table.members_with_roles: + if member.user_id is not None and member.user_id == user_id: + return member + + return None + + +def _is_allowed_to_make_key_request( + user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], team_id: Optional[str] +) -> bool: + """ + Assert user only creates keys for themselves + + Relevant issue: https://github.com/BerriAI/litellm/issues/7336 + """ + ## BASE CASE - PROXY ADMIN + if ( + user_api_key_dict.user_role is not None + and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value + ): + return True + + if user_id is not None: + assert ( + user_id == user_api_key_dict.user_id + ), "User can only create keys for themselves. Got user_id={}, Your ID={}".format( + user_id, user_api_key_dict.user_id + ) + + if team_id is not None: + if ( + user_api_key_dict.team_id is not None + and user_api_key_dict.team_id == UI_TEAM_ID + ): + return True # handle https://github.com/BerriAI/litellm/issues/7482 + assert ( + user_api_key_dict.team_id == team_id + ), "User can only create keys for their own team. Got={}, Your Team ID={}".format( + team_id, user_api_key_dict.team_id + ) + + return True + + +def _team_key_generation_team_member_check( + assigned_user_id: Optional[str], + team_table: LiteLLM_TeamTableCachedObj, + user_api_key_dict: UserAPIKeyAuth, + team_key_generation: TeamUIKeyGenerationConfig, +): + if assigned_user_id is not None: + key_assigned_user_in_team = _get_user_in_team( + team_table=team_table, user_id=assigned_user_id + ) + + if key_assigned_user_in_team is None: + raise HTTPException( + status_code=400, + detail=f"User={assigned_user_id} not assigned to team={team_table.team_id}", + ) + + key_creating_user_in_team = _get_user_in_team( + team_table=team_table, user_id=user_api_key_dict.user_id + ) + + is_admin = ( + user_api_key_dict.user_role is not None + and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value + ) + + if is_admin: + return True + elif key_creating_user_in_team is None: + raise HTTPException( + status_code=400, + detail=f"User={user_api_key_dict.user_id} not assigned to team={team_table.team_id}", + ) + elif ( + "allowed_team_member_roles" in team_key_generation + and key_creating_user_in_team.role + not in team_key_generation["allowed_team_member_roles"] + ): + raise HTTPException( + status_code=400, + detail=f"Team member role {key_creating_user_in_team.role} not in allowed_team_member_roles={team_key_generation['allowed_team_member_roles']}", + ) + return True + + +def _key_generation_required_param_check( + data: GenerateKeyRequest, required_params: Optional[List[str]] +): + if required_params is None: + return True + + data_dict = data.model_dump(exclude_unset=True) + for param in required_params: + if param not in data_dict: + raise HTTPException( + status_code=400, + detail=f"Required param {param} not in data", + ) + return True + + +def _team_key_generation_check( + team_table: LiteLLM_TeamTableCachedObj, + user_api_key_dict: UserAPIKeyAuth, + data: GenerateKeyRequest, +): + if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value: + return True + if ( + litellm.key_generation_settings is not None + and "team_key_generation" in litellm.key_generation_settings + ): + _team_key_generation = litellm.key_generation_settings["team_key_generation"] + else: + _team_key_generation = TeamUIKeyGenerationConfig( + allowed_team_member_roles=["admin", "user"], + ) + + _team_key_generation_team_member_check( + assigned_user_id=data.user_id, + team_table=team_table, + user_api_key_dict=user_api_key_dict, + team_key_generation=_team_key_generation, + ) + _key_generation_required_param_check( + data, + _team_key_generation.get("required_params"), + ) + + return True + + +def _personal_key_membership_check( + user_api_key_dict: UserAPIKeyAuth, + personal_key_generation: Optional[PersonalUIKeyGenerationConfig], +): + if ( + personal_key_generation is None + or "allowed_user_roles" not in personal_key_generation + ): + return True + + if user_api_key_dict.user_role not in personal_key_generation["allowed_user_roles"]: + raise HTTPException( + status_code=400, + detail=f"Personal key creation has been restricted by admin. Allowed roles={litellm.key_generation_settings['personal_key_generation']['allowed_user_roles']}. Your role={user_api_key_dict.user_role}", # type: ignore + ) + + return True + + +def _personal_key_generation_check( + user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest +): + + if ( + litellm.key_generation_settings is None + or litellm.key_generation_settings.get("personal_key_generation") is None + ): + return True + + _personal_key_generation = litellm.key_generation_settings["personal_key_generation"] # type: ignore + + _personal_key_membership_check( + user_api_key_dict, + personal_key_generation=_personal_key_generation, + ) + + _key_generation_required_param_check( + data, + _personal_key_generation.get("required_params"), + ) + + return True + + +def key_generation_check( + team_table: Optional[LiteLLM_TeamTableCachedObj], + user_api_key_dict: UserAPIKeyAuth, + data: GenerateKeyRequest, +) -> bool: + """ + Check if admin has restricted key creation to certain roles for teams or individuals + """ + + ## check if key is for team or individual + is_team_key = _is_team_key(data=data) + if is_team_key: + if team_table is None and litellm.key_generation_settings is not None: + raise HTTPException( + status_code=400, + detail=f"Unable to find team object in database. Team ID: {data.team_id}", + ) + elif team_table is None: + return True # assume user is assigning team_id without using the team table + return _team_key_generation_check( + team_table=team_table, + user_api_key_dict=user_api_key_dict, + data=data, + ) + else: + return _personal_key_generation_check( + user_api_key_dict=user_api_key_dict, data=data + ) + + +def common_key_access_checks( + user_api_key_dict: UserAPIKeyAuth, + data: Union[GenerateKeyRequest, UpdateKeyRequest], + llm_router: Optional[Router], + premium_user: bool, +) -> Literal[True]: + """ + Check if user is allowed to make a key request, for this key + """ + try: + _is_allowed_to_make_key_request( + user_api_key_dict=user_api_key_dict, + user_id=data.user_id, + team_id=data.team_id, + ) + except AssertionError as e: + raise HTTPException( + status_code=403, + detail=str(e), + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e), + ) + + _check_model_access_group( + models=data.models, + llm_router=llm_router, + premium_user=premium_user, + ) + return True + + +router = APIRouter() + + +@router.post( + "/key/generate", + tags=["key management"], + dependencies=[Depends(user_api_key_auth)], + response_model=GenerateKeyResponse, +) +@management_endpoint_wrapper +async def generate_key_fn( # noqa: PLR0915 + data: GenerateKeyRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Generate an API key based on the provided data. + + Docs: https://docs.litellm.ai/docs/proxy/virtual_keys + + Parameters: + - duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). + - key_alias: Optional[str] - User defined key alias + - key: Optional[str] - User defined key value. If not set, a 16-digit unique sk-key is created for you. + - team_id: Optional[str] - The team id of the key + - user_id: Optional[str] - The user id of the key + - budget_id: Optional[str] - The budget id associated with the key. Created by calling `/budget/new`. + - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) + - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models + - config: Optional[dict] - any key-specific configs, overrides config in config.yaml + - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend + - send_invite_email: Optional[bool] - Whether to send an invite email to the user_id, with the generate key + - max_budget: Optional[float] - Specify max budget for a given key. + - budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). + - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. + - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } + - guardrails: Optional[List[str]] - List of active guardrails for the key + - permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false} + - model_max_budget: Optional[Dict[str, BudgetConfig]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}}. IF null or {} then no model specific budget. + - model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit. + - model_tpm_limit: Optional[dict] - key-specific model tpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific tpm limit. + - allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request + - blocked: Optional[bool] - Whether the key is blocked. + - rpm_limit: Optional[int] - Specify rpm limit for a given key (Requests per minute) + - tpm_limit: Optional[int] - Specify tpm limit for a given key (Tokens per minute) + - soft_budget: Optional[float] - Specify soft budget for a given key. Will trigger a slack alert when this soft budget is reached. + - tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing). + - enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests) + + Examples: + + 1. Allow users to turn on/off pii masking + + ```bash + curl --location 'http://0.0.0.0:4000/key/generate' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "permissions": {"allow_pii_controls": true} + }' + ``` + + Returns: + - key: (str) The generated api key + - expires: (datetime) Datetime object for when key expires. + - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. + """ + try: + from litellm.proxy.proxy_server import ( + litellm_proxy_admin_name, + llm_router, + premium_user, + prisma_client, + user_api_key_cache, + user_custom_key_generate, + ) + + verbose_proxy_logger.debug("entered /key/generate") + + if user_custom_key_generate is not None: + if asyncio.iscoroutinefunction(user_custom_key_generate): + result = await user_custom_key_generate(data) # type: ignore + else: + raise ValueError("user_custom_key_generate must be a coroutine") + decision = result.get("decision", True) + message = result.get("message", "Authentication Failed - Custom Auth Rule") + if not decision: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=message + ) + team_table: Optional[LiteLLM_TeamTableCachedObj] = None + if data.team_id is not None: + try: + team_table = await get_team_object( + team_id=data.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=user_api_key_dict.parent_otel_span, + check_db_only=True, + ) + except Exception as e: + verbose_proxy_logger.debug( + f"Error getting team object in `/key/generate`: {e}" + ) + team_table = None + + key_generation_check( + team_table=team_table, + user_api_key_dict=user_api_key_dict, + data=data, + ) + + common_key_access_checks( + user_api_key_dict=user_api_key_dict, + data=data, + llm_router=llm_router, + premium_user=premium_user, + ) + + # check if user set default key/generate params on config.yaml + if litellm.default_key_generate_params is not None: + for elem in data: + key, value = elem + if value is None and key in [ + "max_budget", + "user_id", + "team_id", + "max_parallel_requests", + "tpm_limit", + "rpm_limit", + "budget_duration", + ]: + setattr( + data, key, litellm.default_key_generate_params.get(key, None) + ) + elif key == "models" and value == []: + setattr(data, key, litellm.default_key_generate_params.get(key, [])) + elif key == "metadata" and value == {}: + setattr(data, key, litellm.default_key_generate_params.get(key, {})) + + # check if user set default key/generate params on config.yaml + if litellm.upperbound_key_generate_params is not None: + for elem in data: + key, value = elem + upperbound_value = getattr( + litellm.upperbound_key_generate_params, key, None + ) + if upperbound_value is not None: + if value is None: + # Use the upperbound value if user didn't provide a value + setattr(data, key, upperbound_value) + else: + # Compare with upperbound for numeric fields + if key in [ + "max_budget", + "max_parallel_requests", + "tpm_limit", + "rpm_limit", + ]: + if value > upperbound_value: + raise HTTPException( + status_code=400, + detail={ + "error": f"{key} is over max limit set in config - user_value={value}; max_value={upperbound_value}" + }, + ) + # Compare durations + elif key in ["budget_duration", "duration"]: + upperbound_duration = duration_in_seconds( + duration=upperbound_value + ) + user_duration = duration_in_seconds(duration=value) + if user_duration > upperbound_duration: + raise HTTPException( + status_code=400, + detail={ + "error": f"{key} is over max limit set in config - user_value={value}; max_value={upperbound_value}" + }, + ) + + # TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable + _budget_id = data.budget_id + if prisma_client is not None and data.soft_budget is not None: + # create the Budget Row for the LiteLLM Verification Token + budget_row = LiteLLM_BudgetTable( + soft_budget=data.soft_budget, + model_max_budget=data.model_max_budget or {}, + ) + new_budget = prisma_client.jsonify_object( + budget_row.json(exclude_none=True) + ) + + _budget = await prisma_client.db.litellm_budgettable.create( + data={ + **new_budget, # type: ignore + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } + ) + _budget_id = getattr(_budget, "budget_id", None) + + # ADD METADATA FIELDS + # Set Management Endpoint Metadata Fields + for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium: + if getattr(data, field) is not None: + _set_object_metadata_field( + object_data=data, + field_name=field, + value=getattr(data, field), + ) + + data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore + + # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users + if "max_budget" in data_json: + data_json["key_max_budget"] = data_json.pop("max_budget", None) + if _budget_id is not None: + data_json["budget_id"] = _budget_id + + if "budget_duration" in data_json: + data_json["key_budget_duration"] = data_json.pop("budget_duration", None) + + if user_api_key_dict.user_id is not None: + data_json["created_by"] = user_api_key_dict.user_id + data_json["updated_by"] = user_api_key_dict.user_id + + # Set tags on the new key + if "tags" in data_json: + from litellm.proxy.proxy_server import premium_user + + if premium_user is not True and data_json["tags"] is not None: + raise ValueError( + f"Only premium users can add tags to keys. {CommonProxyErrors.not_premium_user.value}" + ) + + _metadata = data_json.get("metadata") + if not _metadata: + data_json["metadata"] = {"tags": data_json["tags"]} + else: + data_json["metadata"]["tags"] = data_json["tags"] + + data_json.pop("tags") + + await _enforce_unique_key_alias( + key_alias=data_json.get("key_alias", None), + prisma_client=prisma_client, + ) + + response = await generate_key_helper_fn( + request_type="key", **data_json, table_name="key" + ) + + response["soft_budget"] = ( + data.soft_budget + ) # include the user-input soft budget in the response + + response = GenerateKeyResponse(**response) + + asyncio.create_task( + KeyManagementEventHooks.async_key_generated_hook( + data=data, + response=response, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, + ) + ) + + return response + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.generate_key_fn(): Exception occured - {}".format( + str(e) + ) + ) + raise handle_exception_on_proxy(e) + + +def prepare_metadata_fields( + data: BaseModel, non_default_values: dict, existing_metadata: dict +) -> dict: + """ + Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated + """ + if "metadata" not in non_default_values: # allow user to set metadata to none + non_default_values["metadata"] = existing_metadata.copy() + + casted_metadata = cast(dict, non_default_values["metadata"]) + + data_json = data.model_dump(exclude_unset=True, exclude_none=True) + + try: + for k, v in data_json.items(): + if k in LiteLLM_ManagementEndpoint_MetadataFields: + if isinstance(v, datetime): + casted_metadata[k] = v.isoformat() + else: + casted_metadata[k] = v + + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.prepare_metadata_fields(): Exception occured - {}".format( + str(e) + ) + ) + + non_default_values["metadata"] = casted_metadata + return non_default_values + + +def prepare_key_update_data( + data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row +): + data_json: dict = data.model_dump(exclude_unset=True) + data_json.pop("key", None) + non_default_values = {} + for k, v in data_json.items(): + if k in LiteLLM_ManagementEndpoint_MetadataFields: + continue + non_default_values[k] = v + + if "duration" in non_default_values: + duration = non_default_values.pop("duration") + if duration and (isinstance(duration, str)) and len(duration) > 0: + duration_s = duration_in_seconds(duration=duration) + expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["expires"] = expires + + if "budget_duration" in non_default_values: + budget_duration = non_default_values.pop("budget_duration") + if ( + budget_duration + and (isinstance(budget_duration, str)) + and len(budget_duration) > 0 + ): + duration_s = duration_in_seconds(duration=budget_duration) + key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["budget_reset_at"] = key_reset_at + non_default_values["budget_duration"] = budget_duration + + _metadata = existing_key_row.metadata or {} + + # validate model_max_budget + if "model_max_budget" in non_default_values: + validate_model_max_budget(non_default_values["model_max_budget"]) + + non_default_values = prepare_metadata_fields( + data=data, non_default_values=non_default_values, existing_metadata=_metadata + ) + + return non_default_values + + +@router.post( + "/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def update_key_fn( + request: Request, + data: UpdateKeyRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Update an existing API key's parameters. + + Parameters: + - key: str - The key to update + - key_alias: Optional[str] - User-friendly key alias + - user_id: Optional[str] - User ID associated with key + - team_id: Optional[str] - Team ID associated with key + - budget_id: Optional[str] - The budget id associated with the key. Created by calling `/budget/new`. + - models: Optional[list] - Model_name's a user is allowed to call + - tags: Optional[List[str]] - Tags for organizing keys (Enterprise only) + - enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests) + - spend: Optional[float] - Amount spent by key + - max_budget: Optional[float] - Max budget for key + - model_max_budget: Optional[Dict[str, BudgetConfig]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}} + - budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.) + - soft_budget: Optional[float] - [TODO] Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached. + - max_parallel_requests: Optional[int] - Rate limit for parallel requests + - metadata: Optional[dict] - Metadata for key. Example {"team": "core-infra", "app": "app2"} + - tpm_limit: Optional[int] - Tokens per minute limit + - rpm_limit: Optional[int] - Requests per minute limit + - model_rpm_limit: Optional[dict] - Model-specific RPM limits {"gpt-4": 100, "claude-v1": 200} + - model_tpm_limit: Optional[dict] - Model-specific TPM limits {"gpt-4": 100000, "claude-v1": 200000} + - allowed_cache_controls: Optional[list] - List of allowed cache control values + - duration: Optional[str] - Key validity duration ("30d", "1h", etc.) + - permissions: Optional[dict] - Key-specific permissions + - send_invite_email: Optional[bool] - Send invite email to user_id + - guardrails: Optional[List[str]] - List of active guardrails for the key + - blocked: Optional[bool] - Whether the key is blocked + - aliases: Optional[dict] - Model aliases for the key - [Docs](https://litellm.vercel.app/docs/proxy/virtual_keys#model-aliases) + - config: Optional[dict] - [DEPRECATED PARAM] Key-specific config. + - temp_budget_increase: Optional[float] - Temporary budget increase for the key (Enterprise only). + - temp_budget_expiry: Optional[str] - Expiry time for the temporary budget increase (Enterprise only). + + Example: + ```bash + curl --location 'http://0.0.0.0:4000/key/update' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "key": "sk-1234", + "key_alias": "my-key", + "user_id": "user-1234", + "team_id": "team-1234", + "max_budget": 100, + "metadata": {"any_key": "any-val"}, + }' + ``` + """ + from litellm.proxy.proxy_server import ( + llm_router, + premium_user, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + try: + data_json: dict = data.model_dump(exclude_unset=True, exclude_none=True) + key = data_json.pop("key") + + # get the row from db + if prisma_client is None: + raise Exception("Not connected to DB!") + + common_key_access_checks( + user_api_key_dict=user_api_key_dict, + data=data, + llm_router=llm_router, + premium_user=premium_user, + ) + + existing_key_row = await prisma_client.get_data( + token=data.key, table_name="key", query_type="find_unique" + ) + + if existing_key_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={data.team_id}"}, + ) + + non_default_values = prepare_key_update_data( + data=data, existing_key_row=existing_key_row + ) + + await _enforce_unique_key_alias( + key_alias=non_default_values.get("key_alias", None), + prisma_client=prisma_client, + existing_key_token=existing_key_row.token, + ) + + _data = {**non_default_values, "token": key} + response = await prisma_client.update_data(token=key, data=_data) + + # Delete - key from cache, since it's been updated! + # key updated - a new model could have been added to this key. it should not block requests after this is done + await _delete_cache_key_object( + hashed_token=hash_token(key), + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + asyncio.create_task( + KeyManagementEventHooks.async_key_updated_hook( + data=data, + existing_key_row=existing_key_row, + response=response, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, + ) + ) + + if response is None: + raise ValueError("Failed to update key got response = None") + + return {"key": key, **response["data"]} + # update based on remaining passed in values + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.update_key_fn(): Exception occured - {}".format( + str(e) + ) + ) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + +@router.post( + "/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def delete_key_fn( + data: KeyRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Delete a key from the key management system. + + Parameters:: + - keys (List[str]): A list of keys or hashed keys to delete. Example {"keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]} + - key_aliases (List[str]): A list of key aliases to delete. Can be passed instead of `keys`.Example {"key_aliases": ["alias1", "alias2"]} + + Returns: + - deleted_keys (List[str]): A list of deleted keys. Example {"deleted_keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]} + + Example: + ```bash + curl --location 'http://0.0.0.0:4000/key/delete' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "keys": ["sk-QWrxEynunsNpV1zT48HIrw"] + }' + ``` + + Raises: + HTTPException: If an error occurs during key deletion. + """ + try: + from litellm.proxy.proxy_server import prisma_client, user_api_key_cache + + if prisma_client is None: + raise Exception("Not connected to DB!") + + ## only allow user to delete keys they own + verbose_proxy_logger.debug( + f"user_api_key_dict.user_role: {user_api_key_dict.user_role}" + ) + + num_keys_to_be_deleted = 0 + deleted_keys = [] + if data.keys: + number_deleted_keys, _keys_being_deleted = await delete_verification_tokens( + tokens=data.keys, + user_api_key_cache=user_api_key_cache, + user_api_key_dict=user_api_key_dict, + ) + num_keys_to_be_deleted = len(data.keys) + deleted_keys = data.keys + elif data.key_aliases: + number_deleted_keys, _keys_being_deleted = await delete_key_aliases( + key_aliases=data.key_aliases, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + user_api_key_dict=user_api_key_dict, + ) + num_keys_to_be_deleted = len(data.key_aliases) + deleted_keys = data.key_aliases + else: + raise ValueError("Invalid request type") + + if number_deleted_keys is None: + raise ProxyException( + message="Failed to delete keys got None response from delete_verification_token", + type=ProxyErrorTypes.internal_server_error, + param="keys", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + verbose_proxy_logger.debug(f"/key/delete - deleted_keys={number_deleted_keys}") + + try: + assert num_keys_to_be_deleted == len(deleted_keys) + except Exception: + raise HTTPException( + status_code=400, + detail={ + "error": f"Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in. Keys passed in={num_keys_to_be_deleted}, Deleted keys ={number_deleted_keys}" + }, + ) + + verbose_proxy_logger.debug( + f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}" + ) + + asyncio.create_task( + KeyManagementEventHooks.async_key_deleted_hook( + data=data, + keys_being_deleted=_keys_being_deleted, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, + response=number_deleted_keys, + ) + ) + + return {"deleted_keys": deleted_keys} + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.delete_key_fn(): Exception occured - {}".format( + str(e) + ) + ) + raise handle_exception_on_proxy(e) + + +@router.post( + "/v2/key/info", + tags=["key management"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def info_key_fn_v2( + data: Optional[KeyRequest] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Retrieve information about a list of keys. + + **New endpoint**. Currently admin only. + Parameters: + keys: Optional[list] = body parameter representing the key(s) in the request + user_api_key_dict: UserAPIKeyAuth = Dependency representing the user's API key + Returns: + Dict containing the key and its associated information + + Example Curl: + ``` + curl -X GET "http://0.0.0.0:4000/key/info" \ + -H "Authorization: Bearer sk-1234" \ + -d {"keys": ["sk-1", "sk-2", "sk-3"]} + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + try: + if prisma_client is None: + raise Exception( + "Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + if data is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={"message": "Malformed request. No keys passed in."}, + ) + + key_info = await prisma_client.get_data( + token=data.keys, table_name="key", query_type="find_all" + ) + if key_info is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"message": "No keys found"}, + ) + filtered_key_info = [] + for k in key_info: + try: + k = k.model_dump() # noqa + except Exception: + # if using pydantic v1 + k = k.dict() + filtered_key_info.append(k) + return {"key": data.keys, "info": filtered_key_info} + + except Exception as e: + raise handle_exception_on_proxy(e) + + +@router.get( + "/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +async def info_key_fn( + key: Optional[str] = fastapi.Query( + default=None, description="Key in the request parameters" + ), + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Retrieve information about a key. + Parameters: + key: Optional[str] = Query parameter representing the key in the request + user_api_key_dict: UserAPIKeyAuth = Dependency representing the user's API key + Returns: + Dict containing the key and its associated information + + Example Curl: + ``` + curl -X GET "http://0.0.0.0:4000/key/info?key=sk-02Wr4IAlN3NvPXvL5JVvDA" \ +-H "Authorization: Bearer sk-1234" + ``` + + Example Curl - if no key is passed, it will use the Key Passed in Authorization Header + ``` + curl -X GET "http://0.0.0.0:4000/key/info" \ +-H "Authorization: Bearer sk-02Wr4IAlN3NvPXvL5JVvDA" + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + try: + if prisma_client is None: + raise Exception( + "Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + # default to using Auth token if no key is passed in + key = key or user_api_key_dict.api_key + hashed_key: Optional[str] = key + if key is not None: + hashed_key = _hash_token_if_needed(token=key) + key_info = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed_key}, # type: ignore + include={"litellm_budget_table": True}, + ) + if key_info is None: + raise ProxyException( + message="Key not found in database", + type=ProxyErrorTypes.not_found_error, + param="key", + code=status.HTTP_404_NOT_FOUND, + ) + + if ( + _can_user_query_key_info( + user_api_key_dict=user_api_key_dict, + key=key, + key_info=key_info, + ) + is not True + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You are not allowed to access this key's info. Your role={}".format( + user_api_key_dict.user_role + ), + ) + ## REMOVE HASHED TOKEN INFO BEFORE RETURNING ## + try: + key_info = key_info.model_dump() # noqa + except Exception: + # if using pydantic v1 + key_info = key_info.dict() + key_info.pop("token") + return {"key": key, "info": key_info} + except Exception as e: + raise handle_exception_on_proxy(e) + + +def _check_model_access_group( + models: Optional[List[str]], llm_router: Optional[Router], premium_user: bool +) -> Literal[True]: + """ + if is_model_access_group is True + is_wildcard_route is True, check if user is a premium user + + Return True if user is a premium user, False otherwise + """ + if models is None or llm_router is None: + return True + + for model in models: + if llm_router._is_model_access_group_for_wildcard_route( + model_access_group=model + ): + if not premium_user: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": "Setting a model access group on a wildcard model is only available for LiteLLM Enterprise users.{}".format( + CommonProxyErrors.not_premium_user.value + ) + }, + ) + + return True + + +async def generate_key_helper_fn( # noqa: PLR0915 + request_type: Literal[ + "user", "key" + ], # identifies if this request is from /user/new or /key/generate + duration: Optional[str] = None, + models: list = [], + aliases: dict = {}, + config: dict = {}, + spend: float = 0.0, + key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key + key_budget_duration: Optional[str] = None, + budget_id: Optional[float] = None, # budget id <-> LiteLLM_BudgetTable + soft_budget: Optional[ + float + ] = None, # soft_budget is used to set soft Budgets Per user + max_budget: Optional[float] = None, # max_budget is used to Budget Per user + blocked: Optional[bool] = None, + budget_duration: Optional[str] = None, # max_budget is used to Budget Per user + token: Optional[str] = None, + key: Optional[ + str + ] = None, # dev-friendly alt param for 'token'. Exposed on `/key/generate` for setting key value yourself. + user_id: Optional[str] = None, + user_alias: Optional[str] = None, + team_id: Optional[str] = None, + user_email: Optional[str] = None, + user_role: Optional[str] = None, + max_parallel_requests: Optional[int] = None, + metadata: Optional[dict] = {}, + tpm_limit: Optional[int] = None, + rpm_limit: Optional[int] = None, + query_type: Literal["insert_data", "update_data"] = "insert_data", + update_key_values: Optional[dict] = None, + key_alias: Optional[str] = None, + allowed_cache_controls: Optional[list] = [], + permissions: Optional[dict] = {}, + model_max_budget: Optional[dict] = {}, + model_rpm_limit: Optional[dict] = None, + model_tpm_limit: Optional[dict] = None, + guardrails: Optional[list] = None, + teams: Optional[list] = None, + organization_id: Optional[str] = None, + table_name: Optional[Literal["key", "user"]] = None, + send_invite_email: Optional[bool] = None, + created_by: Optional[str] = None, + updated_by: Optional[str] = None, +): + from litellm.proxy.proxy_server import ( + litellm_proxy_budget_name, + premium_user, + prisma_client, + ) + + if prisma_client is None: + raise Exception( + "Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys " + ) + + if token is None: + if key is not None: + token = key + else: + token = f"sk-{secrets.token_urlsafe(16)}" + + if duration is None: # allow tokens that never expire + expires = None + else: + duration_s = duration_in_seconds(duration=duration) + expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + + if key_budget_duration is None: # one-time budget + key_reset_at = None + else: + duration_s = duration_in_seconds(duration=key_budget_duration) + key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + + if budget_duration is None: # one-time budget + reset_at = None + else: + duration_s = duration_in_seconds(duration=budget_duration) + reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + + aliases_json = json.dumps(aliases) + config_json = json.dumps(config) + permissions_json = json.dumps(permissions) + + # Add model_rpm_limit and model_tpm_limit to metadata + if model_rpm_limit is not None: + metadata = metadata or {} + metadata["model_rpm_limit"] = model_rpm_limit + if model_tpm_limit is not None: + metadata = metadata or {} + metadata["model_tpm_limit"] = model_tpm_limit + if guardrails is not None: + metadata = metadata or {} + metadata["guardrails"] = guardrails + + metadata_json = json.dumps(metadata) + validate_model_max_budget(model_max_budget) + model_max_budget_json = json.dumps(model_max_budget) + user_role = user_role + tpm_limit = tpm_limit + rpm_limit = rpm_limit + allowed_cache_controls = allowed_cache_controls + + try: + # Create a new verification token (you may want to enhance this logic based on your needs) + user_data = { + "max_budget": max_budget, + "user_email": user_email, + "user_id": user_id, + "user_alias": user_alias, + "team_id": team_id, + "organization_id": organization_id, + "user_role": user_role, + "spend": spend, + "models": models, + "metadata": metadata_json, + "max_parallel_requests": max_parallel_requests, + "tpm_limit": tpm_limit, + "rpm_limit": rpm_limit, + "budget_duration": budget_duration, + "budget_reset_at": reset_at, + "allowed_cache_controls": allowed_cache_controls, + } + if teams is not None: + user_data["teams"] = teams + key_data = { + "token": token, + "key_alias": key_alias, + "expires": expires, + "models": models, + "aliases": aliases_json, + "config": config_json, + "spend": spend, + "max_budget": key_max_budget, + "user_id": user_id, + "team_id": team_id, + "max_parallel_requests": max_parallel_requests, + "metadata": metadata_json, + "tpm_limit": tpm_limit, + "rpm_limit": rpm_limit, + "budget_duration": key_budget_duration, + "budget_reset_at": key_reset_at, + "allowed_cache_controls": allowed_cache_controls, + "permissions": permissions_json, + "model_max_budget": model_max_budget_json, + "budget_id": budget_id, + "blocked": blocked, + "created_by": created_by, + "updated_by": updated_by, + } + + if ( + get_secret("DISABLE_KEY_NAME", False) is True + ): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much) + pass + else: + key_data["key_name"] = f"sk-...{token[-4:]}" + saved_token = copy.deepcopy(key_data) + if isinstance(saved_token["aliases"], str): + saved_token["aliases"] = json.loads(saved_token["aliases"]) + if isinstance(saved_token["config"], str): + saved_token["config"] = json.loads(saved_token["config"]) + if isinstance(saved_token["metadata"], str): + saved_token["metadata"] = json.loads(saved_token["metadata"]) + if isinstance(saved_token["permissions"], str): + if ( + "get_spend_routes" in saved_token["permissions"] + and premium_user is not True + ): + raise ValueError( + "get_spend_routes permission is only available for LiteLLM Enterprise users" + ) + + saved_token["permissions"] = json.loads(saved_token["permissions"]) + if isinstance(saved_token["model_max_budget"], str): + saved_token["model_max_budget"] = json.loads( + saved_token["model_max_budget"] + ) + + if saved_token.get("expires", None) is not None and isinstance( + saved_token["expires"], datetime + ): + saved_token["expires"] = saved_token["expires"].isoformat() + if prisma_client is not None: + if ( + table_name is None or table_name == "user" + ): # do not auto-create users for `/key/generate` + ## CREATE USER (If necessary) + if query_type == "insert_data": + user_row = await prisma_client.insert_data( + data=user_data, table_name="user" + ) + + if user_row is None: + raise Exception("Failed to create user") + ## use default user model list if no key-specific model list provided + if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore + key_data["models"] = user_row.models # type: ignore + elif query_type == "update_data": + user_row = await prisma_client.update_data( + data=user_data, + table_name="user", + update_key_values=update_key_values, + ) + if user_id == litellm_proxy_budget_name or ( + table_name is not None and table_name == "user" + ): + # do not create a key for litellm_proxy_budget_name or if table name is set to just 'user' + # we only need to ensure this exists in the user table + # the LiteLLM_VerificationToken table will increase in size if we don't do this check + return user_data + + ## CREATE KEY + verbose_proxy_logger.debug("prisma_client: Creating Key= %s", key_data) + create_key_response = await prisma_client.insert_data( + data=key_data, table_name="key" + ) + key_data["token_id"] = getattr(create_key_response, "token", None) + key_data["litellm_budget_table"] = getattr( + create_key_response, "litellm_budget_table", None + ) + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.generate_key_helper_fn(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise e + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": "Internal Server Error."}, + ) + + # Add budget related info in key_data - this ensures it's returned + key_data["budget_id"] = budget_id + + if request_type == "user": + # if this is a /user/new request update the key_date with user_data fields + key_data.update(user_data) + return key_data + + +async def _team_key_deletion_check( + user_api_key_dict: UserAPIKeyAuth, + key_info: LiteLLM_VerificationToken, + prisma_client: PrismaClient, + user_api_key_cache: DualCache, +): + is_team_key = _is_team_key(data=key_info) + + if is_team_key and key_info.team_id is not None: + team_table = await get_team_object( + team_id=key_info.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + check_db_only=True, + ) + if ( + litellm.key_generation_settings is not None + and "team_key_generation" in litellm.key_generation_settings + ): + _team_key_generation = litellm.key_generation_settings[ + "team_key_generation" + ] + else: + _team_key_generation = TeamUIKeyGenerationConfig( + allowed_team_member_roles=["admin", "user"], + ) + # check if user is team admin + if team_table is not None: + return _team_key_generation_team_member_check( + assigned_user_id=user_api_key_dict.user_id, + team_table=team_table, + user_api_key_dict=user_api_key_dict, + team_key_generation=_team_key_generation, + ) + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "error": f"Team not found in db, and user not proxy admin. Team id = {key_info.team_id}" + }, + ) + return False + + +async def can_delete_verification_token( + key_info: LiteLLM_VerificationToken, + user_api_key_cache: DualCache, + user_api_key_dict: UserAPIKeyAuth, + prisma_client: PrismaClient, +) -> bool: + """ + - check if user is proxy admin + - check if user is team admin and key is a team key + - check if key is personal key + """ + is_team_key = _is_team_key(data=key_info) + if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value: + return True + elif is_team_key and key_info.team_id is not None: + return await _team_key_deletion_check( + user_api_key_dict=user_api_key_dict, + key_info=key_info, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + elif key_info.user_id is not None and key_info.user_id == user_api_key_dict.user_id: + return True + else: + return False + + +async def delete_verification_tokens( + tokens: List, + user_api_key_cache: DualCache, + user_api_key_dict: UserAPIKeyAuth, +) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: + """ + Helper that deletes the list of tokens from the database + + - check if user is proxy admin + - check if user is team admin and key is a team key + + Args: + tokens: List of tokens to delete + user_id: Optional user_id to filter by + + Returns: + Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: + Optional[Dict]: + - Number of deleted tokens + List[LiteLLM_VerificationToken]: + - List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted, + this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs + """ + from litellm.proxy.proxy_server import prisma_client + + try: + if prisma_client: + tokens = [_hash_token_if_needed(token=key) for key in tokens] + _keys_being_deleted: List[LiteLLM_VerificationToken] = ( + await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} + ) + ) + + # Assuming 'db' is your Prisma Client instance + # check if admin making request - don't filter by user-id + if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value: + deleted_tokens = await prisma_client.delete_data(tokens=tokens) + # else + else: + tasks = [] + deleted_tokens = [] + for key in _keys_being_deleted: + + async def _delete_key(key: LiteLLM_VerificationToken): + if await can_delete_verification_token( + key_info=key, + user_api_key_cache=user_api_key_cache, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + ): + await prisma_client.delete_data(tokens=[key.token]) + deleted_tokens.append(key.token) + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": "You are not authorized to delete this key" + }, + ) + + tasks.append(_delete_key(key)) + await asyncio.gather(*tasks) + + _num_deleted_tokens = len(deleted_tokens) + if _num_deleted_tokens != len(tokens): + failed_tokens = [ + token for token in tokens if token not in deleted_tokens + ] + raise Exception( + "Failed to delete all tokens. Failed to delete tokens: " + + str(failed_tokens) + ) + else: + raise Exception("DB not connected. prisma_client is None") + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.delete_verification_tokens(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + raise e + + for key in tokens: + user_api_key_cache.delete_cache(key) + # remove hash token from cache + hashed_token = hash_token(cast(str, key)) + user_api_key_cache.delete_cache(hashed_token) + + return {"deleted_keys": deleted_tokens}, _keys_being_deleted + + +async def delete_key_aliases( + key_aliases: List[str], + user_api_key_cache: DualCache, + prisma_client: PrismaClient, + user_api_key_dict: UserAPIKeyAuth, +) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: + _keys_being_deleted = await prisma_client.db.litellm_verificationtoken.find_many( + where={"key_alias": {"in": key_aliases}} + ) + + tokens = [key.token for key in _keys_being_deleted] + return await delete_verification_tokens( + tokens=tokens, + user_api_key_cache=user_api_key_cache, + user_api_key_dict=user_api_key_dict, + ) + + +async def _rotate_master_key( + prisma_client: PrismaClient, + user_api_key_dict: UserAPIKeyAuth, + current_master_key: str, + new_master_key: str, +) -> None: + """ + Rotate the master key + + 1. Get the values from the DB + - Get models from DB + - Get config from DB + 2. Decrypt the values + - ModelTable + - [{"model_name": "str", "litellm_params": {}}] + - ConfigTable + 3. Encrypt the values with the new master key + 4. Update the values in the DB + """ + from litellm.proxy.proxy_server import proxy_config + + try: + models: Optional[List] = ( + await prisma_client.db.litellm_proxymodeltable.find_many() + ) + except Exception: + models = None + # 2. process model table + if models: + decrypted_models = proxy_config.decrypt_model_list_from_db(new_models=models) + verbose_proxy_logger.info( + "ABLE TO DECRYPT MODELS - len(decrypted_models): %s", len(decrypted_models) + ) + new_models = [] + for model in decrypted_models: + new_model = await _add_model_to_db( + model_params=Deployment(**model), + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + new_encryption_key=new_master_key, + should_create_model_in_db=False, + ) + if new_model: + new_models.append(jsonify_object(new_model.model_dump())) + verbose_proxy_logger.info("Resetting proxy model table") + await prisma_client.db.litellm_proxymodeltable.delete_many() + verbose_proxy_logger.info("Creating %s models", len(new_models)) + await prisma_client.db.litellm_proxymodeltable.create_many( + data=new_models, + ) + # 3. process config table + try: + config = await prisma_client.db.litellm_config.find_many() + except Exception: + config = None + + if config: + """If environment_variables is found, decrypt it and encrypt it with the new master key""" + environment_variables_dict = {} + for c in config: + if c.param_name == "environment_variables": + environment_variables_dict = c.param_value + + if environment_variables_dict: + decrypted_env_vars = proxy_config._decrypt_and_set_db_env_variables( + environment_variables=environment_variables_dict + ) + encrypted_env_vars = proxy_config._encrypt_env_variables( + environment_variables=decrypted_env_vars, + new_encryption_key=new_master_key, + ) + + if encrypted_env_vars: + await prisma_client.db.litellm_config.update( + where={"param_name": "environment_variables"}, + data={"param_value": jsonify_object(encrypted_env_vars)}, + ) + + +@router.post( + "/key/{key:path}/regenerate", + tags=["key management"], + dependencies=[Depends(user_api_key_auth)], +) +@router.post( + "/key/regenerate", + tags=["key management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def regenerate_key_fn( + key: Optional[str] = None, + data: Optional[RegenerateKeyRequest] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +) -> Optional[GenerateKeyResponse]: + """ + Regenerate an existing API key while optionally updating its parameters. + + Parameters: + - key: str (path parameter) - The key to regenerate + - data: Optional[RegenerateKeyRequest] - Request body containing optional parameters to update + - key_alias: Optional[str] - User-friendly key alias + - user_id: Optional[str] - User ID associated with key + - team_id: Optional[str] - Team ID associated with key + - models: Optional[list] - Model_name's a user is allowed to call + - tags: Optional[List[str]] - Tags for organizing keys (Enterprise only) + - spend: Optional[float] - Amount spent by key + - max_budget: Optional[float] - Max budget for key + - model_max_budget: Optional[Dict[str, BudgetConfig]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}} + - budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.) + - soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached. + - max_parallel_requests: Optional[int] - Rate limit for parallel requests + - metadata: Optional[dict] - Metadata for key. Example {"team": "core-infra", "app": "app2"} + - tpm_limit: Optional[int] - Tokens per minute limit + - rpm_limit: Optional[int] - Requests per minute limit + - model_rpm_limit: Optional[dict] - Model-specific RPM limits {"gpt-4": 100, "claude-v1": 200} + - model_tpm_limit: Optional[dict] - Model-specific TPM limits {"gpt-4": 100000, "claude-v1": 200000} + - allowed_cache_controls: Optional[list] - List of allowed cache control values + - duration: Optional[str] - Key validity duration ("30d", "1h", etc.) + - permissions: Optional[dict] - Key-specific permissions + - guardrails: Optional[List[str]] - List of active guardrails for the key + - blocked: Optional[bool] - Whether the key is blocked + + + Returns: + - GenerateKeyResponse containing the new key and its updated parameters + + Example: + ```bash + curl --location --request POST 'http://localhost:4000/key/sk-1234/regenerate' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "max_budget": 100, + "metadata": {"team": "core-infra"}, + "models": ["gpt-4", "gpt-3.5-turbo"] + }' + ``` + + Note: This is an Enterprise feature. It requires a premium license to use. + """ + try: + + from litellm.proxy.proxy_server import ( + hash_token, + master_key, + premium_user, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + if premium_user is not True: + raise ValueError( + f"Regenerating Virtual Keys is an Enterprise feature, {CommonProxyErrors.not_premium_user.value}" + ) + + # Check if key exists, raise exception if key is not in the DB + key = data.key if data and data.key else key + if not key: + raise HTTPException(status_code=400, detail={"error": "No key passed in."}) + ### 1. Create New copy that is duplicate of existing key + ###################################################################### + + # create duplicate of existing key + # set token = new token generated + # insert new token in DB + + # create hash of token + if prisma_client is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": "DB not connected. prisma_client is None"}, + ) + + _is_master_key_valid = _is_master_key(api_key=key, _master_key=master_key) + + if master_key is not None and data and _is_master_key_valid: + if data.new_master_key is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "New master key is required."}, + ) + await _rotate_master_key( + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + current_master_key=master_key, + new_master_key=data.new_master_key, + ) + return GenerateKeyResponse( + key=data.new_master_key, + token=data.new_master_key, + key_name=data.new_master_key, + expires=None, + ) + + if "sk" not in key: + hashed_api_key = key + else: + hashed_api_key = hash_token(key) + + _key_in_db = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed_api_key}, + ) + if _key_in_db is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": f"Key {key} not found."}, + ) + + verbose_proxy_logger.debug("key_in_db: %s", _key_in_db) + + new_token = f"sk-{secrets.token_urlsafe(16)}" + new_token_hash = hash_token(new_token) + new_token_key_name = f"sk-...{new_token[-4:]}" + + # Prepare the update data + update_data = { + "token": new_token_hash, + "key_name": new_token_key_name, + } + + non_default_values = {} + if data is not None: + # Update with any provided parameters from GenerateKeyRequest + non_default_values = prepare_key_update_data( + data=data, existing_key_row=_key_in_db + ) + verbose_proxy_logger.debug("non_default_values: %s", non_default_values) + + update_data.update(non_default_values) + update_data = prisma_client.jsonify_object(data=update_data) + # Update the token in the database + updated_token = await prisma_client.db.litellm_verificationtoken.update( + where={"token": hashed_api_key}, + data=update_data, # type: ignore + ) + + updated_token_dict = {} + if updated_token is not None: + updated_token_dict = dict(updated_token) + + updated_token_dict["key"] = new_token + updated_token_dict["token_id"] = updated_token_dict.pop("token") + + ### 3. remove existing key entry from cache + ###################################################################### + if key: + await _delete_cache_key_object( + hashed_token=hash_token(key), + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + if hashed_api_key: + await _delete_cache_key_object( + hashed_token=hash_token(key), + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + response = GenerateKeyResponse( + **updated_token_dict, + ) + + asyncio.create_task( + KeyManagementEventHooks.async_key_rotated_hook( + data=data, + existing_key_row=_key_in_db, + response=response, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, + ) + ) + + return response + except Exception as e: + verbose_proxy_logger.exception("Error regenerating key: %s", e) + raise handle_exception_on_proxy(e) + + +async def validate_key_list_check( + user_api_key_dict: UserAPIKeyAuth, + user_id: Optional[str], + team_id: Optional[str], + organization_id: Optional[str], + key_alias: Optional[str], + prisma_client: PrismaClient, +) -> Optional[LiteLLM_UserTable]: + + if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value: + return None + + if user_api_key_dict.user_id is None: + raise ProxyException( + message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.", + type=ProxyErrorTypes.bad_request_error, + param="user_id", + code=status.HTTP_403_FORBIDDEN, + ) + complete_user_info_db_obj: Optional[BaseModel] = ( + await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, + ) + ) + + if complete_user_info_db_obj is None: + raise ProxyException( + message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.", + type=ProxyErrorTypes.bad_request_error, + param="user_id", + code=status.HTTP_403_FORBIDDEN, + ) + + complete_user_info = LiteLLM_UserTable(**complete_user_info_db_obj.model_dump()) + + # internal user can only see their own keys + if user_id: + if complete_user_info.user_id != user_id: + raise ProxyException( + message="You are not authorized to check another user's keys", + type=ProxyErrorTypes.bad_request_error, + param="user_id", + code=status.HTTP_403_FORBIDDEN, + ) + + if team_id: + if team_id not in complete_user_info.teams: + raise ProxyException( + message="You are not authorized to check this team's keys", + type=ProxyErrorTypes.bad_request_error, + param="team_id", + code=status.HTTP_403_FORBIDDEN, + ) + + if organization_id: + if ( + complete_user_info.organization_memberships is None + or organization_id + not in [ + membership.organization_id + for membership in complete_user_info.organization_memberships + ] + ): + raise ProxyException( + message="You are not authorized to check this organization's keys", + type=ProxyErrorTypes.bad_request_error, + param="organization_id", + code=status.HTTP_403_FORBIDDEN, + ) + return complete_user_info + + +async def get_admin_team_ids( + complete_user_info: Optional[LiteLLM_UserTable], + user_api_key_dict: UserAPIKeyAuth, + prisma_client: PrismaClient, +) -> List[str]: + """ + Get all team IDs where the user is an admin. + """ + if complete_user_info is None: + return [] + # Get all teams that user is an admin of + teams: Optional[List[BaseModel]] = ( + await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": complete_user_info.teams}} + ) + ) + if teams is None: + return [] + + teams_pydantic_obj = [LiteLLM_TeamTable(**team.model_dump()) for team in teams] + + admin_team_ids = [ + team.team_id + for team in teams_pydantic_obj + if _is_user_team_admin(user_api_key_dict=user_api_key_dict, team_obj=team) + ] + return admin_team_ids + + +@router.get( + "/key/list", + tags=["key management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def list_keys( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + page: int = Query(1, description="Page number", ge=1), + size: int = Query(10, description="Page size", ge=1, le=100), + user_id: Optional[str] = Query(None, description="Filter keys by user ID"), + team_id: Optional[str] = Query(None, description="Filter keys by team ID"), + organization_id: Optional[str] = Query( + None, description="Filter keys by organization ID" + ), + key_alias: Optional[str] = Query(None, description="Filter keys by key alias"), + return_full_object: bool = Query(False, description="Return full key object"), + include_team_keys: bool = Query( + False, description="Include all keys for teams that user is an admin of." + ), +) -> KeyListResponseObject: + """ + List all keys for a given user / team / organization. + + Returns: + { + "keys": List[str] or List[UserAPIKeyAuth], + "total_count": int, + "current_page": int, + "total_pages": int, + } + """ + try: + from litellm.proxy.proxy_server import prisma_client + + verbose_proxy_logger.debug("Entering list_keys function") + + if prisma_client is None: + verbose_proxy_logger.error("Database not connected") + raise Exception("Database not connected") + + complete_user_info = await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=user_id, + team_id=team_id, + organization_id=organization_id, + key_alias=key_alias, + prisma_client=prisma_client, + ) + + if include_team_keys: + admin_team_ids = await get_admin_team_ids( + complete_user_info=complete_user_info, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + ) + else: + admin_team_ids = None + + if user_id is None and user_api_key_dict.user_role not in [ + LitellmUserRoles.PROXY_ADMIN.value, + LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value, + ]: + user_id = user_api_key_dict.user_id + + response = await _list_key_helper( + prisma_client=prisma_client, + page=page, + size=size, + user_id=user_id, + team_id=team_id, + key_alias=key_alias, + return_full_object=return_full_object, + organization_id=organization_id, + admin_team_ids=admin_team_ids, + ) + + verbose_proxy_logger.debug("Successfully prepared response") + + return response + + except Exception as e: + verbose_proxy_logger.exception(f"Error in list_keys: {e}") + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"error({str(e)})"), + type=ProxyErrorTypes.internal_server_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.internal_server_error, + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +async def _list_key_helper( + prisma_client: PrismaClient, + page: int, + size: int, + user_id: Optional[str], + team_id: Optional[str], + organization_id: Optional[str], + key_alias: Optional[str], + exclude_team_id: Optional[str] = None, + return_full_object: bool = False, + admin_team_ids: Optional[ + List[str] + ] = None, # New parameter for teams where user is admin +) -> KeyListResponseObject: + """ + Helper function to list keys + Args: + page: int + size: int + user_id: Optional[str] + team_id: Optional[str] + key_alias: Optional[str] + exclude_team_id: Optional[str] # exclude a specific team_id + return_full_object: bool # when true, will return UserAPIKeyAuth objects instead of just the token + admin_team_ids: Optional[List[str]] # list of team IDs where the user is an admin + + Returns: + KeyListResponseObject + { + "keys": List[str] or List[UserAPIKeyAuth], # Updated to reflect possible return types + "total_count": int, + "current_page": int, + "total_pages": int, + } + """ + + # Prepare filter conditions + where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {} + where.update(_get_condition_to_filter_out_ui_session_tokens()) + + # Build the OR conditions for user's keys and admin team keys + or_conditions: List[Dict[str, Any]] = [] + + # Base conditions for user's own keys + user_condition: Dict[str, Any] = {} + if user_id and isinstance(user_id, str): + user_condition["user_id"] = user_id + if team_id and isinstance(team_id, str): + user_condition["team_id"] = team_id + if key_alias and isinstance(key_alias, str): + user_condition["key_alias"] = key_alias + if exclude_team_id and isinstance(exclude_team_id, str): + user_condition["team_id"] = {"not": exclude_team_id} + if organization_id and isinstance(organization_id, str): + user_condition["organization_id"] = organization_id + + if user_condition: + or_conditions.append(user_condition) + + # Add condition for admin team keys if provided + if admin_team_ids: + or_conditions.append({"team_id": {"in": admin_team_ids}}) + + # Combine conditions with OR if we have multiple conditions + if len(or_conditions) > 1: + where["OR"] = or_conditions + elif len(or_conditions) == 1: + where.update(or_conditions[0]) + + verbose_proxy_logger.debug(f"Filter conditions: {where}") + + # Calculate skip for pagination + skip = (page - 1) * size + + verbose_proxy_logger.debug(f"Pagination: skip={skip}, take={size}") + + # Fetch keys with pagination + keys = await prisma_client.db.litellm_verificationtoken.find_many( + where=where, # type: ignore + skip=skip, # type: ignore + take=size, # type: ignore + order=[ + {"created_at": "desc"}, + {"token": "desc"}, # fallback sort + ], + ) + + verbose_proxy_logger.debug(f"Fetched {len(keys)} keys") + + # Get total count of keys + total_count = await prisma_client.db.litellm_verificationtoken.count( + where=where # type: ignore + ) + + verbose_proxy_logger.debug(f"Total count of keys: {total_count}") + + # Calculate total pages + total_pages = -(-total_count // size) # Ceiling division + + # Prepare response + key_list: List[Union[str, UserAPIKeyAuth]] = [] + for key in keys: + if return_full_object is True: + key_list.append(UserAPIKeyAuth(**key.dict())) # Return full key object + else: + _token = key.dict().get("token") + key_list.append(_token) # Return only the token + + return KeyListResponseObject( + keys=key_list, + total_count=total_count, + current_page=page, + total_pages=total_pages, + ) + + +def _get_condition_to_filter_out_ui_session_tokens() -> Dict[str, Any]: + """ + Condition to filter out UI session tokens + """ + return { + "OR": [ + {"team_id": None}, # Include records where team_id is null + { + "team_id": {"not": UI_SESSION_TOKEN_TEAM_ID} + }, # Include records where team_id != UI_SESSION_TOKEN_TEAM_ID + ] + } + + +@router.post( + "/key/block", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def block_key( + data: BlockKeyRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +) -> Optional[LiteLLM_VerificationToken]: + """ + Block an Virtual key from making any requests. + + Parameters: + - key: str - The key to block. Can be either the unhashed key (sk-...) or the hashed key value + + Example: + ```bash + curl --location 'http://0.0.0.0:4000/key/block' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "key": "sk-Fn8Ej39NxjAXrvpUGKghGw" + }' + ``` + + Note: This is an admin-only endpoint. Only proxy admins can block keys. + """ + from litellm.proxy.proxy_server import ( + create_audit_log_for_update, + hash_token, + litellm_proxy_admin_name, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + if prisma_client is None: + raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value)) + + if data.key.startswith("sk-"): + hashed_token = hash_token(token=data.key) + else: + hashed_token = data.key + + if litellm.store_audit_logs is True: + # make an audit log for key update + record = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed_token} + ) + if record is None: + raise ProxyException( + message=f"Key {data.key} not found", + type=ProxyErrorTypes.bad_request_error, + param="key", + code=status.HTTP_404_NOT_FOUND, + ) + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=hashed_token, + action="blocked", + updated_values="{}", + before_value=record.model_dump_json(), + ) + ) + ) + + record = await prisma_client.db.litellm_verificationtoken.update( + where={"token": hashed_token}, data={"blocked": True} # type: ignore + ) + + ## UPDATE KEY CACHE + + ### get cached object ### + key_object = await get_key_object( + hashed_token=hashed_token, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=proxy_logging_obj, + ) + + ### update cached object ### + key_object.blocked = True + + ### store cached object ### + await _cache_key_object( + hashed_token=hashed_token, + user_api_key_obj=key_object, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + return record + + +@router.post( + "/key/unblock", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def unblock_key( + data: BlockKeyRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Unblock a Virtual key to allow it to make requests again. + + Parameters: + - key: str - The key to unblock. Can be either the unhashed key (sk-...) or the hashed key value + + Example: + ```bash + curl --location 'http://0.0.0.0:4000/key/unblock' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "key": "sk-Fn8Ej39NxjAXrvpUGKghGw" + }' + ``` + + Note: This is an admin-only endpoint. Only proxy admins can unblock keys. + """ + from litellm.proxy.proxy_server import ( + create_audit_log_for_update, + hash_token, + litellm_proxy_admin_name, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + if prisma_client is None: + raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value)) + + if data.key.startswith("sk-"): + hashed_token = hash_token(token=data.key) + else: + hashed_token = data.key + + if litellm.store_audit_logs is True: + # make an audit log for key update + record = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed_token} + ) + if record is None: + raise ProxyException( + message=f"Key {data.key} not found", + type=ProxyErrorTypes.bad_request_error, + param="key", + code=status.HTTP_404_NOT_FOUND, + ) + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=hashed_token, + action="blocked", + updated_values="{}", + before_value=record.model_dump_json(), + ) + ) + ) + + record = await prisma_client.db.litellm_verificationtoken.update( + where={"token": hashed_token}, data={"blocked": False} # type: ignore + ) + + ## UPDATE KEY CACHE + + ### get cached object ### + key_object = await get_key_object( + hashed_token=hashed_token, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=proxy_logging_obj, + ) + + ### update cached object ### + key_object.blocked = False + + ### store cached object ### + await _cache_key_object( + hashed_token=hashed_token, + user_api_key_obj=key_object, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + return record + + +@router.post( + "/key/health", + tags=["key management"], + dependencies=[Depends(user_api_key_auth)], + response_model=KeyHealthResponse, +) +@management_endpoint_wrapper +async def key_health( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Check the health of the key + + Checks: + - If key based logging is configured correctly - sends a test log + + Usage + + Pass the key in the request header + + ```bash + curl -X POST "http://localhost:4000/key/health" \ + -H "Authorization: Bearer sk-1234" \ + -H "Content-Type: application/json" + ``` + + Response when logging callbacks are setup correctly: + + ```json + { + "key": "healthy", + "logging_callbacks": { + "callbacks": [ + "gcs_bucket" + ], + "status": "healthy", + "details": "No logger exceptions triggered, system is healthy. Manually check if logs were sent to ['gcs_bucket']" + } + } + ``` + + + Response when logging callbacks are not setup correctly: + ```json + { + "key": "unhealthy", + "logging_callbacks": { + "callbacks": [ + "gcs_bucket" + ], + "status": "unhealthy", + "details": "Logger exceptions triggered, system is unhealthy: Failed to load vertex credentials. Check to see if credentials containing partial/invalid information." + } + } + ``` + """ + try: + # Get the key's metadata + key_metadata = user_api_key_dict.metadata + + health_status: KeyHealthResponse = KeyHealthResponse( + key="healthy", + logging_callbacks=None, + ) + + # Check if logging is configured in metadata + if key_metadata and "logging" in key_metadata: + logging_statuses = await test_key_logging( + user_api_key_dict=user_api_key_dict, + request=request, + key_logging=key_metadata["logging"], + ) + health_status["logging_callbacks"] = logging_statuses + + # Check if any logging callback is unhealthy + if logging_statuses.get("status") == "unhealthy": + health_status["key"] = "unhealthy" + + return KeyHealthResponse(**health_status) + + except Exception as e: + raise ProxyException( + message=f"Key health check failed: {str(e)}", + type=ProxyErrorTypes.internal_server_error, + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +def _can_user_query_key_info( + user_api_key_dict: UserAPIKeyAuth, + key: Optional[str], + key_info: LiteLLM_VerificationToken, +) -> bool: + """ + Helper to check if the user has access to the key's info + """ + if ( + user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value + or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value + ): + return True + elif user_api_key_dict.api_key == key: + return True + # user can query their own key info + elif key_info.user_id == user_api_key_dict.user_id: + return True + return False + + +async def test_key_logging( + user_api_key_dict: UserAPIKeyAuth, + request: Request, + key_logging: List[Dict[str, Any]], +) -> LoggingCallbackStatus: + """ + Test the key-based logging + + - Test that key logging is correctly formatted and all args are passed correctly + - Make a mock completion call -> user can check if it's correctly logged + - Check if any logger.exceptions were triggered -> if they were then returns it to the user client side + """ + import logging + from io import StringIO + + from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request + from litellm.proxy.proxy_server import general_settings, proxy_config + + logging_callbacks: List[str] = [] + for callback in key_logging: + if callback.get("callback_name") is not None: + logging_callbacks.append(callback["callback_name"]) + else: + raise ValueError("callback_name is required in key_logging") + + log_capture_string = StringIO() + ch = logging.StreamHandler(log_capture_string) + ch.setLevel(logging.ERROR) + logger = logging.getLogger() + logger.addHandler(ch) + + try: + data = { + "model": "openai/litellm-key-health-test", + "messages": [ + { + "role": "user", + "content": "Hello, this is a test from litellm /key/health. No LLM API call was made for this", + } + ], + "mock_response": "test response", + } + data = await add_litellm_data_to_request( + data=data, + user_api_key_dict=user_api_key_dict, + proxy_config=proxy_config, + general_settings=general_settings, + request=request, + ) + await litellm.acompletion( + **data + ) # make mock completion call to trigger key based callbacks + except Exception as e: + return LoggingCallbackStatus( + callbacks=logging_callbacks, + status="unhealthy", + details=f"Logging test failed: {str(e)}", + ) + + await asyncio.sleep( + 2 + ) # wait for callbacks to run, callbacks use batching so wait for the flush event + + # Check if any logger exceptions were triggered + log_contents = log_capture_string.getvalue() + logger.removeHandler(ch) + if log_contents: + return LoggingCallbackStatus( + callbacks=logging_callbacks, + status="unhealthy", + details=f"Logger exceptions triggered, system is unhealthy: {log_contents}", + ) + else: + return LoggingCallbackStatus( + callbacks=logging_callbacks, + status="healthy", + details=f"No logger exceptions triggered, system is healthy. Manually check if logs were sent to {logging_callbacks} ", + ) + + +async def _enforce_unique_key_alias( + key_alias: Optional[str], + prisma_client: Any, + existing_key_token: Optional[str] = None, +) -> None: + """ + Helper to enforce unique key aliases across all keys. + + Args: + key_alias (Optional[str]): The key alias to check + prisma_client (Any): Prisma client instance + existing_key_token (Optional[str]): ID of existing key being updated, to exclude from uniqueness check + (The Admin UI passes key_alias, in all Edit key requests. So we need to be sure that if we find a key with the same alias, it's not the same key we're updating) + + Raises: + ProxyException: If key alias already exists on a different key + """ + if key_alias is not None and prisma_client is not None: + where_clause: dict[str, Any] = {"key_alias": key_alias} + if existing_key_token: + # Exclude the current key from the uniqueness check + where_clause["NOT"] = {"token": existing_key_token} + + existing_key = await prisma_client.db.litellm_verificationtoken.find_first( + where=where_clause + ) + if existing_key is not None: + raise ProxyException( + message=f"Key with alias '{key_alias}' already exists. Unique key aliases across all keys are required.", + type=ProxyErrorTypes.bad_request_error, + param="key_alias", + code=status.HTTP_400_BAD_REQUEST, + ) + + +def validate_model_max_budget(model_max_budget: Optional[Dict]) -> None: + """ + Validate the model_max_budget is GenericBudgetConfigType + enforce user has an enterprise license + + Raises: + Exception: If model_max_budget is not a valid GenericBudgetConfigType + """ + try: + if model_max_budget is None: + return + if len(model_max_budget) == 0: + return + if model_max_budget is not None: + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user + + if premium_user is not True: + raise ValueError( + f"You must have an enterprise license to set model_max_budget. {CommonProxyErrors.not_premium_user.value}" + ) + for _model, _budget_info in model_max_budget.items(): + assert isinstance(_model, str) + + # /CRUD endpoints can pass budget_limit as a string, so we need to convert it to a float + if "budget_limit" in _budget_info: + _budget_info["budget_limit"] = float(_budget_info["budget_limit"]) + BudgetConfig(**_budget_info) + except Exception as e: + raise ValueError( + f"Invalid model_max_budget: {str(e)}. Example of valid model_max_budget: https://docs.litellm.ai/docs/proxy/users" + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/model_management_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/model_management_endpoints.py new file mode 100644 index 00000000..2a2b7eae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -0,0 +1,719 @@ +""" +Allow proxy admin to add/update/delete models in the db + +Currently most endpoints are in `proxy_server.py`, but those should be moved here over time. + +Endpoints here: + +model/{model_id}/update - PATCH endpoint for model update. +""" + +#### MODEL MANAGEMENT #### + +import asyncio +import json +import uuid +from typing import Optional, cast + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import BaseModel + +from litellm._logging import verbose_proxy_logger +from litellm.constants import LITELLM_PROXY_ADMIN_NAME +from litellm.proxy._types import ( + CommonProxyErrors, + LiteLLM_ProxyModelTable, + LitellmTableNames, + LitellmUserRoles, + ModelInfoDelete, + PrismaCompatibleUpdateDBModel, + ProxyErrorTypes, + ProxyException, + TeamModelAddRequest, + UpdateTeamRequest, + UserAPIKeyAuth, +) +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper +from litellm.proxy.management_endpoints.team_endpoints import ( + team_model_add, + update_team, +) +from litellm.proxy.management_helpers.audit_logs import create_object_audit_log +from litellm.proxy.utils import PrismaClient +from litellm.types.router import ( + Deployment, + DeploymentTypedDict, + LiteLLMParamsTypedDict, + updateDeployment, +) +from litellm.utils import get_utc_datetime + +router = APIRouter() + + +async def get_db_model( + model_id: str, prisma_client: PrismaClient +) -> Optional[Deployment]: + db_model = cast( + Optional[BaseModel], + await prisma_client.db.litellm_proxymodeltable.find_unique( + where={"model_id": model_id} + ), + ) + + if not db_model: + return None + + deployment_pydantic_obj = Deployment(**db_model.model_dump(exclude_none=True)) + return deployment_pydantic_obj + + +def update_db_model( + db_model: Deployment, updated_patch: updateDeployment +) -> PrismaCompatibleUpdateDBModel: + merged_deployment_dict = DeploymentTypedDict( + model_name=db_model.model_name, + litellm_params=LiteLLMParamsTypedDict( + **db_model.litellm_params.model_dump(exclude_none=True) # type: ignore + ), + ) + # update model name + if updated_patch.model_name: + merged_deployment_dict["model_name"] = updated_patch.model_name + + # update litellm params + if updated_patch.litellm_params: + # Encrypt any sensitive values + encrypted_params = { + k: encrypt_value_helper(v) + for k, v in updated_patch.litellm_params.model_dump( + exclude_none=True + ).items() + } + + merged_deployment_dict["litellm_params"].update(encrypted_params) # type: ignore + + # update model info + if updated_patch.model_info: + if "model_info" not in merged_deployment_dict: + merged_deployment_dict["model_info"] = {} + merged_deployment_dict["model_info"].update( + updated_patch.model_info.model_dump(exclude_none=True) + ) + + # convert to prisma compatible format + + prisma_compatible_model_dict = PrismaCompatibleUpdateDBModel() + if "model_name" in merged_deployment_dict: + prisma_compatible_model_dict["model_name"] = merged_deployment_dict[ + "model_name" + ] + + if "litellm_params" in merged_deployment_dict: + prisma_compatible_model_dict["litellm_params"] = json.dumps( + merged_deployment_dict["litellm_params"] + ) + + if "model_info" in merged_deployment_dict: + prisma_compatible_model_dict["model_info"] = json.dumps( + merged_deployment_dict["model_info"] + ) + return prisma_compatible_model_dict + + +@router.patch( + "/model/{model_id}/update", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], +) +async def patch_model( + model_id: str, # Get model_id from path parameter + patch_data: updateDeployment, # Create a specific schema for PATCH operations + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + PATCH Endpoint for partial model updates. + + Only updates the fields specified in the request while preserving other existing values. + Follows proper PATCH semantics by only modifying provided fields. + + Args: + model_id: The ID of the model to update + patch_data: The fields to update and their new values + user_api_key_dict: User authentication information + + Returns: + Updated model information + + Raises: + ProxyException: For various error conditions including authentication and database errors + """ + from litellm.proxy.proxy_server import ( + litellm_proxy_admin_name, + llm_router, + prisma_client, + store_model_in_db, + ) + + try: + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + # Verify model exists and is stored in DB + if not store_model_in_db: + raise ProxyException( + message="Model updates only supported for DB-stored models", + type=ProxyErrorTypes.validation_error.value, + code=status.HTTP_400_BAD_REQUEST, + param=None, + ) + + # Fetch existing model + db_model = await get_db_model(model_id=model_id, prisma_client=prisma_client) + + if db_model is None: + # Check if model exists in config but not DB + if llm_router and llm_router.get_deployment(model_id=model_id) is not None: + raise ProxyException( + message="Cannot edit config-based model. Store model in DB via /model/new first.", + type=ProxyErrorTypes.validation_error.value, + code=status.HTTP_400_BAD_REQUEST, + param=None, + ) + raise ProxyException( + message=f"Model {model_id} not found on proxy.", + type=ProxyErrorTypes.not_found_error, + code=status.HTTP_404_NOT_FOUND, + param=None, + ) + + # Create update dictionary only for provided fields + update_data = update_db_model(db_model=db_model, updated_patch=patch_data) + + # Add metadata about update + update_data["updated_by"] = ( + user_api_key_dict.user_id or litellm_proxy_admin_name + ) + update_data["updated_at"] = cast(str, get_utc_datetime()) + + # Perform partial update + updated_model = await prisma_client.db.litellm_proxymodeltable.update( + where={"model_id": model_id}, + data=update_data, + ) + + return updated_model + + except Exception as e: + verbose_proxy_logger.exception(f"Error in patch_model: {str(e)}") + + if isinstance(e, (HTTPException, ProxyException)): + raise e + + raise ProxyException( + message=f"Error updating model: {str(e)}", + type=ProxyErrorTypes.internal_server_error, + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + param=None, + ) + + +################################# Helper Functions ################################# +#################################################################################### +#################################################################################### +#################################################################################### + + +async def _add_model_to_db( + model_params: Deployment, + user_api_key_dict: UserAPIKeyAuth, + prisma_client: PrismaClient, + new_encryption_key: Optional[str] = None, + should_create_model_in_db: bool = True, +) -> Optional[LiteLLM_ProxyModelTable]: + # encrypt litellm params # + _litellm_params_dict = model_params.litellm_params.dict(exclude_none=True) + _orignal_litellm_model_name = model_params.litellm_params.model + for k, v in _litellm_params_dict.items(): + encrypted_value = encrypt_value_helper( + value=v, new_encryption_key=new_encryption_key + ) + model_params.litellm_params[k] = encrypted_value + _data: dict = { + "model_id": model_params.model_info.id, + "model_name": model_params.model_name, + "litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore + "model_info": model_params.model_info.model_dump_json( # type: ignore + exclude_none=True + ), + "created_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME, + "updated_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME, + } + if model_params.model_info.id is not None: + _data["model_id"] = model_params.model_info.id + if should_create_model_in_db: + model_response = await prisma_client.db.litellm_proxymodeltable.create( + data=_data # type: ignore + ) + else: + model_response = LiteLLM_ProxyModelTable(**_data) + return model_response + + +async def _add_team_model_to_db( + model_params: Deployment, + user_api_key_dict: UserAPIKeyAuth, + prisma_client: PrismaClient, +) -> Optional[LiteLLM_ProxyModelTable]: + """ + If 'team_id' is provided, + + - generate a unique 'model_name' for the model (e.g. 'model_name_{team_id}_{uuid}) + - store the model in the db with the unique 'model_name' + - store a team model alias mapping {"model_name": "model_name_{team_id}_{uuid}"} + """ + _team_id = model_params.model_info.team_id + if _team_id is None: + return None + original_model_name = model_params.model_name + if original_model_name: + model_params.model_info.team_public_model_name = original_model_name + + unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}" + + model_params.model_name = unique_model_name + + ## CREATE MODEL IN DB ## + model_response = await _add_model_to_db( + model_params=model_params, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + ) + + ## CREATE MODEL ALIAS IN DB ## + await update_team( + data=UpdateTeamRequest( + team_id=_team_id, + model_aliases={original_model_name: unique_model_name}, + ), + user_api_key_dict=user_api_key_dict, + http_request=Request(scope={"type": "http"}), + ) + + # add model to team object + await team_model_add( + data=TeamModelAddRequest( + team_id=_team_id, + models=[original_model_name], + ), + http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, + ) + + return model_response + + +def check_if_team_id_matches_key( + team_id: Optional[str], user_api_key_dict: UserAPIKeyAuth +) -> bool: + can_make_call = True + if ( + user_api_key_dict.user_role + and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN + ): + return True + if team_id is None: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: + can_make_call = False + else: + if user_api_key_dict.team_id != team_id: + can_make_call = False + return can_make_call + + +#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 +@router.post( + "/model/delete", + description="Allows deleting models in the model list in the config.yaml", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], +) +async def delete_model( + model_info: ModelInfoDelete, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + from litellm.proxy.proxy_server import llm_router + + try: + """ + [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 + + - Check if id in db + - Delete + """ + + from litellm.proxy.proxy_server import ( + llm_router, + prisma_client, + store_model_in_db, + ) + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={ + "error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys" + }, + ) + + # update DB + if store_model_in_db is True: + """ + - store model_list in db + - store keys separately + """ + # encrypt litellm params # + result = await prisma_client.db.litellm_proxymodeltable.delete( + where={"model_id": model_info.id} + ) + + if result is None: + raise HTTPException( + status_code=400, + detail={"error": f"Model with id={model_info.id} not found in db"}, + ) + + ## DELETE FROM ROUTER ## + if llm_router is not None: + llm_router.delete_deployment(id=model_info.id) + + ## CREATE AUDIT LOG ## + asyncio.create_task( + create_object_audit_log( + object_id=model_info.id, + action="deleted", + user_api_key_dict=user_api_key_dict, + table_name=LitellmTableNames.PROXY_MODEL_TABLE_NAME, + before_value=result.model_dump_json(exclude_none=True), + after_value=None, + litellm_changed_by=user_api_key_dict.user_id, + litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME, + ) + ) + return {"message": f"Model: {result.model_id} deleted successfully"} + else: + raise HTTPException( + status_code=500, + detail={ + "error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature." + }, + ) + + except Exception as e: + verbose_proxy_logger.exception( + f"Failed to delete model. Due to error - {str(e)}" + ) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + +#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 +@router.post( + "/model/new", + description="Allows adding new models to the model list in the config.yaml", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], +) +async def add_new_model( + model_params: Deployment, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + from litellm.proxy.proxy_server import ( + general_settings, + premium_user, + prisma_client, + proxy_config, + proxy_logging_obj, + store_model_in_db, + ) + + try: + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={ + "error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys" + }, + ) + + if model_params.model_info.team_id is not None and premium_user is not True: + raise HTTPException( + status_code=403, + detail={"error": CommonProxyErrors.not_premium_user.value}, + ) + + if not check_if_team_id_matches_key( + team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict + ): + raise HTTPException( + status_code=403, + detail={"error": "Team ID does not match the API key's team ID"}, + ) + + model_response: Optional[LiteLLM_ProxyModelTable] = None + # update DB + if store_model_in_db is True: + """ + - store model_list in db + - store keys separately + """ + + try: + _original_litellm_model_name = model_params.model_name + if model_params.model_info.team_id is None: + model_response = await _add_model_to_db( + model_params=model_params, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + ) + else: + model_response = await _add_team_model_to_db( + model_params=model_params, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + ) + await proxy_config.add_deployment( + prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj + ) + # don't let failed slack alert block the /model/new response + _alerting = general_settings.get("alerting", []) or [] + if "slack" in _alerting: + # send notification - new model added + await proxy_logging_obj.slack_alerting_instance.model_added_alert( + model_name=model_params.model_name, + litellm_model_name=_original_litellm_model_name, + passed_model_info=model_params.model_info, + ) + except Exception as e: + verbose_proxy_logger.exception(f"Exception in add_new_model: {e}") + + else: + raise HTTPException( + status_code=500, + detail={ + "error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature." + }, + ) + + if model_response is None: + raise HTTPException( + status_code=500, + detail={ + "error": "Failed to add model to db. Check your server logs for more details." + }, + ) + + ## CREATE AUDIT LOG ## + asyncio.create_task( + create_object_audit_log( + object_id=model_response.model_id, + action="created", + user_api_key_dict=user_api_key_dict, + table_name=LitellmTableNames.PROXY_MODEL_TABLE_NAME, + before_value=None, + after_value=( + model_response.model_dump_json(exclude_none=True) + if isinstance(model_response, BaseModel) + else None + ), + litellm_changed_by=user_api_key_dict.user_id, + litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME, + ) + ) + + return model_response + + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.add_new_model(): Exception occured - {}".format( + str(e) + ) + ) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + +#### MODEL MANAGEMENT #### +@router.post( + "/model/update", + description="Edit existing model params", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], +) +async def update_model( + model_params: updateDeployment, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Old endpoint for model update. Makes a PUT request. + + Use `/model/{model_id}/update` to PATCH the stored model in db. + """ + from litellm.proxy.proxy_server import ( + LITELLM_PROXY_ADMIN_NAME, + llm_router, + prisma_client, + store_model_in_db, + ) + + try: + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={ + "error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys" + }, + ) + # update DB + if store_model_in_db is True: + _model_id = None + _model_info = getattr(model_params, "model_info", None) + if _model_info is None: + raise Exception("model_info not provided") + + _model_id = _model_info.id + if _model_id is None: + raise Exception("model_info.id not provided") + _existing_litellm_params = ( + await prisma_client.db.litellm_proxymodeltable.find_unique( + where={"model_id": _model_id} + ) + ) + if _existing_litellm_params is None: + if ( + llm_router is not None + and llm_router.get_deployment(model_id=_model_id) is not None + ): + raise HTTPException( + status_code=400, + detail={ + "error": "Can't edit model. Model in config. Store model in db via `/model/new`. to edit." + }, + ) + raise Exception("model not found") + _existing_litellm_params_dict = dict( + _existing_litellm_params.litellm_params + ) + + if model_params.litellm_params is None: + raise Exception("litellm_params not provided") + + _new_litellm_params_dict = model_params.litellm_params.dict( + exclude_none=True + ) + + ### ENCRYPT PARAMS ### + for k, v in _new_litellm_params_dict.items(): + encrypted_value = encrypt_value_helper(value=v) + model_params.litellm_params[k] = encrypted_value + + ### MERGE WITH EXISTING DATA ### + merged_dictionary = {} + _mp = model_params.litellm_params.dict() + + for key, value in _mp.items(): + if value is not None: + merged_dictionary[key] = value + elif ( + key in _existing_litellm_params_dict + and _existing_litellm_params_dict[key] is not None + ): + merged_dictionary[key] = _existing_litellm_params_dict[key] + else: + pass + + _data: dict = { + "litellm_params": json.dumps(merged_dictionary), # type: ignore + "updated_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME, + } + model_response = await prisma_client.db.litellm_proxymodeltable.update( + where={"model_id": _model_id}, + data=_data, # type: ignore + ) + + ## CREATE AUDIT LOG ## + asyncio.create_task( + create_object_audit_log( + object_id=_model_id, + action="updated", + user_api_key_dict=user_api_key_dict, + table_name=LitellmTableNames.PROXY_MODEL_TABLE_NAME, + before_value=( + _existing_litellm_params.model_dump_json(exclude_none=True) + if isinstance(_existing_litellm_params, BaseModel) + else None + ), + after_value=( + model_response.model_dump_json(exclude_none=True) + if isinstance(model_response, BaseModel) + else None + ), + litellm_changed_by=user_api_key_dict.user_id, + litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME, + ) + ) + + return model_response + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.update_model(): Exception occured - {}".format( + str(e) + ) + ) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/organization_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/organization_endpoints.py new file mode 100644 index 00000000..c202043f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/organization_endpoints.py @@ -0,0 +1,823 @@ +""" +Endpoints for /organization operations + +/organization/new +/organization/update +/organization/delete +/organization/member_add +/organization/info +/organization/list +""" + +#### ORGANIZATION MANAGEMENT #### + +import uuid +from typing import List, Optional, Tuple + +from fastapi import APIRouter, Depends, HTTPException, Request, status + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import * +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.management_endpoints.budget_management_endpoints import ( + new_budget, + update_budget, +) +from litellm.proxy.management_helpers.utils import ( + get_new_internal_user_defaults, + management_endpoint_wrapper, +) +from litellm.proxy.utils import PrismaClient + +router = APIRouter() + + +@router.post( + "/organization/new", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], + response_model=NewOrganizationResponse, +) +async def new_organization( + data: NewOrganizationRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Allow orgs to own teams + + Set org level budgets + model access. + + Only admins can create orgs. + + # Parameters + + - organization_alias: *str* - The name of the organization. + - models: *List* - The models the organization has access to. + - budget_id: *Optional[str]* - The id for a budget (tpm/rpm/max budget) for the organization. + ### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ### + - max_budget: *Optional[float]* - Max budget for org + - tpm_limit: *Optional[int]* - Max tpm limit for org + - rpm_limit: *Optional[int]* - Max rpm limit for org + - max_parallel_requests: *Optional[int]* - [Not Implemented Yet] Max parallel requests for org + - soft_budget: *Optional[float]* - [Not Implemented Yet] Get a slack alert when this soft budget is reached. Don't block requests. + - model_max_budget: *Optional[dict]* - Max budget for a specific model + - budget_duration: *Optional[str]* - Frequency of reseting org budget + - metadata: *Optional[dict]* - Metadata for organization, store information for organization. Example metadata - {"extra_info": "some info"} + - blocked: *bool* - Flag indicating if the org is blocked or not - will stop all calls from keys with this org_id. + - tags: *Optional[List[str]]* - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing). + - organization_id: *Optional[str]* - The organization id of the team. Default is None. Create via `/organization/new`. + - model_aliases: Optional[dict] - Model aliases for the team. [Docs](https://docs.litellm.ai/docs/proxy/team_based_routing#create-team-with-model-alias) + + Case 1: Create new org **without** a budget_id + + ```bash + curl --location 'http://0.0.0.0:4000/organization/new' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data '{ + "organization_alias": "my-secret-org", + "models": ["model1", "model2"], + "max_budget": 100 + }' + + + ``` + + Case 2: Create new org **with** a budget_id + + ```bash + curl --location 'http://0.0.0.0:4000/organization/new' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data '{ + "organization_alias": "my-secret-org", + "models": ["model1", "model2"], + "budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689" + }' + ``` + """ + + from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if ( + user_api_key_dict.user_role is None + or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN + ): + raise HTTPException( + status_code=401, + detail={ + "error": f"Only admins can create orgs. Your role is = {user_api_key_dict.user_role}" + }, + ) + + if data.budget_id is None: + """ + Every organization needs a budget attached. + + If none provided, create one based on provided values + """ + budget_params = LiteLLM_BudgetTable.model_fields.keys() + + # Only include Budget Params when creating an entry in litellm_budgettable + _json_data = data.json(exclude_none=True) + _budget_data = {k: v for k, v in _json_data.items() if k in budget_params} + budget_row = LiteLLM_BudgetTable(**_budget_data) + + new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True)) + + _budget = await prisma_client.db.litellm_budgettable.create( + data={ + **new_budget, # type: ignore + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } + ) # type: ignore + + data.budget_id = _budget.budget_id + + """ + Ensure only models that user has access to, are given to org + """ + if len(user_api_key_dict.models) == 0: # user has access to all models + pass + else: + if len(data.models) == 0: + raise HTTPException( + status_code=400, + detail={ + "error": "User not allowed to give access to all models. Select models you want org to have access to." + }, + ) + for m in data.models: + if m not in user_api_key_dict.models: + raise HTTPException( + status_code=400, + detail={ + "error": f"User not allowed to give access to model={m}. Models you have access to = {user_api_key_dict.models}" + }, + ) + + organization_row = LiteLLM_OrganizationTable( + **data.json(exclude_none=True), + created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + ) + new_organization_row = prisma_client.jsonify_object( + organization_row.json(exclude_none=True) + ) + verbose_proxy_logger.info( + f"new_organization_row: {json.dumps(new_organization_row, indent=2)}" + ) + response = await prisma_client.db.litellm_organizationtable.create( + data={ + **new_organization_row, # type: ignore + } + ) + + return response + + +@router.patch( + "/organization/update", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], + response_model=LiteLLM_OrganizationTableWithMembers, +) +async def update_organization( + data: LiteLLM_OrganizationTableUpdate, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Update an organization + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + if user_api_key_dict.user_id is None: + raise HTTPException( + status_code=400, + detail={ + "error": "Cannot associate a user_id to this action. Check `/key/info` to validate if 'user_id' is set." + }, + ) + + if data.updated_by is None: + data.updated_by = user_api_key_dict.user_id + + updated_organization_row = prisma_client.jsonify_object( + data.model_dump(exclude_none=True) + ) + + response = await prisma_client.db.litellm_organizationtable.update( + where={"organization_id": data.organization_id}, + data=updated_organization_row, + include={"members": True, "teams": True, "litellm_budget_table": True}, + ) + + return response + + +@router.delete( + "/organization/delete", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], + response_model=List[LiteLLM_OrganizationTableWithMembers], +) +async def delete_organization( + data: DeleteOrganizationRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Delete an organization + + # Parameters: + + - organization_ids: List[str] - The organization ids to delete. + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: + raise HTTPException( + status_code=401, + detail={"error": "Only proxy admins can delete organizations"}, + ) + + deleted_orgs = [] + for organization_id in data.organization_ids: + # delete all teams in the organization + await prisma_client.db.litellm_teamtable.delete_many( + where={"organization_id": organization_id} + ) + # delete all members in the organization + await prisma_client.db.litellm_organizationmembership.delete_many( + where={"organization_id": organization_id} + ) + # delete all keys in the organization + await prisma_client.db.litellm_verificationtoken.delete_many( + where={"organization_id": organization_id} + ) + # delete the organization + deleted_org = await prisma_client.db.litellm_organizationtable.delete( + where={"organization_id": organization_id}, + include={"members": True, "teams": True, "litellm_budget_table": True}, + ) + if deleted_org is None: + raise HTTPException( + status_code=404, + detail={"error": f"Organization={organization_id} not found"}, + ) + deleted_orgs.append(deleted_org) + + return deleted_orgs + + +@router.get( + "/organization/list", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], + response_model=List[LiteLLM_OrganizationTableWithMembers], +) +async def list_organization( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + ``` + curl --location --request GET 'http://0.0.0.0:4000/organization/list' \ + --header 'Authorization: Bearer sk-1234' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if prisma_client is None: + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + # if proxy admin - get all orgs + if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN: + response = await prisma_client.db.litellm_organizationtable.find_many( + include={"members": True, "teams": True} + ) + # if internal user - get orgs they are a member of + else: + org_memberships = ( + await prisma_client.db.litellm_organizationmembership.find_many( + where={"user_id": user_api_key_dict.user_id} + ) + ) + org_objects = await prisma_client.db.litellm_organizationtable.find_many( + where={ + "organization_id": { + "in": [membership.organization_id for membership in org_memberships] + } + }, + include={"members": True, "teams": True}, + ) + + response = org_objects + + return response + + +@router.get( + "/organization/info", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], + response_model=LiteLLM_OrganizationTableWithMembers, +) +async def info_organization(organization_id: str): + """ + Get the org specific information + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + response: Optional[LiteLLM_OrganizationTableWithMembers] = ( + await prisma_client.db.litellm_organizationtable.find_unique( + where={"organization_id": organization_id}, + include={"litellm_budget_table": True, "members": True, "teams": True}, + ) + ) + + if response is None: + raise HTTPException(status_code=404, detail={"error": "Organization not found"}) + + response_pydantic_obj = LiteLLM_OrganizationTableWithMembers( + **response.model_dump() + ) + + return response_pydantic_obj + + +@router.post( + "/organization/info", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], +) +async def deprecated_info_organization(data: OrganizationRequest): + """ + DEPRECATED: Use GET /organization/info instead + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if len(data.organizations) == 0: + raise HTTPException( + status_code=400, + detail={ + "error": f"Specify list of organization id's to query. Passed in={data.organizations}" + }, + ) + response = await prisma_client.db.litellm_organizationtable.find_many( + where={"organization_id": {"in": data.organizations}}, + include={"litellm_budget_table": True}, + ) + + return response + + +@router.post( + "/organization/member_add", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], + response_model=OrganizationAddMemberResponse, +) +@management_endpoint_wrapper +async def organization_member_add( + data: OrganizationMemberAddRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +) -> OrganizationAddMemberResponse: + """ + [BETA] + + Add new members (either via user_email or user_id) to an organization + + If user doesn't exist, new user row will also be added to User Table + + Only proxy_admin or org_admin of organization, allowed to access this endpoint. + + # Parameters: + + - organization_id: str (required) + - member: Union[List[Member], Member] (required) + - role: Literal[LitellmUserRoles] (required) + - user_id: Optional[str] + - user_email: Optional[str] + + Note: Either user_id or user_email must be provided for each member. + + Example: + ``` + curl -X POST 'http://0.0.0.0:4000/organization/member_add' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{ + "organization_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", + "member": { + "role": "internal_user", + "user_id": "krrish247652@berri.ai" + }, + "max_budget_in_organization": 100.0 + }' + ``` + + The following is executed in this function: + + 1. Check if organization exists + 2. Creates a new Internal User if the user_id or user_email is not found in LiteLLM_UserTable + 3. Add Internal User to the `LiteLLM_OrganizationMembership` table + """ + try: + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + # Check if organization exists + existing_organization_row = ( + await prisma_client.db.litellm_organizationtable.find_unique( + where={"organization_id": data.organization_id} + ) + ) + if existing_organization_row is None: + raise HTTPException( + status_code=404, + detail={ + "error": f"Organization not found for organization_id={getattr(data, 'organization_id', None)}" + }, + ) + + members: List[OrgMember] + if isinstance(data.member, List): + members = data.member + else: + members = [data.member] + + updated_users: List[LiteLLM_UserTable] = [] + updated_organization_memberships: List[LiteLLM_OrganizationMembershipTable] = [] + + for member in members: + updated_user, updated_organization_membership = ( + await add_member_to_organization( + member=member, + organization_id=data.organization_id, + prisma_client=prisma_client, + ) + ) + + updated_users.append(updated_user) + updated_organization_memberships.append(updated_organization_membership) + + return OrganizationAddMemberResponse( + organization_id=data.organization_id, + updated_users=updated_users, + updated_organization_memberships=updated_organization_memberships, + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error adding member to organization: {e}") + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +async def find_member_if_email( + user_email: str, prisma_client: PrismaClient +) -> LiteLLM_UserTable: + """ + Find a member if the user_email is in LiteLLM_UserTable + """ + + try: + existing_user_email_row: BaseModel = ( + await prisma_client.db.litellm_usertable.find_unique( + where={"user_email": user_email} + ) + ) + except Exception: + raise HTTPException( + status_code=400, + detail={ + "error": f"Unique user not found for user_email={user_email}. Potential duplicate OR non-existent user_email in LiteLLM_UserTable. Use 'user_id' instead." + }, + ) + existing_user_email_row_pydantic = LiteLLM_UserTable( + **existing_user_email_row.model_dump() + ) + return existing_user_email_row_pydantic + + +@router.patch( + "/organization/member_update", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], + response_model=LiteLLM_OrganizationMembershipTable, +) +@management_endpoint_wrapper +async def organization_member_update( + data: OrganizationMemberUpdateRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Update a member's role in an organization + """ + try: + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + # Check if organization exists + existing_organization_row = ( + await prisma_client.db.litellm_organizationtable.find_unique( + where={"organization_id": data.organization_id} + ) + ) + if existing_organization_row is None: + raise HTTPException( + status_code=400, + detail={ + "error": f"Organization not found for organization_id={getattr(data, 'organization_id', None)}" + }, + ) + + # Check if member exists in organization + if data.user_email is not None and data.user_id is None: + existing_user_email_row = await find_member_if_email( + data.user_email, prisma_client + ) + data.user_id = existing_user_email_row.user_id + + try: + existing_organization_membership = ( + await prisma_client.db.litellm_organizationmembership.find_unique( + where={ + "user_id_organization_id": { + "user_id": data.user_id, + "organization_id": data.organization_id, + } + } + ) + ) + except Exception as e: + raise HTTPException( + status_code=400, + detail={ + "error": f"Error finding organization membership for user_id={data.user_id} in organization={data.organization_id}: {e}" + }, + ) + if existing_organization_membership is None: + raise HTTPException( + status_code=404, + detail={ + "error": f"Member not found in organization for user_id={data.user_id}" + }, + ) + + # Update member role + if data.role is not None: + await prisma_client.db.litellm_organizationmembership.update( + where={ + "user_id_organization_id": { + "user_id": data.user_id, + "organization_id": data.organization_id, + } + }, + data={"user_role": data.role}, + ) + if data.max_budget_in_organization is not None: + # if budget_id is None, create a new budget + budget_id = existing_organization_membership.budget_id or str(uuid.uuid4()) + if existing_organization_membership.budget_id is None: + new_budget_obj = BudgetNewRequest( + budget_id=budget_id, max_budget=data.max_budget_in_organization + ) + await new_budget( + budget_obj=new_budget_obj, user_api_key_dict=user_api_key_dict + ) + else: + # update budget table with new max_budget + await update_budget( + budget_obj=BudgetNewRequest( + budget_id=budget_id, max_budget=data.max_budget_in_organization + ), + user_api_key_dict=user_api_key_dict, + ) + + # update organization membership with new budget_id + await prisma_client.db.litellm_organizationmembership.update( + where={ + "user_id_organization_id": { + "user_id": data.user_id, + "organization_id": data.organization_id, + } + }, + data={"budget_id": budget_id}, + ) + final_organization_membership: Optional[BaseModel] = ( + await prisma_client.db.litellm_organizationmembership.find_unique( + where={ + "user_id_organization_id": { + "user_id": data.user_id, + "organization_id": data.organization_id, + } + }, + include={"litellm_budget_table": True}, + ) + ) + + if final_organization_membership is None: + raise HTTPException( + status_code=400, + detail={ + "error": f"Member not found in organization={data.organization_id} for user_id={data.user_id}" + }, + ) + + final_organization_membership_pydantic = LiteLLM_OrganizationMembershipTable( + **final_organization_membership.model_dump(exclude_none=True) + ) + return final_organization_membership_pydantic + except Exception as e: + verbose_proxy_logger.exception(f"Error updating member in organization: {e}") + raise e + + +@router.delete( + "/organization/member_delete", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], +) +async def organization_member_delete( + data: OrganizationMemberDeleteRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Delete a member from an organization + """ + try: + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + if data.user_email is not None and data.user_id is None: + existing_user_email_row = await find_member_if_email( + data.user_email, prisma_client + ) + data.user_id = existing_user_email_row.user_id + + member_to_delete = await prisma_client.db.litellm_organizationmembership.delete( + where={ + "user_id_organization_id": { + "user_id": data.user_id, + "organization_id": data.organization_id, + } + } + ) + return member_to_delete + + except Exception as e: + verbose_proxy_logger.exception(f"Error deleting member from organization: {e}") + raise e + + +async def add_member_to_organization( + member: OrgMember, + organization_id: str, + prisma_client: PrismaClient, +) -> Tuple[LiteLLM_UserTable, LiteLLM_OrganizationMembershipTable]: + """ + Add a member to an organization + + - Checks if member.user_id or member.user_email is in LiteLLM_UserTable + - If not found, create a new user in LiteLLM_UserTable + - Add user to organization in LiteLLM_OrganizationMembership + """ + + try: + user_object: Optional[LiteLLM_UserTable] = None + existing_user_id_row = None + existing_user_email_row = None + ## Check if user exists in LiteLLM_UserTable - user exists - either the user_id or user_email is in LiteLLM_UserTable + if member.user_id is not None: + existing_user_id_row = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": member.user_id} + ) + + if existing_user_id_row is None and member.user_email is not None: + try: + existing_user_email_row = ( + await prisma_client.db.litellm_usertable.find_unique( + where={"user_email": member.user_email} + ) + ) + except Exception as e: + raise ValueError( + f"Potential NON-Existent or Duplicate user email in DB: Error finding a unique instance of user_email={member.user_email} in LiteLLM_UserTable.: {e}" + ) + + ## If user does not exist, create a new user + if existing_user_id_row is None and existing_user_email_row is None: + # Create a new user - since user does not exist + user_id: str = member.user_id or str(uuid.uuid4()) + new_user_defaults = get_new_internal_user_defaults( + user_id=user_id, + user_email=member.user_email, + ) + + _returned_user = await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore + if _returned_user is not None: + user_object = LiteLLM_UserTable(**_returned_user.model_dump()) + elif existing_user_email_row is not None and len(existing_user_email_row) > 1: + raise HTTPException( + status_code=400, + detail={ + "error": "Multiple users with this email found in db. Please use 'user_id' instead." + }, + ) + elif existing_user_email_row is not None: + user_object = LiteLLM_UserTable(**existing_user_email_row.model_dump()) + elif existing_user_id_row is not None: + user_object = LiteLLM_UserTable(**existing_user_id_row.model_dump()) + else: + raise HTTPException( + status_code=404, + detail={ + "error": f"User not found for user_id={member.user_id} and user_email={member.user_email}" + }, + ) + + if user_object is None: + raise ValueError( + f"User does not exist in LiteLLM_UserTable. user_id={member.user_id} and user_email={member.user_email}" + ) + + # Add user to organization + _organization_membership = ( + await prisma_client.db.litellm_organizationmembership.create( + data={ + "organization_id": organization_id, + "user_id": user_object.user_id, + "user_role": member.role, + } + ) + ) + organization_membership = LiteLLM_OrganizationMembershipTable( + **_organization_membership.model_dump() + ) + return user_object, organization_membership + + except Exception as e: + import traceback + + traceback.print_exc() + raise ValueError( + f"Error adding member={member} to organization={organization_id}: {e}" + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/sso_helper_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/sso_helper_utils.py new file mode 100644 index 00000000..45906b2f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/sso_helper_utils.py @@ -0,0 +1,22 @@ +from litellm.proxy._types import LitellmUserRoles + + +def check_is_admin_only_access(ui_access_mode: str) -> bool: + """Checks ui access mode is admin_only""" + return ui_access_mode == "admin_only" + + +def has_admin_ui_access(user_role: str) -> bool: + """ + Check if the user has admin access to the UI. + + Returns: + bool: True if user is 'proxy_admin' or 'proxy_admin_view_only', False otherwise. + """ + + if ( + user_role != LitellmUserRoles.PROXY_ADMIN.value + and user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value + ): + return False + return True diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/team_callback_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/team_callback_endpoints.py new file mode 100644 index 00000000..93d338a4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/team_callback_endpoints.py @@ -0,0 +1,383 @@ +""" +Endpoints to control callbacks per team + +Use this when each team should control its own callbacks +""" + +import json +import traceback +from typing import Optional + +from fastapi import APIRouter, Depends, Header, HTTPException, Request, status + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + AddTeamCallback, + ProxyErrorTypes, + ProxyException, + TeamCallbackMetadata, + UserAPIKeyAuth, +) +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.management_helpers.utils import management_endpoint_wrapper + +router = APIRouter() + + +@router.post( + "/team/{team_id:path}/callback", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def add_team_callbacks( + data: AddTeamCallback, + http_request: Request, + team_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Add a success/failure callback to a team + + Use this if if you want different teams to have different success/failure callbacks + + Parameters: + - callback_name (Literal["langfuse", "langsmith", "gcs"], required): The name of the callback to add + - callback_type (Literal["success", "failure", "success_and_failure"], required): The type of callback to add. One of: + - "success": Callback for successful LLM calls + - "failure": Callback for failed LLM calls + - "success_and_failure": Callback for both successful and failed LLM calls + - callback_vars (StandardCallbackDynamicParams, required): A dictionary of variables to pass to the callback + - langfuse_public_key: The public key for the Langfuse callback + - langfuse_secret_key: The secret key for the Langfuse callback + - langfuse_secret: The secret for the Langfuse callback + - langfuse_host: The host for the Langfuse callback + - gcs_bucket_name: The name of the GCS bucket + - gcs_path_service_account: The path to the GCS service account + - langsmith_api_key: The API key for the Langsmith callback + - langsmith_project: The project for the Langsmith callback + - langsmith_base_url: The base URL for the Langsmith callback + + Example curl: + ``` + curl -X POST 'http:/localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/callback' \ + -H 'Content-Type: application/json' \ + -H 'Authorization: Bearer sk-1234' \ + -d '{ + "callback_name": "langfuse", + "callback_type": "success", + "callback_vars": {"langfuse_public_key": "pk-lf-xxxx1", "langfuse_secret_key": "sk-xxxxx"} + + }' + ``` + + This means for the team where team_id = dbe2f686-a686-4896-864a-4c3924458709, all LLM calls will be logged to langfuse using the public key pk-lf-xxxx1 and the secret key sk-xxxxx + + """ + try: + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + # Check if team_id exists already + _existing_team = await prisma_client.get_data( + team_id=team_id, table_name="team", query_type="find_unique" + ) + if _existing_team is None: + raise HTTPException( + status_code=400, + detail={ + "error": f"Team id = {team_id} does not exist. Please use a different team id." + }, + ) + + # store team callback settings in metadata + team_metadata = _existing_team.metadata + team_callback_settings = team_metadata.get("callback_settings", {}) + # expect callback settings to be + team_callback_settings_obj = TeamCallbackMetadata(**team_callback_settings) + if data.callback_type == "success": + if team_callback_settings_obj.success_callback is None: + team_callback_settings_obj.success_callback = [] + + if data.callback_name in team_callback_settings_obj.success_callback: + raise ProxyException( + message=f"callback_name = {data.callback_name} already exists in failure_callback, for team_id = {team_id}. \n Existing failure_callback = {team_callback_settings_obj.success_callback}", + code=status.HTTP_400_BAD_REQUEST, + type=ProxyErrorTypes.bad_request_error, + param="callback_name", + ) + + team_callback_settings_obj.success_callback.append(data.callback_name) + elif data.callback_type == "failure": + if team_callback_settings_obj.failure_callback is None: + team_callback_settings_obj.failure_callback = [] + + if data.callback_name in team_callback_settings_obj.failure_callback: + raise ProxyException( + message=f"callback_name = {data.callback_name} already exists in failure_callback, for team_id = {team_id}. \n Existing failure_callback = {team_callback_settings_obj.failure_callback}", + code=status.HTTP_400_BAD_REQUEST, + type=ProxyErrorTypes.bad_request_error, + param="callback_name", + ) + team_callback_settings_obj.failure_callback.append(data.callback_name) + elif data.callback_type == "success_and_failure": + if team_callback_settings_obj.success_callback is None: + team_callback_settings_obj.success_callback = [] + if team_callback_settings_obj.failure_callback is None: + team_callback_settings_obj.failure_callback = [] + if data.callback_name in team_callback_settings_obj.success_callback: + raise ProxyException( + message=f"callback_name = {data.callback_name} already exists in success_callback, for team_id = {team_id}. \n Existing success_callback = {team_callback_settings_obj.success_callback}", + code=status.HTTP_400_BAD_REQUEST, + type=ProxyErrorTypes.bad_request_error, + param="callback_name", + ) + + if data.callback_name in team_callback_settings_obj.failure_callback: + raise ProxyException( + message=f"callback_name = {data.callback_name} already exists in failure_callback, for team_id = {team_id}. \n Existing failure_callback = {team_callback_settings_obj.failure_callback}", + code=status.HTTP_400_BAD_REQUEST, + type=ProxyErrorTypes.bad_request_error, + param="callback_name", + ) + + team_callback_settings_obj.success_callback.append(data.callback_name) + team_callback_settings_obj.failure_callback.append(data.callback_name) + for var, value in data.callback_vars.items(): + if team_callback_settings_obj.callback_vars is None: + team_callback_settings_obj.callback_vars = {} + team_callback_settings_obj.callback_vars[var] = value + + team_callback_settings_obj_dict = team_callback_settings_obj.model_dump() + + team_metadata["callback_settings"] = team_callback_settings_obj_dict + team_metadata_json = json.dumps(team_metadata) # update team_metadata + + new_team_row = await prisma_client.db.litellm_teamtable.update( + where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore + ) + + return { + "status": "success", + "data": new_team_row, + } + + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.add_team_callbacks(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Internal Server Error({str(e)})"), + type=ProxyErrorTypes.internal_server_error.value, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Internal Server Error, " + str(e), + type=ProxyErrorTypes.internal_server_error.value, + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +@router.post( + "/team/{team_id}/disable_logging", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def disable_team_logging( + http_request: Request, + team_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Disable all logging callbacks for a team + + Parameters: + - team_id (str, required): The unique identifier for the team + + Example curl: + ``` + curl -X POST 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/disable_logging' \ + -H 'Authorization: Bearer sk-1234' + ``` + + + """ + try: + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + # Check if team exists + _existing_team = await prisma_client.get_data( + team_id=team_id, table_name="team", query_type="find_unique" + ) + if _existing_team is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team id = {team_id} does not exist."}, + ) + + # Update team metadata to disable logging + team_metadata = _existing_team.metadata + team_callback_settings = team_metadata.get("callback_settings", {}) + team_callback_settings_obj = TeamCallbackMetadata(**team_callback_settings) + + # Reset callbacks + team_callback_settings_obj.success_callback = [] + team_callback_settings_obj.failure_callback = [] + + # Update metadata + team_metadata["callback_settings"] = team_callback_settings_obj.model_dump() + team_metadata_json = json.dumps(team_metadata) + + # Update team in database + updated_team = await prisma_client.db.litellm_teamtable.update( + where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore + ) + + if updated_team is None: + raise HTTPException( + status_code=404, + detail={ + "error": f"Team id = {team_id} does not exist. Error updating team logging" + }, + ) + + return { + "status": "success", + "message": f"Logging disabled for team {team_id}", + "data": { + "team_id": updated_team.team_id, + "success_callbacks": [], + "failure_callbacks": [], + }, + } + + except Exception as e: + verbose_proxy_logger.error( + f"litellm.proxy.proxy_server.disable_team_logging(): Exception occurred - {str(e)}" + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Internal Server Error({str(e)})"), + type=ProxyErrorTypes.internal_server_error.value, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Internal Server Error, " + str(e), + type=ProxyErrorTypes.internal_server_error.value, + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +@router.get( + "/team/{team_id:path}/callback", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def get_team_callbacks( + http_request: Request, + team_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Get the success/failure callbacks and variables for a team + + Parameters: + - team_id (str, required): The unique identifier for the team + + Example curl: + ``` + curl -X GET 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/callback' \ + -H 'Authorization: Bearer sk-1234' + ``` + + This will return the callback settings for the team with id dbe2f686-a686-4896-864a-4c3924458709 + + Returns { + "status": "success", + "data": { + "team_id": team_id, + "success_callbacks": team_callback_settings_obj.success_callback, + "failure_callbacks": team_callback_settings_obj.failure_callback, + "callback_vars": team_callback_settings_obj.callback_vars, + }, + } + """ + try: + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + # Check if team_id exists + _existing_team = await prisma_client.get_data( + team_id=team_id, table_name="team", query_type="find_unique" + ) + if _existing_team is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team id = {team_id} does not exist."}, + ) + + # Retrieve team callback settings from metadata + team_metadata = _existing_team.metadata + team_callback_settings = team_metadata.get("callback_settings", {}) + + # Convert to TeamCallbackMetadata object for consistent structure + team_callback_settings_obj = TeamCallbackMetadata(**team_callback_settings) + + return { + "status": "success", + "data": { + "team_id": team_id, + "success_callbacks": team_callback_settings_obj.success_callback, + "failure_callbacks": team_callback_settings_obj.failure_callback, + "callback_vars": team_callback_settings_obj.callback_vars, + }, + } + + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.get_team_callbacks(): Exception occurred - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Internal Server Error({str(e)})"), + type=ProxyErrorTypes.internal_server_error.value, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Internal Server Error, " + str(e), + type=ProxyErrorTypes.internal_server_error.value, + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/team_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/team_endpoints.py new file mode 100644 index 00000000..f5bcc6ba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/team_endpoints.py @@ -0,0 +1,1928 @@ +""" +TEAM MANAGEMENT + +All /team management endpoints + +/team/new +/team/info +/team/update +/team/delete +""" + +import asyncio +import json +import traceback +import uuid +from datetime import datetime, timedelta, timezone +from typing import List, Optional, Tuple, Union, cast + +import fastapi +from fastapi import APIRouter, Depends, Header, HTTPException, Request, status +from pydantic import BaseModel + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + BlockTeamRequest, + CommonProxyErrors, + DeleteTeamRequest, + LiteLLM_AuditLogs, + LiteLLM_ManagementEndpoint_MetadataFields_Premium, + LiteLLM_ModelTable, + LiteLLM_TeamMembership, + LiteLLM_TeamTable, + LiteLLM_TeamTableCachedObj, + LiteLLM_UserTable, + LitellmTableNames, + LitellmUserRoles, + Member, + NewTeamRequest, + ProxyErrorTypes, + ProxyException, + SpecialManagementEndpointEnums, + SpecialModelNames, + TeamAddMemberResponse, + TeamInfoResponseObject, + TeamListResponseObject, + TeamMemberAddRequest, + TeamMemberDeleteRequest, + TeamMemberUpdateRequest, + TeamMemberUpdateResponse, + TeamModelAddRequest, + TeamModelDeleteRequest, + UpdateTeamRequest, + UserAPIKeyAuth, +) +from litellm.proxy.auth.auth_checks import ( + allowed_route_check_inside_route, + get_team_object, +) +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.management_endpoints.common_utils import ( + _is_user_team_admin, + _set_object_metadata_field, +) +from litellm.proxy.management_helpers.utils import ( + add_new_member, + management_endpoint_wrapper, +) +from litellm.proxy.utils import ( + PrismaClient, + _premium_user_check, + handle_exception_on_proxy, +) +from litellm.router import Router + +router = APIRouter() + + +def _is_available_team(team_id: str, user_api_key_dict: UserAPIKeyAuth) -> bool: + if litellm.default_internal_user_params is None: + return False + if "available_teams" in litellm.default_internal_user_params: + return team_id in litellm.default_internal_user_params["available_teams"] + return False + + +async def get_all_team_memberships( + prisma_client: PrismaClient, team_id: List[str], user_id: Optional[str] = None +) -> List[LiteLLM_TeamMembership]: + """Get all team memberships for a given user""" + ## GET ALL MEMBERSHIPS ## + if not isinstance(user_id, str): + user_id = str(user_id) + + team_memberships = await prisma_client.db.litellm_teammembership.find_many( + where=( + {"user_id": user_id, "team_id": {"in": team_id}} + if user_id is not None + else {"team_id": {"in": team_id}} + ), + include={"litellm_budget_table": True}, + ) + + returned_tm: List[LiteLLM_TeamMembership] = [] + for tm in team_memberships: + returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump())) + + return returned_tm + + +#### TEAM MANAGEMENT #### +@router.post( + "/team/new", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], + response_model=LiteLLM_TeamTable, +) +@management_endpoint_wrapper +async def new_team( # noqa: PLR0915 + data: NewTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Allow users to create a new team. Apply user permissions to their team. + + 👉 [Detailed Doc on setting team budgets](https://docs.litellm.ai/docs/proxy/team_budgets) + + + Parameters: + - team_alias: Optional[str] - User defined team alias + - team_id: Optional[str] - The team id of the user. If none passed, we'll generate it. + - members_with_roles: List[{"role": "admin" or "user", "user_id": "<user-id>"}] - A list of users and their roles in the team. Get user_id when making a new user via `/user/new`. + - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"extra_info": "some info"} + - tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit + - rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit + - max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget + - budget_duration: Optional[str] - The duration of the budget for the team. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets) + - models: Optional[list] - A list of models associated with the team - all keys for this team_id will have at most, these models. If empty, assumes all models are allowed. + - blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id. + - members: Optional[List] - Control team members via `/team/member/add` and `/team/member/delete`. + - tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing). + - organization_id: Optional[str] - The organization id of the team. Default is None. Create via `/organization/new`. + - model_aliases: Optional[dict] - Model aliases for the team. [Docs](https://docs.litellm.ai/docs/proxy/team_based_routing#create-team-with-model-alias) + - guardrails: Optional[List[str]] - Guardrails for the team. [Docs](https://docs.litellm.ai/docs/proxy/guardrails) + Returns: + - team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id. + + _deprecated_params: + - admins: list - A list of user_id's for the admin role + - users: list - A list of user_id's for the user role + + Example Request: + ``` + curl --location 'http://0.0.0.0:4000/team/new' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "team_alias": "my-new-team_2", + "members_with_roles": [{"role": "admin", "user_id": "user-1234"}, + {"role": "user", "user_id": "user-2434"}] + }' + + ``` + + ``` + curl --location 'http://0.0.0.0:4000/team/new' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "team_alias": "QA Prod Bot", + "max_budget": 0.000000001, + "budget_duration": "1d" + }' + ``` + """ + try: + from litellm.proxy.proxy_server import ( + create_audit_log_for_update, + duration_in_seconds, + litellm_proxy_admin_name, + prisma_client, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + data.team_id = str(uuid.uuid4()) + else: + # Check if team_id exists already + _existing_team_id = await prisma_client.get_data( + team_id=data.team_id, table_name="team", query_type="find_unique" + ) + if _existing_team_id is not None: + raise HTTPException( + status_code=400, + detail={ + "error": f"Team id = {data.team_id} already exists. Please use a different team id." + }, + ) + + if ( + user_api_key_dict.user_role is None + or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN + ): # don't restrict proxy admin + if ( + data.tpm_limit is not None + and user_api_key_dict.tpm_limit is not None + and data.tpm_limit > user_api_key_dict.tpm_limit + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}" + }, + ) + + if ( + data.rpm_limit is not None + and user_api_key_dict.rpm_limit is not None + and data.rpm_limit > user_api_key_dict.rpm_limit + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}" + }, + ) + + if ( + data.max_budget is not None + and user_api_key_dict.max_budget is not None + and data.max_budget > user_api_key_dict.max_budget + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}" + }, + ) + + if data.models is not None and len(user_api_key_dict.models) > 0: + for m in data.models: + if m not in user_api_key_dict.models: + raise HTTPException( + status_code=400, + detail={ + "error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}" + }, + ) + + if user_api_key_dict.user_id is not None: + creating_user_in_list = False + for member in data.members_with_roles: + if member.user_id == user_api_key_dict.user_id: + creating_user_in_list = True + + if creating_user_in_list is False: + data.members_with_roles.append( + Member(role="admin", user_id=user_api_key_dict.user_id) + ) + + ## ADD TO MODEL TABLE + _model_id = None + if data.model_aliases is not None and isinstance(data.model_aliases, dict): + litellm_modeltable = LiteLLM_ModelTable( + model_aliases=json.dumps(data.model_aliases), + created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + ) + model_dict = await prisma_client.db.litellm_modeltable.create( + {**litellm_modeltable.json(exclude_none=True)} # type: ignore + ) # type: ignore + + _model_id = model_dict.id + + ## ADD TO TEAM TABLE + complete_team_data = LiteLLM_TeamTable( + **data.json(), + model_id=_model_id, + ) + + # Set Management Endpoint Metadata Fields + for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium: + if getattr(data, field) is not None: + _set_object_metadata_field( + object_data=complete_team_data, + field_name=field, + value=getattr(data, field), + ) + + # If budget_duration is set, set `budget_reset_at` + if complete_team_data.budget_duration is not None: + duration_s = duration_in_seconds( + duration=complete_team_data.budget_duration + ) + reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + complete_team_data.budget_reset_at = reset_at + + complete_team_data_dict = complete_team_data.model_dump(exclude_none=True) + complete_team_data_dict = prisma_client.jsonify_team_object( + db_data=complete_team_data_dict + ) + team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create( + data=complete_team_data_dict, + include={"litellm_model_table": True}, # type: ignore + ) + + ## ADD TEAM ID TO USER TABLE ## + for user in complete_team_data.members_with_roles: + ## add team id to user row ## + await prisma_client.update_data( + user_id=user.user_id, + data={"user_id": user.user_id, "teams": [team_row.team_id]}, + update_key_values_custom_query={ + "teams": { + "push ": [team_row.team_id], + } + }, + ) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = complete_team_data.json(exclude_none=True) + + _updated_values = json.dumps(_updated_values, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.TEAM_TABLE_NAME, + object_id=data.team_id, + action="created", + updated_values=_updated_values, + before_value=None, + ) + ) + ) + + try: + return team_row.model_dump() + except Exception: + return team_row.dict() + except Exception as e: + raise handle_exception_on_proxy(e) + + +async def _update_model_table( + data: UpdateTeamRequest, + model_id: Optional[str], + prisma_client: PrismaClient, + user_api_key_dict: UserAPIKeyAuth, + litellm_proxy_admin_name: str, +) -> Optional[str]: + """ + Upsert model table and return the model id + """ + ## UPSERT MODEL TABLE + _model_id = model_id + if data.model_aliases is not None and isinstance(data.model_aliases, dict): + litellm_modeltable = LiteLLM_ModelTable( + model_aliases=json.dumps(data.model_aliases), + created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + ) + if model_id is None: + model_dict = await prisma_client.db.litellm_modeltable.create( + data={**litellm_modeltable.json(exclude_none=True)} # type: ignore + ) + else: + model_dict = await prisma_client.db.litellm_modeltable.upsert( + where={"id": model_id}, + data={ + "update": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore + "create": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore + }, + ) # type: ignore + + _model_id = model_dict.id + + return _model_id + + +@router.post( + "/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def update_team( + data: UpdateTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Use `/team/member_add` AND `/team/member/delete` to add/remove new team members + + You can now update team budget / rate limits via /team/update + + Parameters: + - team_id: str - The team id of the user. Required param. + - team_alias: Optional[str] - User defined team alias + - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } + - tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit + - rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit + - max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget + - budget_duration: Optional[str] - The duration of the budget for the team. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets) + - models: Optional[list] - A list of models associated with the team - all keys for this team_id will have at most, these models. If empty, assumes all models are allowed. + - blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id. + - tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing). + - organization_id: Optional[str] - The organization id of the team. Default is None. Create via `/organization/new`. + - model_aliases: Optional[dict] - Model aliases for the team. [Docs](https://docs.litellm.ai/docs/proxy/team_based_routing#create-team-with-model-alias) + - guardrails: Optional[List[str]] - Guardrails for the team. [Docs](https://docs.litellm.ai/docs/proxy/guardrails) + Example - update team TPM Limit + + ``` + curl --location 'http://0.0.0.0:4000/team/update' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "team_id": "8d916b1c-510d-4894-a334-1c16a93344f5", + "tpm_limit": 100 + }' + ``` + + Example - Update Team `max_budget` budget + ``` + curl --location 'http://0.0.0.0:4000/team/update' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "team_id": "8d916b1c-510d-4894-a334-1c16a93344f5", + "max_budget": 10 + }' + ``` + """ + from litellm.proxy.auth.auth_checks import _cache_team_object + from litellm.proxy.proxy_server import ( + create_audit_log_for_update, + duration_in_seconds, + litellm_proxy_admin_name, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + verbose_proxy_logger.debug("/team/update - %s", data) + + existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} + ) + + if existing_team_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={data.team_id}"}, + ) + + updated_kv = data.json(exclude_unset=True) + + # Check budget_duration and budget_reset_at + if data.budget_duration is not None: + duration_s = duration_in_seconds(duration=data.budget_duration) + reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + + # set the budget_reset_at in DB + updated_kv["budget_reset_at"] = reset_at + + # update team metadata fields + _team_metadata_fields = LiteLLM_ManagementEndpoint_MetadataFields_Premium + for field in _team_metadata_fields: + if field in updated_kv and updated_kv[field] is not None: + _update_team_metadata_field( + updated_kv=updated_kv, + field_name=field, + ) + + if "model_aliases" in updated_kv: + updated_kv.pop("model_aliases") + _model_id = await _update_model_table( + data=data, + model_id=existing_team_row.model_id, + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ) + if _model_id is not None: + updated_kv["model_id"] = _model_id + + updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv) + team_row: Optional[LiteLLM_TeamTable] = ( + await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, + data=updated_kv, + include={"litellm_model_table": True}, # type: ignore + ) + ) + + if team_row is None or team_row.team_id is None: + raise HTTPException( + status_code=400, + detail={"error": "Team doesn't exist. Got={}".format(team_row)}, + ) + + await _cache_team_object( + team_id=team_row.team_id, + team_table=LiteLLM_TeamTableCachedObj(**team_row.model_dump()), + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _before_value = existing_team_row.json(exclude_none=True) + _before_value = json.dumps(_before_value, default=str) + _after_value: str = json.dumps(updated_kv, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.TEAM_TABLE_NAME, + object_id=data.team_id, + action="updated", + updated_values=_after_value, + before_value=_before_value, + ) + ) + ) + + return {"team_id": team_row.team_id, "data": team_row} + + +def _check_team_member_admin_add( + member: Union[Member, List[Member]], + premium_user: bool, +): + if isinstance(member, Member) and member.role == "admin": + if premium_user is not True: + raise ValueError( + f"Assigning team admins is a premium feature. {CommonProxyErrors.not_premium_user.value}" + ) + elif isinstance(member, List): + for m in member: + if m.role == "admin": + if premium_user is not True: + raise ValueError( + f"Assigning team admins is a premium feature. Got={m}. {CommonProxyErrors.not_premium_user.value}. " + ) + + +def team_call_validation_checks( + prisma_client: Optional[PrismaClient], + data: TeamMemberAddRequest, + premium_user: bool, +): + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + if data.member is None: + raise HTTPException( + status_code=400, detail={"error": "No member/members passed in"} + ) + + try: + _check_team_member_admin_add( + member=data.member, + premium_user=premium_user, + ) + except Exception as e: + raise HTTPException(status_code=400, detail={"error": str(e)}) + + +def team_member_add_duplication_check( + data: TeamMemberAddRequest, + existing_team_row: LiteLLM_TeamTable, +): + def _check_member_duplication(member: Member): + if member.user_id in [m.user_id for m in existing_team_row.members_with_roles]: + raise HTTPException( + status_code=400, + detail={ + "error": f"User={member.user_id} already in team. Existing members={existing_team_row.members_with_roles}" + }, + ) + + if isinstance(data.member, Member): + _check_member_duplication(data.member) + elif isinstance(data.member, List): + for m in data.member: + _check_member_duplication(m) + + +@router.post( + "/team/member_add", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], + response_model=TeamAddMemberResponse, +) +@management_endpoint_wrapper +async def team_member_add( + data: TeamMemberAddRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [BETA] + + Add new members (either via user_email or user_id) to a team + + If user doesn't exist, new user row will also be added to User Table + + Only proxy_admin or admin of team, allowed to access this endpoint. + ``` + + curl -X POST 'http://0.0.0.0:4000/team/member_add' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{"team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", "member": {"role": "user", "user_id": "krrish247652@berri.ai"}}' + + ``` + """ + from litellm.proxy.proxy_server import ( + litellm_proxy_admin_name, + premium_user, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + try: + team_call_validation_checks( + prisma_client=prisma_client, + data=data, + premium_user=premium_user, + ) + except HTTPException as e: + raise e + + prisma_client = cast(PrismaClient, prisma_client) + + existing_team_row = await get_team_object( + team_id=data.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=proxy_logging_obj, + check_cache_only=False, + check_db_only=True, + ) + if existing_team_row is None: + raise HTTPException( + status_code=404, + detail={ + "error": f"Team not found for team_id={getattr(data, 'team_id', None)}" + }, + ) + + complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump()) + + team_member_add_duplication_check( + data=data, + existing_team_row=complete_team_data, + ) + + ## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN + + if ( + hasattr(user_api_key_dict, "user_role") + and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value + and not _is_user_team_admin( + user_api_key_dict=user_api_key_dict, team_obj=complete_team_data + ) + and not _is_available_team( + team_id=complete_team_data.team_id, + user_api_key_dict=user_api_key_dict, + ) + ): + raise HTTPException( + status_code=403, + detail={ + "error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format( + "/team/member_add", + complete_team_data.team_id, + ) + }, + ) + + updated_users: List[LiteLLM_UserTable] = [] + updated_team_memberships: List[LiteLLM_TeamMembership] = [] + + ## VALIDATE IF NEW MEMBER ## + if isinstance(data.member, Member): + try: + updated_user, updated_tm = await add_new_member( + new_member=data.member, + max_budget_in_team=data.max_budget_in_team, + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + team_id=data.team_id, + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail={ + "error": "Unable to add user - {}, to team - {}, for reason - {}".format( + data.member, data.team_id, str(e) + ) + }, + ) + + updated_users.append(updated_user) + if updated_tm is not None: + updated_team_memberships.append(updated_tm) + elif isinstance(data.member, List): + tasks: List = [] + for m in data.member: + try: + updated_user, updated_tm = await add_new_member( + new_member=m, + max_budget_in_team=data.max_budget_in_team, + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + team_id=data.team_id, + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail={ + "error": "Unable to add user - {}, to team - {}, for reason - {}".format( + data.member, data.team_id, str(e) + ) + }, + ) + updated_users.append(updated_user) + if updated_tm is not None: + updated_team_memberships.append(updated_tm) + + await asyncio.gather(*tasks) + + ## ADD TO TEAM ## + if isinstance(data.member, Member): + # add to team db + new_member = data.member + + # get user id + if new_member.user_id is None and new_member.user_email is not None: + for user in updated_users: + if ( + user.user_email is not None + and user.user_email == new_member.user_email + ): + new_member.user_id = user.user_id + + complete_team_data.members_with_roles.append(new_member) + + elif isinstance(data.member, List): + # add to team db + new_members = data.member + + for nm in new_members: + if nm.user_id is None and nm.user_email is not None: + for user in updated_users: + if user.user_email is not None and user.user_email == nm.user_email: + nm.user_id = user.user_id + + complete_team_data.members_with_roles.extend(new_members) + + # ADD MEMBER TO TEAM + _db_team_members = [m.model_dump() for m in complete_team_data.members_with_roles] + updated_team = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, + data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore + ) + + # Check if updated_team is None + if updated_team is None: + raise HTTPException( + status_code=404, detail={"error": f"Team with id {data.team_id} not found"} + ) + return TeamAddMemberResponse( + **updated_team.model_dump(), + updated_users=updated_users, + updated_team_memberships=updated_team_memberships, + ) + + +@router.post( + "/team/member_delete", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def team_member_delete( + data: TeamMemberDeleteRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [BETA] + + delete members (either via user_email or user_id) from a team + + If user doesn't exist, an exception will be raised + ``` + curl -X POST 'http://0.0.0.0:8000/team/member_delete' \ + + -H 'Authorization: Bearer sk-1234' \ + + -H 'Content-Type: application/json' \ + + -d '{ + "team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", + "user_id": "krrish247652@berri.ai" + }' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + if data.user_id is None and data.user_email is None: + raise HTTPException( + status_code=400, + detail={"error": "Either user_id or user_email needs to be passed in"}, + ) + + _existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} + ) + + if _existing_team_row is None: + raise HTTPException( + status_code=400, + detail={"error": "Team id={} does not exist in db".format(data.team_id)}, + ) + existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump()) + + ## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN + + if ( + user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value + and not _is_user_team_admin( + user_api_key_dict=user_api_key_dict, team_obj=existing_team_row + ) + ): + raise HTTPException( + status_code=403, + detail={ + "error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format( + "/team/member_delete", existing_team_row.team_id + ) + }, + ) + + ## DELETE MEMBER FROM TEAM + is_member_in_team = False + new_team_members: List[Member] = [] + for m in existing_team_row.members_with_roles: + if ( + data.user_id is not None + and m.user_id is not None + and data.user_id == m.user_id + ): + is_member_in_team = True + continue + elif ( + data.user_email is not None + and m.user_email is not None + and data.user_email == m.user_email + ): + is_member_in_team = True + continue + new_team_members.append(m) + + if not is_member_in_team: + raise HTTPException(status_code=400, detail={"error": "User not found in team"}) + + existing_team_row.members_with_roles = new_team_members + + _db_new_team_members: List[dict] = [m.model_dump() for m in new_team_members] + + _ = await prisma_client.db.litellm_teamtable.update( + where={ + "team_id": data.team_id, + }, + data={"members_with_roles": json.dumps(_db_new_team_members)}, # type: ignore + ) + + ## DELETE TEAM ID from USER ROW, IF EXISTS ## + # get user row + key_val = {} + if data.user_id is not None: + key_val["user_id"] = data.user_id + elif data.user_email is not None: + key_val["user_email"] = data.user_email + existing_user_rows = await prisma_client.db.litellm_usertable.find_many( + where=key_val # type: ignore + ) + + if existing_user_rows is not None and ( + isinstance(existing_user_rows, list) and len(existing_user_rows) > 0 + ): + for existing_user in existing_user_rows: + team_list = [] + if data.team_id in existing_user.teams: + team_list = existing_user.teams + team_list.remove(data.team_id) + await prisma_client.db.litellm_usertable.update( + where={ + "user_id": existing_user.user_id, + }, + data={"teams": {"set": team_list}}, + ) + + return existing_team_row + + +@router.post( + "/team/member_update", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], + response_model=TeamMemberUpdateResponse, +) +@management_endpoint_wrapper +async def team_member_update( + data: TeamMemberUpdateRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [BETA] + + Update team member budgets and team member role + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + if data.user_id is None and data.user_email is None: + raise HTTPException( + status_code=400, + detail={"error": "Either user_id or user_email needs to be passed in"}, + ) + + _existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} + ) + + if _existing_team_row is None: + raise HTTPException( + status_code=400, + detail={"error": "Team id={} does not exist in db".format(data.team_id)}, + ) + existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump()) + + ## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN + + if ( + user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value + and not _is_user_team_admin( + user_api_key_dict=user_api_key_dict, team_obj=existing_team_row + ) + ): + raise HTTPException( + status_code=403, + detail={ + "error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format( + "/team/member_delete", existing_team_row.team_id + ) + }, + ) + + returned_team_info: TeamInfoResponseObject = await team_info( + http_request=http_request, + team_id=data.team_id, + user_api_key_dict=user_api_key_dict, + ) + + team_table = returned_team_info["team_info"] + + ## get user id + received_user_id: Optional[str] = None + if data.user_id is not None: + received_user_id = data.user_id + elif data.user_email is not None: + for member in returned_team_info["team_info"].members_with_roles: + if member.user_email is not None and member.user_email == data.user_email: + received_user_id = member.user_id + break + + if received_user_id is None: + raise HTTPException( + status_code=400, + detail={ + "error": "User id doesn't exist in team table. Data={}".format(data) + }, + ) + ## find the relevant team membership + identified_budget_id: Optional[str] = None + for tm in returned_team_info["team_memberships"]: + if tm.user_id == received_user_id: + identified_budget_id = tm.budget_id + break + + ### upsert new budget + if data.max_budget_in_team is not None: + if identified_budget_id is None: + new_budget = await prisma_client.db.litellm_budgettable.create( + data={ + "max_budget": data.max_budget_in_team, + "created_by": user_api_key_dict.user_id or "", + "updated_by": user_api_key_dict.user_id or "", + } + ) + + await prisma_client.db.litellm_teammembership.create( + data={ + "team_id": data.team_id, + "user_id": received_user_id, + "budget_id": new_budget.budget_id, + }, + ) + elif identified_budget_id is not None: + await prisma_client.db.litellm_budgettable.update( + where={"budget_id": identified_budget_id}, + data={"max_budget": data.max_budget_in_team}, + ) + + ### update team member role + if data.role is not None: + team_members: List[Member] = [] + for member in team_table.members_with_roles: + if member.user_id == received_user_id: + team_members.append( + Member( + user_id=member.user_id, + role=data.role, + user_email=data.user_email or member.user_email, + ) + ) + else: + team_members.append(member) + + team_table.members_with_roles = team_members + + _db_team_members: List[dict] = [m.model_dump() for m in team_members] + await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, + data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore + ) + + return TeamMemberUpdateResponse( + team_id=data.team_id, + user_id=received_user_id, + user_email=data.user_email, + max_budget_in_team=data.max_budget_in_team, + ) + + +@router.post( + "/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def delete_team( + data: DeleteTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + delete team and associated team keys + + Parameters: + - team_ids: List[str] - Required. List of team IDs to delete. Example: ["team-1234", "team-5678"] + + ``` + curl --location 'http://0.0.0.0:4000/team/delete' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "team_ids": ["8d916b1c-510d-4894-a334-1c16a93344f5"] + }' + ``` + """ + from litellm.proxy.proxy_server import ( + create_audit_log_for_update, + litellm_proxy_admin_name, + prisma_client, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_ids is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + # check that all teams passed exist + team_rows: List[LiteLLM_TeamTable] = [] + for team_id in data.team_ids: + try: + team_row_base: Optional[BaseModel] = ( + await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} + ) + ) + if team_row_base is None: + raise Exception + except Exception: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={team_id}"}, + ) + team_row_pydantic = LiteLLM_TeamTable(**team_row_base.model_dump()) + team_rows.append(team_row_pydantic) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes + if litellm.store_audit_logs is True: + # make an audit log for each team deleted + for team_id in data.team_ids: + team_row: Optional[LiteLLM_TeamTable] = await prisma_client.get_data( # type: ignore + team_id=team_id, table_name="team", query_type="find_unique" + ) + + if team_row is None: + continue + + _team_row = team_row.json(exclude_none=True) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.TEAM_TABLE_NAME, + object_id=team_id, + action="deleted", + updated_values="{}", + before_value=_team_row, + ) + ) + ) + + # End of Audit logging + + ## DELETE ASSOCIATED KEYS + await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key") + + # ## DELETE TEAM MEMBERSHIPS + for team_row in team_rows: + ### get all team members + team_members = team_row.members_with_roles + ### call team_member_delete for each team member + tasks = [] + for team_member in team_members: + tasks.append( + team_member_delete( + data=TeamMemberDeleteRequest( + team_id=team_row.team_id, + user_id=team_member.user_id, + user_email=team_member.user_email, + ), + user_api_key_dict=user_api_key_dict, + ) + ) + await asyncio.gather(*tasks) + + ## DELETE TEAMS + deleted_teams = await prisma_client.delete_data( + team_id_list=data.team_ids, table_name="team" + ) + return deleted_teams + + +def validate_membership( + user_api_key_dict: UserAPIKeyAuth, team_table: LiteLLM_TeamTable +): + if ( + user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value + or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value + ): + return + + if ( + user_api_key_dict.team_id == team_table.team_id + ): # allow team keys to check their info + return + + if user_api_key_dict.user_id not in [ + m.user_id for m in team_table.members_with_roles + ]: + raise HTTPException( + status_code=403, + detail={ + "error": "User={} not authorized to access this team={}".format( + user_api_key_dict.user_id, team_table.team_id + ) + }, + ) + + +def _unfurl_all_proxy_models( + team_info: LiteLLM_TeamTable, llm_router: Router +) -> LiteLLM_TeamTable: + if ( + SpecialModelNames.all_proxy_models.value in team_info.models + and llm_router is not None + ): + team_models: set[str] = set() # make set to avoid duplicates + for model in team_info.models: + if model != SpecialModelNames.all_proxy_models.value: + team_models.add(model) + for model in llm_router.get_model_names(): + team_models.add(model) + team_info.models = list(team_models) + return team_info + + +@router.get( + "/team/info", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def team_info( + http_request: Request, + team_id: str = fastapi.Query( + default=None, description="Team ID in the request parameters" + ), + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + get info on team + related keys + + Parameters: + - team_id: str - Required. The unique identifier of the team to get info on. + + ``` + curl --location 'http://localhost:4000/team/info?team_id=your_team_id_here' \ + --header 'Authorization: Bearer your_api_key_here' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + try: + if prisma_client is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "error": "Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + }, + ) + if team_id is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={"message": "Malformed request. No team id passed in."}, + ) + + try: + team_info: Optional[BaseModel] = ( + await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} + ) + ) + if team_info is None: + raise Exception + except Exception: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"message": f"Team not found, passed team id: {team_id}."}, + ) + validate_membership( + user_api_key_dict=user_api_key_dict, + team_table=LiteLLM_TeamTable(**team_info.model_dump()), + ) + + ## GET ALL KEYS ## + keys = await prisma_client.get_data( + team_id=team_id, + table_name="key", + query_type="find_all", + expires=datetime.now(), + ) + + if keys is None: + keys = [] + + if team_info is None: + ## make sure we still return a total spend ## + spend = 0 + for k in keys: + spend += getattr(k, "spend", 0) + team_info = {"spend": spend} + + ## REMOVE HASHED TOKEN INFO before returning ## + for key in keys: + try: + key = key.model_dump() # noqa + except Exception: + # if using pydantic v1 + key = key.dict() + key.pop("token", None) + + ## GET ALL MEMBERSHIPS ## + returned_tm = await get_all_team_memberships( + prisma_client, [team_id], user_id=None + ) + + if isinstance(team_info, dict): + _team_info = LiteLLM_TeamTable(**team_info) + elif isinstance(team_info, BaseModel): + _team_info = LiteLLM_TeamTable(**team_info.model_dump()) + else: + _team_info = LiteLLM_TeamTable() + + # ## UNFURL 'all-proxy-models' into the team_info.models list ## + # if llm_router is not None: + # _team_info = _unfurl_all_proxy_models(_team_info, llm_router) + response_object = TeamInfoResponseObject( + team_id=team_id, + team_info=_team_info, + keys=keys, + team_memberships=returned_tm, + ) + return response_object + + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.management_endpoints.team_endpoints.py::team_info - Exception occurred - {}\n{}".format( + e, traceback.format_exc() + ) + ) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + +@router.post( + "/team/block", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def block_team( + data: BlockTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Blocks all calls from keys with this team id. + + Parameters: + - team_id: str - Required. The unique identifier of the team to block. + + Example: + ``` + curl --location 'http://0.0.0.0:4000/team/block' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "team_id": "team-1234" + }' + ``` + + Returns: + - The updated team record with blocked=True + + + + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise Exception("No DB Connected.") + + record = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, data={"blocked": True} # type: ignore + ) + + if record is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={data.team_id}"}, + ) + + return record + + +@router.post( + "/team/unblock", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def unblock_team( + data: BlockTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Blocks all calls from keys with this team id. + + Parameters: + - team_id: str - Required. The unique identifier of the team to unblock. + + Example: + ``` + curl --location 'http://0.0.0.0:4000/team/unblock' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "team_id": "team-1234" + }' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise Exception("No DB Connected.") + + record = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, data={"blocked": False} # type: ignore + ) + + if record is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={data.team_id}"}, + ) + + return record + + +@router.get("/team/available") +async def list_available_teams( + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + response_model=List[LiteLLM_TeamTable], +): + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + available_teams = cast( + Optional[List[str]], + ( + litellm.default_internal_user_params.get("available_teams") + if litellm.default_internal_user_params is not None + else None + ), + ) + if available_teams is None: + raise HTTPException( + status_code=400, + detail={ + "error": "No available teams for user to join. See how to set available teams here: https://docs.litellm.ai/docs/proxy/self_serve#all-settings-for-self-serve--sso-flow" + }, + ) + + # filter out teams that the user is already a member of + user_info = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id} + ) + if user_info is None: + raise HTTPException( + status_code=404, + detail={"error": "User not found"}, + ) + user_info_correct_type = LiteLLM_UserTable(**user_info.model_dump()) + + available_teams = [ + team for team in available_teams if team not in user_info_correct_type.teams + ] + + available_teams_db = await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": available_teams}} + ) + + available_teams_correct_type = [ + LiteLLM_TeamTable(**team.model_dump()) for team in available_teams_db + ] + + return available_teams_correct_type + + +@router.get( + "/team/list", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def list_team( + http_request: Request, + user_id: Optional[str] = fastapi.Query( + default=None, description="Only return teams which this 'user_id' belongs to" + ), + organization_id: Optional[str] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + ``` + curl --location --request GET 'http://0.0.0.0:4000/team/list' \ + --header 'Authorization: Bearer sk-1234' + ``` + + Parameters: + - user_id: str - Optional. If passed will only return teams that the user_id is a member of. + - organization_id: str - Optional. If passed will only return teams that belong to the organization_id. Pass 'default_organization' to get all teams without organization_id. + """ + from litellm.proxy.proxy_server import prisma_client + + if not allowed_route_check_inside_route( + user_api_key_dict=user_api_key_dict, requested_user_id=user_id + ): + raise HTTPException( + status_code=401, + detail={ + "error": "Only admin users can query all teams/other teams. Your user role={}".format( + user_api_key_dict.user_role + ) + }, + ) + + if prisma_client is None: + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + response = await prisma_client.db.litellm_teamtable.find_many( + include={ + "litellm_model_table": True, + } + ) + + filtered_response = [] + if user_id: + for team in response: + if team.members_with_roles: + for member in team.members_with_roles: + if ( + "user_id" in member + and member["user_id"] is not None + and member["user_id"] == user_id + ): + filtered_response.append(team) + + else: + filtered_response = response + + _team_ids = [team.team_id for team in filtered_response] + returned_tm = await get_all_team_memberships( + prisma_client, _team_ids, user_id=user_id + ) + + returned_responses: List[TeamListResponseObject] = [] + for team in filtered_response: + _team_memberships: List[LiteLLM_TeamMembership] = [] + for tm in returned_tm: + if tm.team_id == team.team_id: + _team_memberships.append(tm) + + # add all keys that belong to the team + keys = await prisma_client.db.litellm_verificationtoken.find_many( + where={"team_id": team.team_id} + ) + + try: + returned_responses.append( + TeamListResponseObject( + **team.model_dump(), + team_memberships=_team_memberships, + keys=keys, + ) + ) + except Exception as e: + team_exception = """Invalid team object for team_id: {}. team_object={}. + Error: {} + """.format( + team.team_id, team.model_dump(), str(e) + ) + verbose_proxy_logger.exception(team_exception) + continue + # Sort the responses by team_alias + returned_responses.sort(key=lambda x: (getattr(x, "team_alias", "") or "")) + + if organization_id is not None: + if organization_id == SpecialManagementEndpointEnums.DEFAULT_ORGANIZATION.value: + returned_responses = [ + team for team in returned_responses if team.organization_id is None + ] + else: + returned_responses = [ + team + for team in returned_responses + if team.organization_id == organization_id + ] + + return returned_responses + + +async def get_paginated_teams( + prisma_client: PrismaClient, + page_size: int = 10, + page: int = 1, +) -> Tuple[List[LiteLLM_TeamTable], int]: + """ + Get paginated list of teams from team table + + Parameters: + prisma_client: PrismaClient - The database client + page_size: int - Number of teams per page + page: int - Page number (1-based) + + Returns: + Tuple[List[LiteLLM_TeamTable], int] - (list of teams, total count) + """ + try: + # Calculate skip for pagination + skip = (page - 1) * page_size + # Get total count + total_count = await prisma_client.db.litellm_teamtable.count() + + # Get paginated teams + teams = await prisma_client.db.litellm_teamtable.find_many( + skip=skip, take=page_size, order={"team_alias": "asc"} # Sort by team_alias + ) + return teams, total_count + except Exception as e: + verbose_proxy_logger.exception( + f"[Non-Blocking] Error getting paginated teams: {e}" + ) + return [], 0 + + +def _update_team_metadata_field(updated_kv: dict, field_name: str) -> None: + """ + Helper function to update metadata fields that require premium user checks in the update endpoint + + Args: + updated_kv: The key-value dict being used for the update + field_name: Name of the metadata field being updated + """ + if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium: + _premium_user_check() + + if field_name in updated_kv and updated_kv[field_name] is not None: + # remove field from updated_kv + _value = updated_kv.pop(field_name) + if "metadata" in updated_kv and updated_kv["metadata"] is not None: + updated_kv["metadata"][field_name] = _value + else: + updated_kv["metadata"] = {field_name: _value} + + +@router.get( + "/team/filter/ui", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, + responses={ + 200: {"model": List[LiteLLM_TeamTable]}, + }, +) +async def ui_view_teams( + team_id: Optional[str] = fastapi.Query( + default=None, description="Team ID in the request parameters" + ), + team_alias: Optional[str] = fastapi.Query( + default=None, description="Team alias in the request parameters" + ), + page: int = fastapi.Query( + default=1, description="Page number for pagination", ge=1 + ), + page_size: int = fastapi.Query( + default=50, description="Number of items per page", ge=1, le=100 + ), + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [PROXY-ADMIN ONLY] Filter teams based on partial match of team_id or team_alias with pagination. + + Args: + user_id (Optional[str]): Partial user ID to search for + user_email (Optional[str]): Partial email to search for + page (int): Page number for pagination (starts at 1) + page_size (int): Number of items per page (max 100) + user_api_key_dict (UserAPIKeyAuth): User authentication information + + Returns: + List[LiteLLM_SpendLogs]: Paginated list of matching user records + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + try: + # Calculate offset for pagination + skip = (page - 1) * page_size + + # Build where conditions based on provided parameters + where_conditions = {} + + if team_id: + where_conditions["team_id"] = { + "contains": team_id, + "mode": "insensitive", # Case-insensitive search + } + + if team_alias: + where_conditions["team_alias"] = { + "contains": team_alias, + "mode": "insensitive", # Case-insensitive search + } + + # Query users with pagination and filters + teams = await prisma_client.db.litellm_teamtable.find_many( + where=where_conditions, + skip=skip, + take=page_size, + order={"created_at": "desc"}, + ) + + if not teams: + return [] + + return teams + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error searching teams: {str(e)}") + + +@router.post( + "/team/model/add", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def team_model_add( + data: TeamModelAddRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Add models to a team's allowed model list. Only proxy admin or team admin can add models. + + Parameters: + - team_id: str - Required. The team to add models to + - models: List[str] - Required. List of models to add to the team + + Example Request: + ``` + curl --location 'http://0.0.0.0:4000/team/model/add' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "team_id": "team-1234", + "models": ["gpt-4", "claude-2"] + }' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + # Get existing team + team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} + ) + + if team_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={data.team_id}"}, + ) + + team_obj = LiteLLM_TeamTable(**team_row.model_dump()) + + # Authorization check - only proxy admin or team admin can add models + if ( + user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value + and not _is_user_team_admin( + user_api_key_dict=user_api_key_dict, team_obj=team_obj + ) + ): + raise HTTPException( + status_code=403, + detail={"error": "Only proxy admin or team admin can modify team models"}, + ) + + # Get current models list + current_models = team_obj.models or [] + + # Add new models (avoid duplicates) + updated_models = list(set(current_models + data.models)) + + # Update team + updated_team = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, data={"models": updated_models} + ) + + return updated_team + + +@router.post( + "/team/model/delete", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def team_model_delete( + data: TeamModelDeleteRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Remove models from a team's allowed model list. Only proxy admin or team admin can remove models. + + Parameters: + - team_id: str - Required. The team to remove models from + - models: List[str] - Required. List of models to remove from the team + + Example Request: + ``` + curl --location 'http://0.0.0.0:4000/team/model/delete' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "team_id": "team-1234", + "models": ["gpt-4"] + }' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + # Get existing team + team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} + ) + + if team_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={data.team_id}"}, + ) + + team_obj = LiteLLM_TeamTable(**team_row.model_dump()) + + # Authorization check - only proxy admin or team admin can remove models + if ( + user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value + and not _is_user_team_admin( + user_api_key_dict=user_api_key_dict, team_obj=team_obj + ) + ): + raise HTTPException( + status_code=403, + detail={"error": "Only proxy admin or team admin can modify team models"}, + ) + + # Get current models list + current_models = team_obj.models or [] + + # Remove specified models + updated_models = [m for m in current_models if m not in data.models] + + # Update team + updated_team = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, data={"models": updated_models} + ) + + return updated_team diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/types.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/types.py new file mode 100644 index 00000000..0e811669 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/types.py @@ -0,0 +1,13 @@ +""" +Types for the management endpoints + +Might include fastapi/proxy requirements.txt related imports +""" + +from typing import List + +from fastapi_sso.sso.base import OpenID + + +class CustomOpenID(OpenID): + team_ids: List[str] diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/ui_sso.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/ui_sso.py new file mode 100644 index 00000000..86dec9fc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/ui_sso.py @@ -0,0 +1,781 @@ +""" +Has all /sso/* routes + +/sso/key/generate - handles user signing in with SSO and redirects to /sso/callback +/sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI +""" + +import asyncio +import os +import uuid +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import RedirectResponse + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY +from litellm.proxy._types import ( + LiteLLM_UserTable, + LitellmUserRoles, + Member, + NewUserRequest, + NewUserResponse, + ProxyErrorTypes, + ProxyException, + SSOUserDefinedValues, + TeamMemberAddRequest, + UserAPIKeyAuth, +) +from litellm.proxy.auth.auth_checks import get_user_object +from litellm.proxy.auth.auth_utils import _has_user_setup_sso +from litellm.proxy.auth.handle_jwt import JWTHandler +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.common_utils.admin_ui_utils import ( + admin_ui_disabled, + html_form, + show_missing_vars_in_env, +) +from litellm.proxy.management_endpoints.internal_user_endpoints import new_user +from litellm.proxy.management_endpoints.sso_helper_utils import ( + check_is_admin_only_access, + has_admin_ui_access, +) +from litellm.proxy.management_endpoints.team_endpoints import team_member_add +from litellm.proxy.management_endpoints.types import CustomOpenID +from litellm.secret_managers.main import str_to_bool + +if TYPE_CHECKING: + from fastapi_sso.sso.base import OpenID +else: + from typing import Any as OpenID + +router = APIRouter() + + +@router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False) +async def google_login(request: Request): # noqa: PLR0915 + """ + Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env + PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/" + Example: + """ + from litellm.proxy.proxy_server import premium_user + + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + + ####### Check if UI is disabled ####### + _disable_ui_flag = os.getenv("DISABLE_ADMIN_UI") + if _disable_ui_flag is not None: + is_disabled = str_to_bool(value=_disable_ui_flag) + if is_disabled: + return admin_ui_disabled() + + ####### Check if user is a Enterprise / Premium User ####### + if ( + microsoft_client_id is not None + or google_client_id is not None + or generic_client_id is not None + ): + if premium_user is not True: + raise ProxyException( + message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this", + type=ProxyErrorTypes.auth_error, + param="premium_user", + code=status.HTTP_403_FORBIDDEN, + ) + + ####### Detect DB + MASTER KEY in .env ####### + missing_env_vars = show_missing_vars_in_env() + if missing_env_vars is not None: + return missing_env_vars + + # get url from request + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) + ui_username = os.getenv("UI_USERNAME") + if redirect_url.endswith("/"): + redirect_url += "sso/callback" + else: + redirect_url += "/sso/callback" + # Google SSO Auth + if google_client_id is not None: + from fastapi_sso.sso.google import GoogleSSO + + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: + raise ProxyException( + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + google_sso = GoogleSSO( + client_id=google_client_id, + client_secret=google_client_secret, + redirect_uri=redirect_url, + ) + verbose_proxy_logger.info( + f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" + ) + with google_sso: + return await google_sso.get_login_redirect() + # Microsoft SSO Auth + elif microsoft_client_id is not None: + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + with microsoft_sso: + return await microsoft_sso.get_login_redirect() + elif generic_client_id is not None: + from fastapi_sso.sso.base import DiscoveryDocument + from fastapi_sso.sso.generic import create_provider + + generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) + generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") + generic_authorization_endpoint = os.getenv( + "GENERIC_AUTHORIZATION_ENDPOINT", None + ) + generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) + generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) + if generic_client_secret is None: + raise ProxyException( + message="GENERIC_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_authorization_endpoint is None: + raise ProxyException( + message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_AUTHORIZATION_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_token_endpoint is None: + raise ProxyException( + message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_TOKEN_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_userinfo_endpoint is None: + raise ProxyException( + message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_USERINFO_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + verbose_proxy_logger.debug( + f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" + ) + verbose_proxy_logger.debug( + f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" + ) + discovery = DiscoveryDocument( + authorization_endpoint=generic_authorization_endpoint, + token_endpoint=generic_token_endpoint, + userinfo_endpoint=generic_userinfo_endpoint, + ) + SSOProvider = create_provider(name="oidc", discovery_document=discovery) + generic_sso = SSOProvider( + client_id=generic_client_id, + client_secret=generic_client_secret, + redirect_uri=redirect_url, + allow_insecure_http=True, + scope=generic_scope, + ) + with generic_sso: + # TODO: state should be a random string and added to the user session with cookie + # or a cryptographicly signed state that we can verify stateless + # For simplification we are using a static state, this is not perfect but some + # SSO providers do not allow stateless verification + redirect_params = {} + state = os.getenv("GENERIC_CLIENT_STATE", None) + + if state: + redirect_params["state"] = state + elif "okta" in generic_authorization_endpoint: + redirect_params["state"] = ( + uuid.uuid4().hex + ) # set state param for okta - required + return await generic_sso.get_login_redirect(**redirect_params) # type: ignore + elif ui_username is not None: + # No Google, Microsoft SSO + # Use UI Credentials set in .env + from fastapi.responses import HTMLResponse + + return HTMLResponse(content=html_form, status_code=200) + else: + from fastapi.responses import HTMLResponse + + return HTMLResponse(content=html_form, status_code=200) + + +def generic_response_convertor(response, jwt_handler: JWTHandler): + generic_user_id_attribute_name = os.getenv( + "GENERIC_USER_ID_ATTRIBUTE", "preferred_username" + ) + generic_user_display_name_attribute_name = os.getenv( + "GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub" + ) + generic_user_email_attribute_name = os.getenv( + "GENERIC_USER_EMAIL_ATTRIBUTE", "email" + ) + + generic_user_first_name_attribute_name = os.getenv( + "GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name" + ) + generic_user_last_name_attribute_name = os.getenv( + "GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name" + ) + + generic_provider_attribute_name = os.getenv( + "GENERIC_USER_PROVIDER_ATTRIBUTE", "provider" + ) + + verbose_proxy_logger.debug( + f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}" + ) + + return CustomOpenID( + id=response.get(generic_user_id_attribute_name), + display_name=response.get(generic_user_display_name_attribute_name), + email=response.get(generic_user_email_attribute_name), + first_name=response.get(generic_user_first_name_attribute_name), + last_name=response.get(generic_user_last_name_attribute_name), + provider=response.get(generic_provider_attribute_name), + team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)), + ) + + +async def get_generic_sso_response( + request: Request, + jwt_handler: JWTHandler, + generic_client_id: str, + redirect_url: str, +) -> Optional[OpenID]: + # make generic sso provider + from fastapi_sso.sso.base import DiscoveryDocument + from fastapi_sso.sso.generic import create_provider + + generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) + generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") + generic_authorization_endpoint = os.getenv("GENERIC_AUTHORIZATION_ENDPOINT", None) + generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) + generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) + generic_include_client_id = ( + os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true" + ) + if generic_client_secret is None: + raise ProxyException( + message="GENERIC_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_authorization_endpoint is None: + raise ProxyException( + message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_AUTHORIZATION_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_token_endpoint is None: + raise ProxyException( + message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_TOKEN_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_userinfo_endpoint is None: + raise ProxyException( + message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_USERINFO_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + verbose_proxy_logger.debug( + f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" + ) + verbose_proxy_logger.debug( + f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" + ) + + discovery = DiscoveryDocument( + authorization_endpoint=generic_authorization_endpoint, + token_endpoint=generic_token_endpoint, + userinfo_endpoint=generic_userinfo_endpoint, + ) + + def response_convertor(response, client): + return generic_response_convertor( + response=response, + jwt_handler=jwt_handler, + ) + + SSOProvider = create_provider( + name="oidc", + discovery_document=discovery, + response_convertor=response_convertor, + ) + generic_sso = SSOProvider( + client_id=generic_client_id, + client_secret=generic_client_secret, + redirect_uri=redirect_url, + allow_insecure_http=True, + scope=generic_scope, + ) + verbose_proxy_logger.debug("calling generic_sso.verify_and_process") + result = await generic_sso.verify_and_process( + request, params={"include_client_id": generic_include_client_id} + ) + verbose_proxy_logger.debug("generic result: %s", result) + return result + + +async def create_team_member_add_task(team_id, user_info): + """Create a task for adding a member to a team.""" + try: + member = Member(user_id=user_info.user_id, role="user") + team_member_add_request = TeamMemberAddRequest( + member=member, + team_id=team_id, + ) + return await team_member_add( + data=team_member_add_request, + user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), + http_request=Request(scope={"type": "http", "path": "/sso/callback"}), + ) + except Exception as e: + verbose_proxy_logger.debug( + f"[Non-Blocking] Error trying to add sso user to db: {e}" + ) + + +async def add_missing_team_member( + user_info: Union[NewUserResponse, LiteLLM_UserTable], sso_teams: List[str] +): + """ + - Get missing teams (diff b/w user_info.team_ids and sso_teams) + - Add missing user to missing teams + """ + if user_info.teams is None: + return + missing_teams = set(sso_teams) - set(user_info.teams) + missing_teams_list = list(missing_teams) + tasks = [] + tasks = [ + create_team_member_add_task(team_id, user_info) + for team_id in missing_teams_list + ] + + try: + await asyncio.gather(*tasks) + except Exception as e: + verbose_proxy_logger.debug( + f"[Non-Blocking] Error trying to add sso user to db: {e}" + ) + + +def get_disabled_non_admin_personal_key_creation(): + key_generation_settings = litellm.key_generation_settings + if key_generation_settings is None: + return False + personal_key_generation = ( + key_generation_settings.get("personal_key_generation") or {} + ) + allowed_user_roles = personal_key_generation.get("allowed_user_roles") or [] + return bool("proxy_admin" in allowed_user_roles) + + +@router.get("/sso/callback", tags=["experimental"], include_in_schema=False) +async def auth_callback(request: Request): # noqa: PLR0915 + """Verify login""" + from litellm.proxy.management_endpoints.key_management_endpoints import ( + generate_key_helper_fn, + ) + from litellm.proxy.proxy_server import ( + general_settings, + jwt_handler, + master_key, + premium_user, + prisma_client, + proxy_logging_obj, + ui_access_mode, + user_api_key_cache, + user_custom_sso, + ) + + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + # get url from request + if master_key is None: + raise ProxyException( + message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.", + type=ProxyErrorTypes.auth_error, + param="master_key", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) + if redirect_url.endswith("/"): + redirect_url += "sso/callback" + else: + redirect_url += "/sso/callback" + + result = None + if google_client_id is not None: + from fastapi_sso.sso.google import GoogleSSO + + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: + raise ProxyException( + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + google_sso = GoogleSSO( + client_id=google_client_id, + redirect_uri=redirect_url, + client_secret=google_client_secret, + ) + result = await google_sso.verify_and_process(request) + elif microsoft_client_id is not None: + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if microsoft_tenant is None: + raise ProxyException( + message="MICROSOFT_TENANT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="MICROSOFT_TENANT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + result = await microsoft_sso.verify_and_process(request) + elif generic_client_id is not None: + result = await get_generic_sso_response( + request=request, + jwt_handler=jwt_handler, + generic_client_id=generic_client_id, + redirect_url=redirect_url, + ) + # User is Authe'd in - generate key for the UI to access Proxy + user_email: Optional[str] = getattr(result, "email", None) + user_id: Optional[str] = getattr(result, "id", None) if result is not None else None + + if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None: + email_domain = user_email.split("@")[1] + allowed_domains = os.getenv("ALLOWED_EMAIL_DOMAINS").split(",") # type: ignore + if email_domain not in allowed_domains: + raise HTTPException( + status_code=401, + detail={ + "message": "The email domain={}, is not an allowed email domain={}. Contact your admin to change this.".format( + email_domain, allowed_domains + ) + }, + ) + + # generic client id + if generic_client_id is not None and result is not None: + generic_user_role_attribute_name = os.getenv( + "GENERIC_USER_ROLE_ATTRIBUTE", "role" + ) + user_id = getattr(result, "id", None) + user_email = getattr(result, "email", None) + user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore + + if user_id is None and result is not None: + _first_name = getattr(result, "first_name", "") or "" + _last_name = getattr(result, "last_name", "") or "" + user_id = _first_name + _last_name + + if user_email is not None and (user_id is None or len(user_id) == 0): + user_id = user_email + + user_info = None + user_id_models: List = [] + max_internal_user_budget = litellm.max_internal_user_budget + internal_user_budget_duration = litellm.internal_user_budget_duration + + # User might not be already created on first generation of key + # But if it is, we want their models preferences + default_ui_key_values: Dict[str, Any] = { + "duration": "24hr", + "key_max_budget": litellm.max_ui_session_budget, + "aliases": {}, + "config": {}, + "spend": 0, + "team_id": "litellm-dashboard", + } + user_defined_values: Optional[SSOUserDefinedValues] = None + + if user_custom_sso is not None: + if asyncio.iscoroutinefunction(user_custom_sso): + user_defined_values = await user_custom_sso(result) # type: ignore + else: + raise ValueError("user_custom_sso must be a coroutine function") + elif user_id is not None: + user_defined_values = SSOUserDefinedValues( + models=user_id_models, + user_id=user_id, + user_email=user_email, + max_budget=max_internal_user_budget, + user_role=None, + budget_duration=internal_user_budget_duration, + ) + + _user_id_from_sso = user_id + user_role = None + try: + if prisma_client is not None: + try: + user_info = await get_user_object( + user_id=user_id, + user_email=user_email, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + user_id_upsert=False, + parent_otel_span=None, + proxy_logging_obj=proxy_logging_obj, + sso_user_id=user_id, + ) + except Exception as e: + verbose_proxy_logger.debug(f"Error getting user object: {e}") + user_info = None + + verbose_proxy_logger.debug( + f"user_info: {user_info}; litellm.default_internal_user_params: {litellm.default_internal_user_params}" + ) + + if user_info is not None: + user_id = user_info.user_id + user_defined_values = SSOUserDefinedValues( + models=getattr(user_info, "models", user_id_models), + user_id=user_info.user_id, + user_email=getattr(user_info, "user_email", user_email), + user_role=getattr(user_info, "user_role", None), + max_budget=getattr( + user_info, "max_budget", max_internal_user_budget + ), + budget_duration=getattr( + user_info, "budget_duration", internal_user_budget_duration + ), + ) + + user_role = getattr(user_info, "user_role", None) + + # update id + await prisma_client.db.litellm_usertable.update_many( + where={"user_email": user_email}, data={"user_id": user_id} # type: ignore + ) + else: + verbose_proxy_logger.info( + "user not in DB, inserting user into LiteLLM DB" + ) + # user not in DB, insert User into LiteLLM DB + user_info = await insert_sso_user( + result_openid=result, + user_defined_values=user_defined_values, + ) + + user_role = ( + user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY + ) + sso_teams = getattr(result, "team_ids", []) + await add_missing_team_member(user_info=user_info, sso_teams=sso_teams) + + except Exception as e: + verbose_proxy_logger.debug( + f"[Non-Blocking] Error trying to add sso user to db: {e}" + ) + + if user_defined_values is None: + raise Exception( + "Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues" + ) + + verbose_proxy_logger.info( + f"user_defined_values for creating ui key: {user_defined_values}" + ) + + default_ui_key_values.update(user_defined_values) + default_ui_key_values["request_type"] = "key" + response = await generate_key_helper_fn( + **default_ui_key_values, # type: ignore + table_name="key", + ) + + key = response["token"] # type: ignore + user_id = response["user_id"] # type: ignore + + litellm_dashboard_ui = "/ui/" + user_role = user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value + if ( + os.getenv("PROXY_ADMIN_ID", None) is not None + and os.environ["PROXY_ADMIN_ID"] == user_id + ): + # checks if user is admin + user_role = LitellmUserRoles.PROXY_ADMIN.value + + verbose_proxy_logger.debug( + f"user_role: {user_role}; ui_access_mode: {ui_access_mode}" + ) + ## CHECK IF ROLE ALLOWED TO USE PROXY ## + is_admin_only_access = check_is_admin_only_access(ui_access_mode) + if is_admin_only_access: + has_access = has_admin_ui_access(user_role) + if not has_access: + raise HTTPException( + status_code=401, + detail={ + "error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}" + }, + ) + + disabled_non_admin_personal_key_creation = ( + get_disabled_non_admin_personal_key_creation() + ) + + import jwt + + jwt_token = jwt.encode( # type: ignore + { + "user_id": user_id, + "key": key, + "user_email": user_email, + "user_role": user_role, + "login_method": "sso", + "premium_user": premium_user, + "auth_header_name": general_settings.get( + "litellm_key_header_name", "Authorization" + ), + "disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation, + }, + master_key, + algorithm="HS256", + ) + if user_id is not None and isinstance(user_id, str): + litellm_dashboard_ui += "?userID=" + user_id + redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303) + redirect_response.set_cookie(key="token", value=jwt_token, secure=True) + return redirect_response + + +async def insert_sso_user( + result_openid: Optional[OpenID], + user_defined_values: Optional[SSOUserDefinedValues] = None, +) -> NewUserResponse: + """ + Helper function to create a New User in LiteLLM DB after a successful SSO login + + Args: + result_openid (OpenID): User information in OpenID format if the login was successful. + user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read + + Returns: + Tuple[str, str]: User ID and User Role + """ + verbose_proxy_logger.debug( + f"Inserting SSO user into DB. User values: {user_defined_values}" + ) + + if user_defined_values is None: + raise ValueError("user_defined_values is None") + + if litellm.default_internal_user_params: + user_defined_values.update(litellm.default_internal_user_params) # type: ignore + + # Set budget for internal users + if user_defined_values.get("user_role") == LitellmUserRoles.INTERNAL_USER.value: + if user_defined_values.get("max_budget") is None: + user_defined_values["max_budget"] = litellm.max_internal_user_budget + if user_defined_values.get("budget_duration") is None: + user_defined_values["budget_duration"] = ( + litellm.internal_user_budget_duration + ) + + if user_defined_values["user_role"] is None: + user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY + + new_user_request = NewUserRequest( + user_id=user_defined_values["user_id"], + user_email=user_defined_values["user_email"], + user_role=user_defined_values["user_role"], # type: ignore + max_budget=user_defined_values["max_budget"], + budget_duration=user_defined_values["budget_duration"], + ) + + if result_openid: + new_user_request.metadata = {"auth_provider": result_openid.provider} + + response = await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth()) + + return response + + +@router.get( + "/sso/get/ui_settings", + tags=["experimental"], + include_in_schema=False, + dependencies=[Depends(user_api_key_auth)], +) +async def get_ui_settings(request: Request): + from litellm.proxy.proxy_server import general_settings, proxy_state + + _proxy_base_url = os.getenv("PROXY_BASE_URL", None) + _logout_url = os.getenv("PROXY_LOGOUT_URL", None) + _is_sso_enabled = _has_user_setup_sso() + disable_expensive_db_queries = ( + proxy_state.get_proxy_state_variable("spend_logs_row_count") + > MAX_SPENDLOG_ROWS_TO_QUERY + ) + default_team_disabled = general_settings.get("default_team_disabled", False) + if "PROXY_DEFAULT_TEAM_DISABLED" in os.environ: + if os.environ["PROXY_DEFAULT_TEAM_DISABLED"].lower() == "true": + default_team_disabled = True + + return { + "PROXY_BASE_URL": _proxy_base_url, + "PROXY_LOGOUT_URL": _logout_url, + "DEFAULT_TEAM_DISABLED": default_team_disabled, + "SSO_ENABLED": _is_sso_enabled, + "NUM_SPEND_LOGS_ROWS": proxy_state.get_proxy_state_variable( + "spend_logs_row_count" + ), + "DISABLE_EXPENSIVE_DB_QUERIES": disable_expensive_db_queries, + } |