diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/auth/user_api_key_auth.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/auth/user_api_key_auth.py | 1337 |
1 files changed, 1337 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/user_api_key_auth.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/user_api_key_auth.py new file mode 100644 index 00000000..ace0bf49 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/user_api_key_auth.py @@ -0,0 +1,1337 @@ +""" +This file handles authentication for the LiteLLM Proxy. + +it checks if the user passed a valid API Key to the LiteLLM Proxy + +Returns a UserAPIKeyAuth object if the API key is valid + +""" + +import asyncio +import re +import secrets +from datetime import datetime, timezone +from typing import Optional, cast + +import fastapi +from fastapi import HTTPException, Request, WebSocket, status +from fastapi.security.api_key import APIKeyHeader + +import litellm +from litellm._logging import verbose_logger, verbose_proxy_logger +from litellm._service_logger import ServiceLogging +from litellm.caching import DualCache +from litellm.litellm_core_utils.dd_tracing import tracer +from litellm.proxy._types import * +from litellm.proxy.auth.auth_checks import ( + _cache_key_object, + _handle_failed_db_connection_for_get_key_object, + _virtual_key_max_budget_check, + _virtual_key_soft_budget_check, + can_key_call_model, + common_checks, + get_end_user_object, + get_key_object, + get_team_object, + get_user_object, + is_valid_fallback_model, +) +from litellm.proxy.auth.auth_utils import ( + _get_request_ip_address, + get_end_user_id_from_request_body, + get_request_route, + is_pass_through_provider_route, + pre_db_read_auth_checks, + route_in_additonal_public_routes, + should_run_auth_on_pass_through_provider_route, +) +from litellm.proxy.auth.handle_jwt import JWTAuthManager, JWTHandler +from litellm.proxy.auth.oauth2_check import check_oauth2_token +from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request +from litellm.proxy.auth.route_checks import RouteChecks +from litellm.proxy.auth.service_account_checks import service_account_checks +from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm.proxy.utils import PrismaClient, ProxyLogging +from litellm.types.services import ServiceTypes + +user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL + + +api_key_header = APIKeyHeader( + name=SpecialHeaders.openai_authorization.value, + auto_error=False, + description="Bearer token", +) +azure_api_key_header = APIKeyHeader( + name=SpecialHeaders.azure_authorization.value, + auto_error=False, + description="Some older versions of the openai Python package will send an API-Key header with just the API key ", +) +anthropic_api_key_header = APIKeyHeader( + name=SpecialHeaders.anthropic_authorization.value, + auto_error=False, + description="If anthropic client used.", +) +google_ai_studio_api_key_header = APIKeyHeader( + name=SpecialHeaders.google_ai_studio_authorization.value, + auto_error=False, + description="If google ai studio client used.", +) +azure_apim_header = APIKeyHeader( + name=SpecialHeaders.azure_apim_authorization.value, + auto_error=False, + description="The default name of the subscription key header of Azure", +) + + +def _get_bearer_token( + api_key: str, +): + if api_key.startswith("Bearer "): # ensure Bearer token passed in + api_key = api_key.replace("Bearer ", "") # extract the token + elif api_key.startswith("Basic "): + api_key = api_key.replace("Basic ", "") # handle langfuse input + elif api_key.startswith("bearer "): + api_key = api_key.replace("bearer ", "") + else: + api_key = "" + return api_key + + +def _is_ui_route( + route: str, + user_obj: Optional[LiteLLM_UserTable] = None, +) -> bool: + """ + - Check if the route is a UI used route + """ + # this token is only used for managing the ui + allowed_routes = LiteLLMRoutes.ui_routes.value + # check if the current route startswith any of the allowed routes + if ( + route is not None + and isinstance(route, str) + and any(route.startswith(allowed_route) for allowed_route in allowed_routes) + ): + # Do something if the current route starts with any of the allowed routes + return True + elif any( + RouteChecks._route_matches_pattern(route=route, pattern=allowed_route) + for allowed_route in allowed_routes + ): + return True + return False + + +def _is_api_route_allowed( + route: str, + request: Request, + request_data: dict, + api_key: str, + valid_token: Optional[UserAPIKeyAuth], + user_obj: Optional[LiteLLM_UserTable] = None, +) -> bool: + """ + - Route b/w api token check and normal token check + """ + _user_role = _get_user_role(user_obj=user_obj) + + if valid_token is None: + raise Exception("Invalid proxy server token passed. valid_token=None.") + + if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin + RouteChecks.non_proxy_admin_allowed_routes_check( + user_obj=user_obj, + _user_role=_user_role, + route=route, + request=request, + request_data=request_data, + api_key=api_key, + valid_token=valid_token, + ) + return True + + +def _is_allowed_route( + route: str, + token_type: Literal["ui", "api"], + request: Request, + request_data: dict, + api_key: str, + valid_token: Optional[UserAPIKeyAuth], + user_obj: Optional[LiteLLM_UserTable] = None, +) -> bool: + """ + - Route b/w ui token check and normal token check + """ + + if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj): + return True + else: + return _is_api_route_allowed( + route=route, + request=request, + request_data=request_data, + api_key=api_key, + valid_token=valid_token, + user_obj=user_obj, + ) + + +async def user_api_key_auth_websocket(websocket: WebSocket): + # Accept the WebSocket connection + + request = Request(scope={"type": "http"}) + request._url = websocket.url + + query_params = websocket.query_params + + model = query_params.get("model") + + async def return_body(): + return_string = f'{{"model": "{model}"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body # type: ignore + + # Extract the Authorization header + authorization = websocket.headers.get("authorization") + + # If no Authorization header, try the api-key header + if not authorization: + api_key = websocket.headers.get("api-key") + if not api_key: + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + raise HTTPException(status_code=403, detail="No API key provided") + else: + # Extract the API key from the Bearer token + if not authorization.startswith("Bearer "): + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + raise HTTPException( + status_code=403, detail="Invalid Authorization header format" + ) + + api_key = authorization[len("Bearer ") :].strip() + + # Call user_api_key_auth with the extracted API key + # Note: You'll need to modify this to work with WebSocket context if needed + try: + return await user_api_key_auth(request=request, api_key=f"Bearer {api_key}") + except Exception as e: + verbose_proxy_logger.exception(e) + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + raise HTTPException(status_code=403, detail=str(e)) + + +def update_valid_token_with_end_user_params( + valid_token: UserAPIKeyAuth, end_user_params: dict +) -> UserAPIKeyAuth: + valid_token.end_user_id = end_user_params.get("end_user_id") + valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit") + valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit") + valid_token.allowed_model_region = end_user_params.get("allowed_model_region") + return valid_token + + +async def get_global_proxy_spend( + litellm_proxy_admin_name: str, + user_api_key_cache: DualCache, + prisma_client: Optional[PrismaClient], + token: str, + proxy_logging_obj: ProxyLogging, +) -> Optional[float]: + global_proxy_spend = None + if litellm.max_budget > 0: # user set proxy max budget + # check cache + global_proxy_spend = await user_api_key_cache.async_get_cache( + key="{}:spend".format(litellm_proxy_admin_name) + ) + if global_proxy_spend is None and prisma_client is not None: + # get from db + sql_query = ( + """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";""" + ) + + response = await prisma_client.db.query_raw(query=sql_query) + + global_proxy_spend = response[0]["total_spend"] + + await user_api_key_cache.async_set_cache( + key="{}:spend".format(litellm_proxy_admin_name), + value=global_proxy_spend, + ) + if global_proxy_spend is not None: + user_info = CallInfo( + user_id=litellm_proxy_admin_name, + max_budget=litellm.max_budget, + spend=global_proxy_spend, + token=token, + ) + asyncio.create_task( + proxy_logging_obj.budget_alerts( + type="proxy_budget", + user_info=user_info, + ) + ) + return global_proxy_spend + + +def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str: + is_admin = jwt_handler.is_admin(scopes=scopes) + if is_admin: + return LitellmUserRoles.PROXY_ADMIN + else: + return LitellmUserRoles.TEAM + + +def get_model_from_request(request_data: dict, route: str) -> Optional[str]: + + # First try to get model from request_data + model = request_data.get("model") + + # If model not in request_data, try to extract from route + if model is None: + # Parse model from route that follows the pattern /openai/deployments/{model}/* + match = re.match(r"/openai/deployments/([^/]+)", route) + if match: + model = match.group(1) + + return model + + +async def _user_api_key_auth_builder( # noqa: PLR0915 + request: Request, + api_key: str, + azure_api_key_header: str, + anthropic_api_key_header: Optional[str], + google_ai_studio_api_key_header: Optional[str], + azure_apim_header: Optional[str], + request_data: dict, +) -> UserAPIKeyAuth: + + from litellm.proxy.proxy_server import ( + general_settings, + jwt_handler, + litellm_proxy_admin_name, + llm_model_list, + llm_router, + master_key, + model_max_budget_limiter, + open_telemetry_logger, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + user_custom_auth, + ) + + parent_otel_span: Optional[Span] = None + start_time = datetime.now() + route: str = get_request_route(request=request) + try: + + # get the request body + + await pre_db_read_auth_checks( + request_data=request_data, + request=request, + route=route, + ) + pass_through_endpoints: Optional[List[dict]] = general_settings.get( + "pass_through_endpoints", None + ) + passed_in_key: Optional[str] = None + if isinstance(api_key, str): + passed_in_key = api_key + api_key = _get_bearer_token(api_key=api_key) + elif isinstance(azure_api_key_header, str): + api_key = azure_api_key_header + elif isinstance(anthropic_api_key_header, str): + api_key = anthropic_api_key_header + elif isinstance(google_ai_studio_api_key_header, str): + api_key = google_ai_studio_api_key_header + elif isinstance(azure_apim_header, str): + api_key = azure_apim_header + elif pass_through_endpoints is not None: + for endpoint in pass_through_endpoints: + if endpoint.get("path", "") == route: + headers: Optional[dict] = endpoint.get("headers", None) + if headers is not None: + header_key: str = headers.get("litellm_user_api_key", "") + if request.headers.get(key=header_key) is not None: + api_key = request.headers.get(key=header_key) + + # if user wants to pass LiteLLM_Master_Key as a custom header, example pass litellm keys as X-LiteLLM-Key: Bearer sk-1234 + custom_litellm_key_header_name = general_settings.get("litellm_key_header_name") + if custom_litellm_key_header_name is not None: + api_key = get_api_key_from_custom_header( + request=request, + custom_litellm_key_header_name=custom_litellm_key_header_name, + ) + + if open_telemetry_logger is not None: + parent_otel_span = ( + open_telemetry_logger.create_litellm_proxy_request_started_span( + start_time=start_time, + headers=dict(request.headers), + ) + ) + + ### USER-DEFINED AUTH FUNCTION ### + if user_custom_auth is not None: + response = await user_custom_auth(request=request, api_key=api_key) # type: ignore + return UserAPIKeyAuth.model_validate(response) + + ### LITELLM-DEFINED AUTH FUNCTION ### + #### IF JWT #### + """ + LiteLLM supports using JWTs. + + Enable this in proxy config, by setting + ``` + general_settings: + enable_jwt_auth: true + ``` + """ + + ######## Route Checks Before Reading DB / Cache for "token" ################ + if ( + route in LiteLLMRoutes.public_routes.value # type: ignore + or route_in_additonal_public_routes(current_route=route) + ): + # check if public endpoint + return UserAPIKeyAuth(user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY) + elif is_pass_through_provider_route(route=route): + if should_run_auth_on_pass_through_provider_route(route=route) is False: + return UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY + ) + + ########## End of Route Checks Before Reading DB / Cache for "token" ######## + + if general_settings.get("enable_oauth2_auth", False) is True: + # return UserAPIKeyAuth object + # helper to check if the api_key is a valid oauth2 token + from litellm.proxy.proxy_server import premium_user + + if premium_user is not True: + raise ValueError( + "Oauth2 token validation is only available for premium users" + + CommonProxyErrors.not_premium_user.value + ) + + return await check_oauth2_token(token=api_key) + + if general_settings.get("enable_oauth2_proxy_auth", False) is True: + return await handle_oauth2_proxy_request(request=request) + + if general_settings.get("enable_jwt_auth", False) is True: + from litellm.proxy.proxy_server import premium_user + + if premium_user is not True: + raise ValueError( + f"JWT Auth is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}" + ) + is_jwt = jwt_handler.is_jwt(token=api_key) + verbose_proxy_logger.debug("is_jwt: %s", is_jwt) + if is_jwt: + result = await JWTAuthManager.auth_builder( + request_data=request_data, + general_settings=general_settings, + api_key=api_key, + jwt_handler=jwt_handler, + route=route, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + parent_otel_span=parent_otel_span, + ) + + is_proxy_admin = result["is_proxy_admin"] + team_id = result["team_id"] + team_object = result["team_object"] + user_id = result["user_id"] + user_object = result["user_object"] + end_user_id = result["end_user_id"] + end_user_object = result["end_user_object"] + org_id = result["org_id"] + token = result["token"] + + global_proxy_spend = await get_global_proxy_spend( + litellm_proxy_admin_name=litellm_proxy_admin_name, + user_api_key_cache=user_api_key_cache, + prisma_client=prisma_client, + token=token, + proxy_logging_obj=proxy_logging_obj, + ) + + if is_proxy_admin: + return UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, + ) + # run through common checks + _ = await common_checks( + request_body=request_data, + team_object=team_object, + user_object=user_object, + end_user_object=end_user_object, + general_settings=general_settings, + global_proxy_spend=global_proxy_spend, + route=route, + llm_router=llm_router, + proxy_logging_obj=proxy_logging_obj, + valid_token=None, + ) + + # return UserAPIKeyAuth object + return UserAPIKeyAuth( + api_key=None, + team_id=team_id, + team_tpm_limit=( + team_object.tpm_limit if team_object is not None else None + ), + team_rpm_limit=( + team_object.rpm_limit if team_object is not None else None + ), + team_models=team_object.models if team_object is not None else [], + user_role=LitellmUserRoles.INTERNAL_USER, + user_id=user_id, + org_id=org_id, + parent_otel_span=parent_otel_span, + end_user_id=end_user_id, + ) + + #### ELSE #### + ## CHECK PASS-THROUGH ENDPOINTS ## + is_mapped_pass_through_route: bool = False + for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: # type: ignore + if route.startswith(mapped_route): + is_mapped_pass_through_route = True + if is_mapped_pass_through_route: + if request.headers.get("litellm_user_api_key") is not None: + api_key = request.headers.get("litellm_user_api_key") or "" + if pass_through_endpoints is not None: + for endpoint in pass_through_endpoints: + if isinstance(endpoint, dict) and endpoint.get("path", "") == route: + ## IF AUTH DISABLED + if endpoint.get("auth") is not True: + return UserAPIKeyAuth() + ## IF AUTH ENABLED + ### IF CUSTOM PARSER REQUIRED + if ( + endpoint.get("custom_auth_parser") is not None + and endpoint.get("custom_auth_parser") == "langfuse" + ): + """ + - langfuse returns {'Authorization': 'Basic YW55dGhpbmc6YW55dGhpbmc'} + - check the langfuse public key if it contains the litellm api key + """ + import base64 + + api_key = api_key.replace("Basic ", "").strip() + decoded_bytes = base64.b64decode(api_key) + decoded_str = decoded_bytes.decode("utf-8") + api_key = decoded_str.split(":")[0] + else: + headers = endpoint.get("headers", None) + if headers is not None: + header_key = headers.get("litellm_user_api_key", "") + if ( + isinstance(request.headers, dict) + and request.headers.get(key=header_key) is not None # type: ignore + ): + api_key = request.headers.get(key=header_key) # type: ignore + if master_key is None: + if isinstance(api_key, str): + return UserAPIKeyAuth( + api_key=api_key, + user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, + ) + else: + return UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, + ) + elif api_key is None: # only require api key if master key is set + raise Exception("No api key passed in.") + elif api_key == "": + # missing 'Bearer ' prefix + raise Exception( + f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}" + ) + + if route == "/user/auth": + if general_settings.get("allow_user_auth", False) is True: + return UserAPIKeyAuth() + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="'allow_user_auth' not set or set to False", + ) + + ## Check END-USER OBJECT + _end_user_object = None + end_user_params = {} + + end_user_id = get_end_user_id_from_request_body(request_data) + if end_user_id: + try: + end_user_params["end_user_id"] = end_user_id + + # get end-user object + _end_user_object = await get_end_user_object( + end_user_id=end_user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + if _end_user_object is not None: + end_user_params["allowed_model_region"] = ( + _end_user_object.allowed_model_region + ) + if _end_user_object.litellm_budget_table is not None: + budget_info = _end_user_object.litellm_budget_table + if budget_info.tpm_limit is not None: + end_user_params["end_user_tpm_limit"] = ( + budget_info.tpm_limit + ) + if budget_info.rpm_limit is not None: + end_user_params["end_user_rpm_limit"] = ( + budget_info.rpm_limit + ) + if budget_info.max_budget is not None: + end_user_params["end_user_max_budget"] = ( + budget_info.max_budget + ) + except Exception as e: + if isinstance(e, litellm.BudgetExceededError): + raise e + verbose_proxy_logger.debug( + "Unable to find user in db. Error - {}".format(str(e)) + ) + pass + + ### CHECK IF ADMIN ### + # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead + ### CHECK IF ADMIN ### + # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead + ## Check CACHE + try: + valid_token = await get_key_object( + hashed_token=hash_token(api_key), + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + check_cache_only=True, + ) + except Exception: + verbose_logger.debug("api key not found in cache.") + valid_token = None + + if ( + valid_token is not None + and isinstance(valid_token, UserAPIKeyAuth) + and valid_token.user_role == LitellmUserRoles.PROXY_ADMIN + ): + # update end-user params on valid token + valid_token = update_valid_token_with_end_user_params( + valid_token=valid_token, end_user_params=end_user_params + ) + valid_token.parent_otel_span = parent_otel_span + + return valid_token + + if ( + valid_token is not None + and isinstance(valid_token, UserAPIKeyAuth) + and valid_token.team_id is not None + ): + ## UPDATE TEAM VALUES BASED ON CACHED TEAM OBJECT - allows `/team/update` values to work for cached token + try: + team_obj: LiteLLM_TeamTableCachedObj = await get_team_object( + team_id=valid_token.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + check_cache_only=True, + ) + + if ( + team_obj.last_refreshed_at is not None + and valid_token.last_refreshed_at is not None + and team_obj.last_refreshed_at > valid_token.last_refreshed_at + ): + team_obj_dict = team_obj.__dict__ + + for k, v in team_obj_dict.items(): + field_name = f"team_{k}" + if field_name in valid_token.__fields__: + setattr(valid_token, field_name, v) + except Exception as e: + verbose_logger.debug( + e + ) # moving from .warning to .debug as it spams logs when team missing from cache. + + try: + is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore + except Exception: + is_master_key_valid = False + + ## VALIDATE MASTER KEY ## + try: + assert isinstance(master_key, str) + except Exception: + raise HTTPException( + status_code=500, + detail={ + "Master key must be a valid string. Current type={}".format( + type(master_key) + ) + }, + ) + + if is_master_key_valid: + _user_api_key_obj = await _return_user_api_key_auth_obj( + user_obj=None, + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key=master_key, + parent_otel_span=parent_otel_span, + valid_token_dict={ + **end_user_params, + "user_id": litellm_proxy_admin_name, + }, + route=route, + start_time=start_time, + ) + asyncio.create_task( + _cache_key_object( + hashed_token=hash_token(master_key), + user_api_key_obj=_user_api_key_obj, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + ) + + _user_api_key_obj = update_valid_token_with_end_user_params( + valid_token=_user_api_key_obj, end_user_params=end_user_params + ) + + return _user_api_key_obj + + ## IF it's not a master key + ## Route should not be in master_key_only_routes + if route in LiteLLMRoutes.master_key_only_routes.value: # type: ignore + raise Exception( + f"Tried to access route={route}, which is only for MASTER KEY" + ) + + ## Check DB + if isinstance( + api_key, str + ): # if generated token, make sure it starts with sk-. + assert api_key.startswith( + "sk-" + ), "LiteLLM Virtual Key expected. Received={}, expected to start with 'sk-'.".format( + api_key + ) # prevent token hashes from being used + else: + verbose_logger.warning( + "litellm.proxy.proxy_server.user_api_key_auth(): Warning - Key={} is not a string.".format( + api_key + ) + ) + + if ( + prisma_client is None + ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error + return await _handle_failed_db_connection_for_get_key_object( + e=Exception("No connected db.") + ) + + ## check for cache hit (In-Memory Cache) + _user_role = None + if api_key.startswith("sk-"): + api_key = hash_token(token=api_key) + + if valid_token is None: + try: + valid_token = await get_key_object( + hashed_token=api_key, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + # update end-user params on valid token + # These can change per request - it's important to update them here + valid_token.end_user_id = end_user_params.get("end_user_id") + valid_token.end_user_tpm_limit = end_user_params.get( + "end_user_tpm_limit" + ) + valid_token.end_user_rpm_limit = end_user_params.get( + "end_user_rpm_limit" + ) + valid_token.allowed_model_region = end_user_params.get( + "allowed_model_region" + ) + # update key budget with temp budget increase + valid_token = _update_key_budget_with_temp_budget_increase( + valid_token + ) # updating it here, allows all downstream reporting / checks to use the updated budget + except Exception: + verbose_logger.info( + "litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format( + api_key + ) + ) + valid_token = None + + if valid_token is None: + raise Exception( + "Invalid proxy server token passed. Received API Key (hashed)={}. Unable to find token in cache or `LiteLLM_VerificationTokenTable`".format( + api_key + ) + ) + + user_obj: Optional[LiteLLM_UserTable] = None + valid_token_dict: dict = {} + if valid_token is not None: + # Got Valid Token from Cache, DB + # Run checks for + # 1. If token can call model + ## 1a. If token can call fallback models (if client-side fallbacks given) + # 2. If user_id for this token is in budget + # 3. If the user spend within their own team is within budget + # 4. If 'user' passed to /chat/completions, /embeddings endpoint is in budget + # 5. If token is expired + # 6. If token spend is under Budget for the token + # 7. If token spend per model is under budget per model + # 8. If token spend is under team budget + # 9. If team spend is under team budget + + ## base case ## key is disabled + if valid_token.blocked is True: + raise Exception( + "Key is blocked. Update via `/key/unblock` if you're admin." + ) + config = valid_token.config + + if config != {}: + model_list = config.get("model_list", []) + new_model_list = model_list + verbose_proxy_logger.debug( + f"\n new llm router model list {new_model_list}" + ) + elif ( + isinstance(valid_token.models, list) + and "all-team-models" in valid_token.models + ): + # Do not do any validation at this step + # the validation will occur when checking the team has access to this model + pass + else: + model = get_model_from_request(request_data, route) + fallback_models = cast( + Optional[List[ALL_FALLBACK_MODEL_VALUES]], + request_data.get("fallbacks", None), + ) + + if model is not None: + await can_key_call_model( + model=model, + llm_model_list=llm_model_list, + valid_token=valid_token, + llm_router=llm_router, + ) + + if fallback_models is not None: + for m in fallback_models: + await can_key_call_model( + model=m["model"] if isinstance(m, dict) else m, + llm_model_list=llm_model_list, + valid_token=valid_token, + llm_router=llm_router, + ) + await is_valid_fallback_model( + model=m["model"] if isinstance(m, dict) else m, + llm_router=llm_router, + user_model=None, + ) + + # Check 2. If user_id for this token is in budget - done in common_checks() + if valid_token.user_id is not None: + try: + user_obj = await get_user_object( + user_id=valid_token.user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + user_id_upsert=False, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + except Exception as e: + verbose_logger.debug( + "litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to get user from db/cache. Setting user_obj to None. Exception received - {}".format( + str(e) + ) + ) + user_obj = None + + # Check 3. Check if user is in their team budget + if valid_token.team_member_spend is not None: + if prisma_client is not None: + + _cache_key = f"{valid_token.team_id}_{valid_token.user_id}" + + team_member_info = await user_api_key_cache.async_get_cache( + key=_cache_key + ) + if team_member_info is None: + # read from DB + _user_id = valid_token.user_id + _team_id = valid_token.team_id + + if _user_id is not None and _team_id is not None: + team_member_info = await prisma_client.db.litellm_teammembership.find_first( + where={ + "user_id": _user_id, + "team_id": _team_id, + }, # type: ignore + include={"litellm_budget_table": True}, + ) + await user_api_key_cache.async_set_cache( + key=_cache_key, + value=team_member_info, + ) + + if ( + team_member_info is not None + and team_member_info.litellm_budget_table is not None + ): + team_member_budget = ( + team_member_info.litellm_budget_table.max_budget + ) + if team_member_budget is not None and team_member_budget > 0: + if valid_token.team_member_spend > team_member_budget: + raise litellm.BudgetExceededError( + current_cost=valid_token.team_member_spend, + max_budget=team_member_budget, + ) + + # Check 3. If token is expired + if valid_token.expires is not None: + current_time = datetime.now(timezone.utc) + if isinstance(valid_token.expires, datetime): + expiry_time = valid_token.expires + else: + expiry_time = datetime.fromisoformat(valid_token.expires) + if ( + expiry_time.tzinfo is None + or expiry_time.tzinfo.utcoffset(expiry_time) is None + ): + expiry_time = expiry_time.replace(tzinfo=timezone.utc) + verbose_proxy_logger.debug( + f"Checking if token expired, expiry time {expiry_time} and current time {current_time}" + ) + if expiry_time < current_time: + # Token exists but is expired. + raise ProxyException( + message=f"Authentication Error - Expired Key. Key Expiry time {expiry_time} and current time {current_time}", + type=ProxyErrorTypes.expired_key, + code=400, + param=api_key, + ) + + # Check 4. Token Spend is under budget + await _virtual_key_max_budget_check( + valid_token=valid_token, + proxy_logging_obj=proxy_logging_obj, + user_obj=user_obj, + ) + + # Check 5. Soft Budget Check + await _virtual_key_soft_budget_check( + valid_token=valid_token, + proxy_logging_obj=proxy_logging_obj, + ) + + # Check 5. Token Model Spend is under Model budget + max_budget_per_model = valid_token.model_max_budget + current_model = request_data.get("model", None) + + if ( + max_budget_per_model is not None + and isinstance(max_budget_per_model, dict) + and len(max_budget_per_model) > 0 + and prisma_client is not None + and current_model is not None + and valid_token.token is not None + ): + ## GET THE SPEND FOR THIS MODEL + await model_max_budget_limiter.is_key_within_model_budget( + user_api_key_dict=valid_token, + model=current_model, + ) + + # Check 6: Additional Common Checks across jwt + key auth + if valid_token.team_id is not None: + _team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable( + team_id=valid_token.team_id, + max_budget=valid_token.team_max_budget, + spend=valid_token.team_spend, + tpm_limit=valid_token.team_tpm_limit, + rpm_limit=valid_token.team_rpm_limit, + blocked=valid_token.team_blocked, + models=valid_token.team_models, + metadata=valid_token.team_metadata, + ) + else: + _team_obj = None + + # Check 7: Check if key is a service account key + await service_account_checks( + valid_token=valid_token, + request_data=request_data, + ) + + user_api_key_cache.set_cache( + key=valid_token.team_id, value=_team_obj + ) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py + + global_proxy_spend = None + if ( + litellm.max_budget > 0 and prisma_client is not None + ): # user set proxy max budget + # check cache + global_proxy_spend = await user_api_key_cache.async_get_cache( + key="{}:spend".format(litellm_proxy_admin_name) + ) + if global_proxy_spend is None: + # get from db + sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";""" + + response = await prisma_client.db.query_raw(query=sql_query) + + global_proxy_spend = response[0]["total_spend"] + await user_api_key_cache.async_set_cache( + key="{}:spend".format(litellm_proxy_admin_name), + value=global_proxy_spend, + ) + + if global_proxy_spend is not None: + call_info = CallInfo( + token=valid_token.token, + spend=global_proxy_spend, + max_budget=litellm.max_budget, + user_id=litellm_proxy_admin_name, + team_id=valid_token.team_id, + ) + asyncio.create_task( + proxy_logging_obj.budget_alerts( + type="proxy_budget", + user_info=call_info, + ) + ) + _ = await common_checks( + request_body=request_data, + team_object=_team_obj, + user_object=user_obj, + end_user_object=_end_user_object, + general_settings=general_settings, + global_proxy_spend=global_proxy_spend, + route=route, + llm_router=llm_router, + proxy_logging_obj=proxy_logging_obj, + valid_token=valid_token, + ) + # Token passed all checks + if valid_token is None: + raise HTTPException(401, detail="Invalid API key") + if valid_token.token is None: + raise HTTPException(401, detail="Invalid API key, no token associated") + api_key = valid_token.token + + # Add hashed token to cache + asyncio.create_task( + _cache_key_object( + hashed_token=api_key, + user_api_key_obj=valid_token, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + ) + + valid_token_dict = valid_token.model_dump(exclude_none=True) + valid_token_dict.pop("token", None) + + if _end_user_object is not None: + valid_token_dict.update(end_user_params) + + # check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions + # sso/login, ui/login, /key functions and /user functions + # this will never be allowed to call /chat/completions + token_team = getattr(valid_token, "team_id", None) + token_type: Literal["ui", "api"] = ( + "ui" + if token_team is not None and token_team == "litellm-dashboard" + else "api" + ) + _is_route_allowed = _is_allowed_route( + route=route, + token_type=token_type, + user_obj=user_obj, + request=request, + request_data=request_data, + api_key=api_key, + valid_token=valid_token, + ) + if not _is_route_allowed: + raise HTTPException(401, detail="Invalid route for UI token") + + if valid_token is None: + # No token was found when looking up in the DB + raise Exception("Invalid proxy server token passed") + if valid_token_dict is not None: + return await _return_user_api_key_auth_obj( + user_obj=user_obj, + api_key=api_key, + parent_otel_span=parent_otel_span, + valid_token_dict=valid_token_dict, + route=route, + start_time=start_time, + ) + else: + raise Exception() + except Exception as e: + requester_ip = _get_request_ip_address( + request=request, + use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), + ) + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format( + str(e), + requester_ip, + ), + extra={"requester_ip": requester_ip}, + ) + + # Log this exception to OTEL, Datadog etc + user_api_key_dict = UserAPIKeyAuth( + parent_otel_span=parent_otel_span, + api_key=api_key, + ) + asyncio.create_task( + proxy_logging_obj.post_call_failure_hook( + request_data=request_data, + original_exception=e, + user_api_key_dict=user_api_key_dict, + error_type=ProxyErrorTypes.auth_error, + route=route, + ) + ) + + if isinstance(e, litellm.BudgetExceededError): + raise ProxyException( + message=e.message, + type=ProxyErrorTypes.budget_exceeded, + param=None, + code=400, + ) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=status.HTTP_401_UNAUTHORIZED, + ) + + +@tracer.wrap() +async def user_api_key_auth( + request: Request, + api_key: str = fastapi.Security(api_key_header), + azure_api_key_header: str = fastapi.Security(azure_api_key_header), + anthropic_api_key_header: Optional[str] = fastapi.Security( + anthropic_api_key_header + ), + google_ai_studio_api_key_header: Optional[str] = fastapi.Security( + google_ai_studio_api_key_header + ), + azure_apim_header: Optional[str] = fastapi.Security(azure_apim_header), +) -> UserAPIKeyAuth: + """ + Parent function to authenticate user api key / jwt token. + """ + + request_data = await _read_request_body(request=request) + + user_api_key_auth_obj = await _user_api_key_auth_builder( + request=request, + api_key=api_key, + azure_api_key_header=azure_api_key_header, + anthropic_api_key_header=anthropic_api_key_header, + google_ai_studio_api_key_header=google_ai_studio_api_key_header, + azure_apim_header=azure_apim_header, + request_data=request_data, + ) + + end_user_id = get_end_user_id_from_request_body(request_data) + if end_user_id is not None: + user_api_key_auth_obj.end_user_id = end_user_id + + return user_api_key_auth_obj + + +async def _return_user_api_key_auth_obj( + user_obj: Optional[LiteLLM_UserTable], + api_key: str, + parent_otel_span: Optional[Span], + valid_token_dict: dict, + route: str, + start_time: datetime, + user_role: Optional[LitellmUserRoles] = None, +) -> UserAPIKeyAuth: + end_time = datetime.now() + + asyncio.create_task( + user_api_key_service_logger_obj.async_service_success_hook( + service=ServiceTypes.AUTH, + call_type=route, + start_time=start_time, + end_time=end_time, + duration=end_time.timestamp() - start_time.timestamp(), + parent_otel_span=parent_otel_span, + ) + ) + + retrieved_user_role = ( + user_role or _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER + ) + + user_api_key_kwargs = { + "api_key": api_key, + "parent_otel_span": parent_otel_span, + "user_role": retrieved_user_role, + **valid_token_dict, + } + if user_obj is not None: + user_api_key_kwargs.update( + user_tpm_limit=user_obj.tpm_limit, + user_rpm_limit=user_obj.rpm_limit, + user_email=user_obj.user_email, + ) + if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj): + user_api_key_kwargs.update( + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + return UserAPIKeyAuth(**user_api_key_kwargs) + else: + return UserAPIKeyAuth(**user_api_key_kwargs) + + +def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]): + if user_obj is None: + return False + + if ( + user_obj.user_role is not None + and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value + ): + return True + + if ( + user_obj.user_role is not None + and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value + ): + return True + + return False + + +def _get_user_role( + user_obj: Optional[LiteLLM_UserTable], +) -> Optional[LitellmUserRoles]: + if user_obj is None: + return None + + _user = user_obj + + _user_role = _user.user_role + try: + role = LitellmUserRoles(_user_role) + except ValueError: + return LitellmUserRoles.INTERNAL_USER + + return role + + +def get_api_key_from_custom_header( + request: Request, custom_litellm_key_header_name: str +) -> str: + """ + Get API key from custom header + + Args: + request (Request): Request object + custom_litellm_key_header_name (str): Custom header name + + Returns: + Optional[str]: API key + """ + api_key: str = "" + # use this as the virtual key passed to litellm proxy + custom_litellm_key_header_name = custom_litellm_key_header_name.lower() + _headers = {k.lower(): v for k, v in request.headers.items()} + verbose_proxy_logger.debug( + "searching for custom_litellm_key_header_name= %s, in headers=%s", + custom_litellm_key_header_name, + _headers, + ) + custom_api_key = _headers.get(custom_litellm_key_header_name) + if custom_api_key: + api_key = _get_bearer_token(api_key=custom_api_key) + verbose_proxy_logger.debug( + "Found custom API key using header: {}, setting api_key={}".format( + custom_litellm_key_header_name, api_key + ) + ) + else: + verbose_proxy_logger.exception( + f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer <api_key>" + ) + return api_key + + +def _get_temp_budget_increase(valid_token: UserAPIKeyAuth): + valid_token_metadata = valid_token.metadata + if ( + "temp_budget_increase" in valid_token_metadata + and "temp_budget_expiry" in valid_token_metadata + ): + expiry = datetime.fromisoformat(valid_token_metadata["temp_budget_expiry"]) + if expiry > datetime.now(): + return valid_token_metadata["temp_budget_increase"] + return None + + +def _update_key_budget_with_temp_budget_increase( + valid_token: UserAPIKeyAuth, +) -> UserAPIKeyAuth: + if valid_token.max_budget is None: + return valid_token + temp_budget_increase = _get_temp_budget_increase(valid_token) or 0.0 + valid_token.max_budget = valid_token.max_budget + temp_budget_increase + return valid_token |