diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/auth')
13 files changed, 5383 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py new file mode 100644 index 00000000..f029511d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py @@ -0,0 +1,1373 @@ +# What is this? +## Common auth checks between jwt + key based auth +""" +Got Valid Token from Cache, DB +Run checks for: + +1. If user can call model +2. If user is in budget +3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget +""" +import asyncio +import re +import time +import traceback +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast + +from fastapi import status +from pydantic import BaseModel + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.caching.dual_cache import LimitedSizeOrderedDict +from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider +from litellm.proxy._types import ( + DB_CONNECTION_ERROR_TYPES, + RBAC_ROLES, + CallInfo, + LiteLLM_EndUserTable, + LiteLLM_JWTAuth, + LiteLLM_OrganizationMembershipTable, + LiteLLM_OrganizationTable, + LiteLLM_TeamTable, + LiteLLM_TeamTableCachedObj, + LiteLLM_UserTable, + LiteLLMRoutes, + LitellmUserRoles, + ProxyErrorTypes, + ProxyException, + RoleBasedPermissions, + SpecialModelNames, + UserAPIKeyAuth, +) +from litellm.proxy.auth.route_checks import RouteChecks +from litellm.proxy.route_llm_request import route_request +from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics +from litellm.router import Router +from litellm.types.services import ServiceTypes + +from .auth_checks_organization import organization_role_based_access_check + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +last_db_access_time = LimitedSizeOrderedDict(max_size=100) +db_cache_expiry = 5 # refresh every 5s + +all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value + + +async def common_checks( + request_body: dict, + team_object: Optional[LiteLLM_TeamTable], + user_object: Optional[LiteLLM_UserTable], + end_user_object: Optional[LiteLLM_EndUserTable], + global_proxy_spend: Optional[float], + general_settings: dict, + route: str, + llm_router: Optional[Router], + proxy_logging_obj: ProxyLogging, + valid_token: Optional[UserAPIKeyAuth], +) -> bool: + """ + Common checks across jwt + key-based auth. + + 1. If team is blocked + 2. If team can call model + 3. If team is in budget + 4. If user passed in (JWT or key.user_id) - is in budget + 5. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget + 6. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints + 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget + 8. [OPTIONAL] If guardrails modified - is request allowed to change this + 9. Check if request body is safe + 10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks + """ + _model = request_body.get("model", None) + + # 1. If team is blocked + if team_object is not None and team_object.blocked is True: + raise Exception( + f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin." + ) + + # 2. If team can call model + if _model and team_object: + if not await can_team_access_model( + model=_model, + team_object=team_object, + llm_router=llm_router, + team_model_aliases=valid_token.team_model_aliases if valid_token else None, + ): + raise ProxyException( + message=f"Team not allowed to access model. Team={team_object.team_id}, Model={_model}. Allowed team models = {team_object.models}", + type=ProxyErrorTypes.team_model_access_denied, + param="model", + code=status.HTTP_401_UNAUTHORIZED, + ) + + ## 2.1 If user can call model (if personal key) + if team_object is None and user_object is not None: + await can_user_call_model( + model=_model, + llm_router=llm_router, + user_object=user_object, + ) + + # 3. If team is in budget + await _team_max_budget_check( + team_object=team_object, + proxy_logging_obj=proxy_logging_obj, + valid_token=valid_token, + ) + + # 4. If user is in budget + ## 4.1 check personal budget, if personal key + if ( + (team_object is None or team_object.team_id is None) + and user_object is not None + and user_object.max_budget is not None + ): + user_budget = user_object.max_budget + if user_budget < user_object.spend: + raise litellm.BudgetExceededError( + current_cost=user_object.spend, + max_budget=user_budget, + message=f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}", + ) + + ## 4.2 check team member budget, if team key + # 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget + if end_user_object is not None and end_user_object.litellm_budget_table is not None: + end_user_budget = end_user_object.litellm_budget_table.max_budget + if end_user_budget is not None and end_user_object.spend > end_user_budget: + raise litellm.BudgetExceededError( + current_cost=end_user_object.spend, + max_budget=end_user_budget, + message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}", + ) + + # 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints + if ( + general_settings.get("enforce_user_param", None) is not None + and general_settings["enforce_user_param"] is True + ): + if RouteChecks.is_llm_api_route(route=route) and "user" not in request_body: + raise Exception( + f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" + ) + # 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget + if ( + litellm.max_budget > 0 + and global_proxy_spend is not None + # only run global budget checks for OpenAI routes + # Reason - the Admin UI should continue working if the proxy crosses it's global budget + and RouteChecks.is_llm_api_route(route=route) + and route != "/v1/models" + and route != "/models" + ): + if global_proxy_spend > litellm.max_budget: + raise litellm.BudgetExceededError( + current_cost=global_proxy_spend, max_budget=litellm.max_budget + ) + + _request_metadata: dict = request_body.get("metadata", {}) or {} + if _request_metadata.get("guardrails"): + # check if team allowed to modify guardrails + from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails + + can_modify: bool = can_modify_guardrails(team_object) + if can_modify is False: + from fastapi import HTTPException + + raise HTTPException( + status_code=403, + detail={ + "error": "Your team does not have permission to modify guardrails." + }, + ) + + # 10 [OPTIONAL] Organization RBAC checks + organization_role_based_access_check( + user_object=user_object, route=route, request_body=request_body + ) + + return True + + +def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool: + """ + Return if a user is allowed to access route. Helper function for `allowed_routes_check`. + + Parameters: + - user_route: str - the route the user is trying to call + - allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user. + """ + + for allowed_route in allowed_routes: + if ( + allowed_route in LiteLLMRoutes.__members__ + and user_route in LiteLLMRoutes[allowed_route].value + ): + return True + elif allowed_route == user_route: + return True + return False + + +def allowed_routes_check( + user_role: Literal[ + LitellmUserRoles.PROXY_ADMIN, + LitellmUserRoles.TEAM, + LitellmUserRoles.INTERNAL_USER, + ], + user_route: str, + litellm_proxy_roles: LiteLLM_JWTAuth, +) -> bool: + """ + Check if user -> not admin - allowed to access these routes + """ + + if user_role == LitellmUserRoles.PROXY_ADMIN: + is_allowed = _allowed_routes_check( + user_route=user_route, + allowed_routes=litellm_proxy_roles.admin_allowed_routes, + ) + return is_allowed + + elif user_role == LitellmUserRoles.TEAM: + if litellm_proxy_roles.team_allowed_routes is None: + """ + By default allow a team to call openai + info routes + """ + is_allowed = _allowed_routes_check( + user_route=user_route, allowed_routes=["openai_routes", "info_routes"] + ) + return is_allowed + elif litellm_proxy_roles.team_allowed_routes is not None: + is_allowed = _allowed_routes_check( + user_route=user_route, + allowed_routes=litellm_proxy_roles.team_allowed_routes, + ) + return is_allowed + return False + + +def allowed_route_check_inside_route( + user_api_key_dict: UserAPIKeyAuth, + requested_user_id: Optional[str], +) -> bool: + ret_val = True + if ( + user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN + and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY + ): + ret_val = False + if requested_user_id is not None and user_api_key_dict.user_id is not None: + if user_api_key_dict.user_id == requested_user_id: + ret_val = True + return ret_val + + +def get_actual_routes(allowed_routes: list) -> list: + actual_routes: list = [] + for route_name in allowed_routes: + try: + route_value = LiteLLMRoutes[route_name].value + if isinstance(route_value, set): + actual_routes.extend(list(route_value)) + else: + actual_routes.extend(route_value) + + except KeyError: + actual_routes.append(route_name) + return actual_routes + + +@log_db_metrics +async def get_end_user_object( + end_user_id: Optional[str], + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, +) -> Optional[LiteLLM_EndUserTable]: + """ + Returns end user object, if in db. + + Do a isolated check for end user in table vs. doing a combined key + team + user + end-user check, as key might come in frequently for different end-users. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (end-user). + """ + if prisma_client is None: + raise Exception("No db connected") + + if end_user_id is None: + return None + _key = "end_user_id:{}".format(end_user_id) + + def check_in_budget(end_user_obj: LiteLLM_EndUserTable): + if end_user_obj.litellm_budget_table is None: + return + end_user_budget = end_user_obj.litellm_budget_table.max_budget + if end_user_budget is not None and end_user_obj.spend > end_user_budget: + raise litellm.BudgetExceededError( + current_cost=end_user_obj.spend, max_budget=end_user_budget + ) + + # check if in cache + cached_user_obj = await user_api_key_cache.async_get_cache(key=_key) + if cached_user_obj is not None: + if isinstance(cached_user_obj, dict): + return_obj = LiteLLM_EndUserTable(**cached_user_obj) + check_in_budget(end_user_obj=return_obj) + return return_obj + elif isinstance(cached_user_obj, LiteLLM_EndUserTable): + return_obj = cached_user_obj + check_in_budget(end_user_obj=return_obj) + return return_obj + # else, check db + try: + response = await prisma_client.db.litellm_endusertable.find_unique( + where={"user_id": end_user_id}, + include={"litellm_budget_table": True}, + ) + + if response is None: + raise Exception + + # save the end-user object to cache + await user_api_key_cache.async_set_cache( + key="end_user_id:{}".format(end_user_id), value=response + ) + + _response = LiteLLM_EndUserTable(**response.dict()) + + check_in_budget(end_user_obj=_response) + + return _response + except Exception as e: # if end-user not in db + if isinstance(e, litellm.BudgetExceededError): + raise e + return None + + +def model_in_access_group( + model: str, team_models: Optional[List[str]], llm_router: Optional[Router] +) -> bool: + from collections import defaultdict + + if team_models is None: + return True + if model in team_models: + return True + + access_groups: dict[str, list[str]] = defaultdict(list) + if llm_router: + access_groups = llm_router.get_model_access_groups(model_name=model) + + if len(access_groups) > 0: # check if token contains any model access groups + for idx, m in enumerate( + team_models + ): # loop token models, if any of them are an access group add the access group + if m in access_groups: + return True + + # Filter out models that are access_groups + filtered_models = [m for m in team_models if m not in access_groups] + + if model in filtered_models: + return True + + return False + + +def _should_check_db( + key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int +) -> bool: + """ + Prevent calling db repeatedly for items that don't exist in the db. + """ + current_time = time.time() + # if key doesn't exist in last_db_access_time -> check db + if key not in last_db_access_time: + return True + elif ( + last_db_access_time[key][0] is not None + ): # check db for non-null values (for refresh operations) + return True + elif last_db_access_time[key][0] is None: + if current_time - last_db_access_time[key] >= db_cache_expiry: + return True + return False + + +def _update_last_db_access_time( + key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict +): + last_db_access_time[key] = (value, time.time()) + + +def _get_role_based_permissions( + rbac_role: RBAC_ROLES, + general_settings: dict, + key: Literal["models", "routes"], +) -> Optional[List[str]]: + """ + Get the role based permissions from the general settings. + """ + role_based_permissions = cast( + Optional[List[RoleBasedPermissions]], + general_settings.get("role_permissions", []), + ) + if role_based_permissions is None: + return None + + for role_based_permission in role_based_permissions: + + if role_based_permission.role == rbac_role: + return getattr(role_based_permission, key) + + return None + + +def get_role_based_models( + rbac_role: RBAC_ROLES, + general_settings: dict, +) -> Optional[List[str]]: + """ + Get the models allowed for a user role. + + Used by JWT Auth. + """ + + return _get_role_based_permissions( + rbac_role=rbac_role, + general_settings=general_settings, + key="models", + ) + + +def get_role_based_routes( + rbac_role: RBAC_ROLES, + general_settings: dict, +) -> Optional[List[str]]: + """ + Get the routes allowed for a user role. + """ + + return _get_role_based_permissions( + rbac_role=rbac_role, + general_settings=general_settings, + key="routes", + ) + + +async def _get_fuzzy_user_object( + prisma_client: PrismaClient, + sso_user_id: Optional[str] = None, + user_email: Optional[str] = None, +) -> Optional[LiteLLM_UserTable]: + """ + Checks if sso user is in db. + + Called when user id match is not found in db. + + - Check if sso_user_id is user_id in db + - Check if sso_user_id is sso_user_id in db + - Check if user_email is user_email in db + - If not, create new user with user_email and sso_user_id and user_id = sso_user_id + """ + response = None + if sso_user_id is not None: + response = await prisma_client.db.litellm_usertable.find_unique( + where={"sso_user_id": sso_user_id}, + include={"organization_memberships": True}, + ) + + if response is None and user_email is not None: + response = await prisma_client.db.litellm_usertable.find_first( + where={"user_email": user_email}, + include={"organization_memberships": True}, + ) + + if response is not None and sso_user_id is not None: # update sso_user_id + asyncio.create_task( # background task to update user with sso id + prisma_client.db.litellm_usertable.update( + where={"user_id": response.user_id}, + data={"sso_user_id": sso_user_id}, + ) + ) + + return response + + +@log_db_metrics +async def get_user_object( + user_id: Optional[str], + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + user_id_upsert: bool, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, + sso_user_id: Optional[str] = None, + user_email: Optional[str] = None, +) -> Optional[LiteLLM_UserTable]: + """ + - Check if user id in proxy User Table + - if valid, return LiteLLM_UserTable object with defined limits + - if not, then raise an error + """ + + if user_id is None: + return None + + # check if in cache + cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id) + if cached_user_obj is not None: + if isinstance(cached_user_obj, dict): + return LiteLLM_UserTable(**cached_user_obj) + elif isinstance(cached_user_obj, LiteLLM_UserTable): + return cached_user_obj + # else, check db + if prisma_client is None: + raise Exception("No db connected") + try: + db_access_time_key = "user_id:{}".format(user_id) + should_check_db = _should_check_db( + key=db_access_time_key, + last_db_access_time=last_db_access_time, + db_cache_expiry=db_cache_expiry, + ) + + if should_check_db: + response = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_id}, include={"organization_memberships": True} + ) + + if response is None: + response = await _get_fuzzy_user_object( + prisma_client=prisma_client, + sso_user_id=sso_user_id, + user_email=user_email, + ) + + else: + response = None + + if response is None: + if user_id_upsert: + response = await prisma_client.db.litellm_usertable.create( + data={"user_id": user_id}, + include={"organization_memberships": True}, + ) + else: + raise Exception + + if ( + response.organization_memberships is not None + and len(response.organization_memberships) > 0 + ): + # dump each organization membership to type LiteLLM_OrganizationMembershipTable + _dumped_memberships = [ + LiteLLM_OrganizationMembershipTable(**membership.model_dump()) + for membership in response.organization_memberships + if membership is not None + ] + response.organization_memberships = _dumped_memberships + + _response = LiteLLM_UserTable(**dict(response)) + response_dict = _response.model_dump() + + # save the user object to cache + await user_api_key_cache.async_set_cache(key=user_id, value=response_dict) + + # save to db access time + _update_last_db_access_time( + key=db_access_time_key, + value=response_dict, + last_db_access_time=last_db_access_time, + ) + + return _response + except Exception as e: # if user not in db + raise ValueError( + f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call. Got error - {e}" + ) + + +async def _cache_management_object( + key: str, + value: BaseModel, + user_api_key_cache: DualCache, + proxy_logging_obj: Optional[ProxyLogging], +): + await user_api_key_cache.async_set_cache(key=key, value=value) + + +async def _cache_team_object( + team_id: str, + team_table: LiteLLM_TeamTableCachedObj, + user_api_key_cache: DualCache, + proxy_logging_obj: Optional[ProxyLogging], +): + key = "team_id:{}".format(team_id) + + ## CACHE REFRESH TIME! + team_table.last_refreshed_at = time.time() + + await _cache_management_object( + key=key, + value=team_table, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + +async def _cache_key_object( + hashed_token: str, + user_api_key_obj: UserAPIKeyAuth, + user_api_key_cache: DualCache, + proxy_logging_obj: Optional[ProxyLogging], +): + key = hashed_token + + ## CACHE REFRESH TIME + user_api_key_obj.last_refreshed_at = time.time() + + await _cache_management_object( + key=key, + value=user_api_key_obj, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + +async def _delete_cache_key_object( + hashed_token: str, + user_api_key_cache: DualCache, + proxy_logging_obj: Optional[ProxyLogging], +): + key = hashed_token + + user_api_key_cache.delete_cache(key=key) + + ## UPDATE REDIS CACHE ## + if proxy_logging_obj is not None: + await proxy_logging_obj.internal_usage_cache.dual_cache.async_delete_cache( + key=key + ) + + +@log_db_metrics +async def _get_team_db_check( + team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None +): + response = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} + ) + + if response is None and team_id_upsert: + response = await prisma_client.db.litellm_teamtable.create( + data={"team_id": team_id} + ) + + return response + + +async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient): + return await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} + ) + + +async def _get_team_object_from_user_api_key_cache( + team_id: str, + prisma_client: PrismaClient, + user_api_key_cache: DualCache, + last_db_access_time: LimitedSizeOrderedDict, + db_cache_expiry: int, + proxy_logging_obj: Optional[ProxyLogging], + key: str, + team_id_upsert: Optional[bool] = None, +) -> LiteLLM_TeamTableCachedObj: + db_access_time_key = key + should_check_db = _should_check_db( + key=db_access_time_key, + last_db_access_time=last_db_access_time, + db_cache_expiry=db_cache_expiry, + ) + if should_check_db: + response = await _get_team_db_check( + team_id=team_id, prisma_client=prisma_client, team_id_upsert=team_id_upsert + ) + else: + response = None + + if response is None: + raise Exception + + _response = LiteLLM_TeamTableCachedObj(**response.dict()) + # save the team object to cache + await _cache_team_object( + team_id=team_id, + team_table=_response, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + # save to db access time + # save to db access time + _update_last_db_access_time( + key=db_access_time_key, + value=_response, + last_db_access_time=last_db_access_time, + ) + + return _response + + +async def _get_team_object_from_cache( + key: str, + proxy_logging_obj: Optional[ProxyLogging], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span], +) -> Optional[LiteLLM_TeamTableCachedObj]: + cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None + + ## CHECK REDIS CACHE ## + if ( + proxy_logging_obj is not None + and proxy_logging_obj.internal_usage_cache.dual_cache + ): + + cached_team_obj = ( + await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache( + key=key, parent_otel_span=parent_otel_span + ) + ) + + if cached_team_obj is None: + cached_team_obj = await user_api_key_cache.async_get_cache(key=key) + + if cached_team_obj is not None: + if isinstance(cached_team_obj, dict): + return LiteLLM_TeamTableCachedObj(**cached_team_obj) + elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj): + return cached_team_obj + + return None + + +async def get_team_object( + team_id: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, + check_cache_only: Optional[bool] = None, + check_db_only: Optional[bool] = None, + team_id_upsert: Optional[bool] = None, +) -> LiteLLM_TeamTableCachedObj: + """ + - Check if team id in proxy Team Table + - if valid, return LiteLLM_TeamTable object with defined limits + - if not, then raise an error + + Raises: + - Exception: If team doesn't exist in db or cache + """ + if prisma_client is None: + raise Exception( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + + # check if in cache + key = "team_id:{}".format(team_id) + + if not check_db_only: + cached_team_obj = await _get_team_object_from_cache( + key=key, + proxy_logging_obj=proxy_logging_obj, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + ) + + if cached_team_obj is not None: + return cached_team_obj + + if check_cache_only: + raise Exception( + f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}." + ) + + # else, check db + try: + return await _get_team_object_from_user_api_key_cache( + team_id=team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + last_db_access_time=last_db_access_time, + db_cache_expiry=db_cache_expiry, + key=key, + team_id_upsert=team_id_upsert, + ) + except Exception: + raise Exception( + f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call." + ) + + +@log_db_metrics +async def get_key_object( + hashed_token: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, + check_cache_only: Optional[bool] = None, +) -> UserAPIKeyAuth: + """ + - Check if team id in proxy Team Table + - if valid, return LiteLLM_TeamTable object with defined limits + - if not, then raise an error + """ + if prisma_client is None: + raise Exception( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + + # check if in cache + key = hashed_token + + cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache( + key=key + ) + + if cached_key_obj is not None: + if isinstance(cached_key_obj, dict): + return UserAPIKeyAuth(**cached_key_obj) + elif isinstance(cached_key_obj, UserAPIKeyAuth): + return cached_key_obj + + if check_cache_only: + raise Exception( + f"Key doesn't exist in cache + check_cache_only=True. key={key}." + ) + + # else, check db + try: + _valid_token: Optional[BaseModel] = await prisma_client.get_data( + token=hashed_token, + table_name="combined_view", + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + + if _valid_token is None: + raise Exception + + _response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True)) + + # save the key object to cache + await _cache_key_object( + hashed_token=hashed_token, + user_api_key_obj=_response, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + return _response + except DB_CONNECTION_ERROR_TYPES as e: + return await _handle_failed_db_connection_for_get_key_object(e=e) + except Exception: + traceback.print_exc() + raise Exception( + f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call." + ) + + +async def _handle_failed_db_connection_for_get_key_object( + e: Exception, +) -> UserAPIKeyAuth: + """ + Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB + + Use this if you don't want failed DB queries to block LLM API reqiests + + Returns: + - UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True + + Raises: + - Orignal Exception in all other cases + """ + from litellm.proxy.proxy_server import ( + general_settings, + litellm_proxy_admin_name, + proxy_logging_obj, + ) + + # If this flag is on, requests failing to connect to the DB will be allowed + if general_settings.get("allow_requests_on_db_unavailable", False) is True: + # log this as a DB failure on prometheus + proxy_logging_obj.service_logging_obj.service_failure_hook( + service=ServiceTypes.DB, + call_type="get_key_object", + error=e, + duration=0.0, + ) + + return UserAPIKeyAuth( + key_name="failed-to-connect-to-db", + token="failed-to-connect-to-db", + user_id=litellm_proxy_admin_name, + ) + else: + # raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus + raise e + + +@log_db_metrics +async def get_org_object( + org_id: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, +) -> Optional[LiteLLM_OrganizationTable]: + """ + - Check if org id in proxy Org Table + - if valid, return LiteLLM_OrganizationTable object + - if not, then raise an error + """ + if prisma_client is None: + raise Exception( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + + # check if in cache + cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id)) + if cached_org_obj is not None: + if isinstance(cached_org_obj, dict): + return LiteLLM_OrganizationTable(**cached_org_obj) + elif isinstance(cached_org_obj, LiteLLM_OrganizationTable): + return cached_org_obj + # else, check db + try: + response = await prisma_client.db.litellm_organizationtable.find_unique( + where={"organization_id": org_id} + ) + + if response is None: + raise Exception + + return response + except Exception: + raise Exception( + f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call." + ) + + +async def _can_object_call_model( + model: str, + llm_router: Optional[Router], + models: List[str], + team_model_aliases: Optional[Dict[str, str]] = None, + object_type: Literal["user", "team", "key"] = "user", +) -> Literal[True]: + """ + Checks if token can call a given model + + Args: + - model: str + - llm_router: Optional[Router] + - models: List[str] + - team_model_aliases: Optional[Dict[str, str]] + - object_type: Literal["user", "team", "key"]. We use the object type to raise the correct exception type + + Returns: + - True: if token allowed to call model + + Raises: + - Exception: If token not allowed to call model + """ + if model in litellm.model_alias_map: + model = litellm.model_alias_map[model] + + ## check if model in allowed model names + from collections import defaultdict + + access_groups: Dict[str, List[str]] = defaultdict(list) + + if llm_router: + access_groups = llm_router.get_model_access_groups(model_name=model) + if ( + len(access_groups) > 0 and llm_router is not None + ): # check if token contains any model access groups + for idx, m in enumerate( + models + ): # loop token models, if any of them are an access group add the access group + if m in access_groups: + return True + + # Filter out models that are access_groups + filtered_models = [m for m in models if m not in access_groups] + + verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") + + if _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases): + return True + + if _model_matches_any_wildcard_pattern_in_list( + model=model, allowed_model_list=filtered_models + ): + return True + + all_model_access: bool = False + + if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models: + all_model_access = True + + if SpecialModelNames.all_proxy_models.value in filtered_models: + all_model_access = True + + if model is not None and model not in filtered_models and all_model_access is False: + raise ProxyException( + message=f"{object_type} not allowed to access model. This {object_type} can only access models={models}. Tried to access {model}", + type=ProxyErrorTypes.get_model_access_error_type_for_object( + object_type=object_type + ), + param="model", + code=status.HTTP_401_UNAUTHORIZED, + ) + + verbose_proxy_logger.debug( + f"filtered allowed_models: {filtered_models}; models: {models}" + ) + return True + + +def _model_in_team_aliases( + model: str, team_model_aliases: Optional[Dict[str, str]] = None +) -> bool: + """ + Returns True if `model` being accessed is an alias of a team model + + - `model=gpt-4o` + - `team_model_aliases={"gpt-4o": "gpt-4o-team-1"}` + - returns True + + - `model=gp-4o` + - `team_model_aliases={"o-3": "o3-preview"}` + - returns False + """ + if team_model_aliases: + if model in team_model_aliases: + return True + return False + + +async def can_key_call_model( + model: str, + llm_model_list: Optional[list], + valid_token: UserAPIKeyAuth, + llm_router: Optional[litellm.Router], +) -> Literal[True]: + """ + Checks if token can call a given model + + Returns: + - True: if token allowed to call model + + Raises: + - Exception: If token not allowed to call model + """ + return await _can_object_call_model( + model=model, + llm_router=llm_router, + models=valid_token.models, + team_model_aliases=valid_token.team_model_aliases, + object_type="key", + ) + + +async def can_team_access_model( + model: str, + team_object: Optional[LiteLLM_TeamTable], + llm_router: Optional[Router], + team_model_aliases: Optional[Dict[str, str]] = None, +) -> Literal[True]: + """ + Returns True if the team can access a specific model. + + """ + return await _can_object_call_model( + model=model, + llm_router=llm_router, + models=team_object.models if team_object else [], + team_model_aliases=team_model_aliases, + object_type="team", + ) + + +async def can_user_call_model( + model: str, + llm_router: Optional[Router], + user_object: Optional[LiteLLM_UserTable], +) -> Literal[True]: + + if user_object is None: + return True + + if SpecialModelNames.no_default_models.value in user_object.models: + raise ProxyException( + message=f"User not allowed to access model. No default model access, only team models allowed. Tried to access {model}", + type=ProxyErrorTypes.key_model_access_denied, + param="model", + code=status.HTTP_401_UNAUTHORIZED, + ) + + return await _can_object_call_model( + model=model, + llm_router=llm_router, + models=user_object.models, + object_type="user", + ) + + +async def is_valid_fallback_model( + model: str, + llm_router: Optional[Router], + user_model: Optional[str], +) -> Literal[True]: + """ + Try to route the fallback model request. + + Validate if it can't be routed. + + Help catch invalid fallback models. + """ + await route_request( + data={ + "model": model, + "messages": [{"role": "user", "content": "Who was Alexander?"}], + }, + llm_router=llm_router, + user_model=user_model, + route_type="acompletion", # route type shouldn't affect the fallback model check + ) + + return True + + +async def _virtual_key_max_budget_check( + valid_token: UserAPIKeyAuth, + proxy_logging_obj: ProxyLogging, + user_obj: Optional[LiteLLM_UserTable] = None, +): + """ + Raises: + BudgetExceededError if the token is over it's max budget. + Triggers a budget alert if the token is over it's max budget. + + """ + if valid_token.spend is not None and valid_token.max_budget is not None: + #################################### + # collect information for alerting # + #################################### + + user_email = None + # Check if the token has any user id information + if user_obj is not None: + user_email = user_obj.user_email + + call_info = CallInfo( + token=valid_token.token, + spend=valid_token.spend, + max_budget=valid_token.max_budget, + user_id=valid_token.user_id, + team_id=valid_token.team_id, + user_email=user_email, + key_alias=valid_token.key_alias, + ) + asyncio.create_task( + proxy_logging_obj.budget_alerts( + type="token_budget", + user_info=call_info, + ) + ) + + #################################### + # collect information for alerting # + #################################### + + if valid_token.spend >= valid_token.max_budget: + raise litellm.BudgetExceededError( + current_cost=valid_token.spend, + max_budget=valid_token.max_budget, + ) + + +async def _virtual_key_soft_budget_check( + valid_token: UserAPIKeyAuth, + proxy_logging_obj: ProxyLogging, +): + """ + Triggers a budget alert if the token is over it's soft budget. + + """ + + if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget: + verbose_proxy_logger.debug( + "Crossed Soft Budget for token %s, spend %s, soft_budget %s", + valid_token.token, + valid_token.spend, + valid_token.soft_budget, + ) + call_info = CallInfo( + token=valid_token.token, + spend=valid_token.spend, + max_budget=valid_token.max_budget, + soft_budget=valid_token.soft_budget, + user_id=valid_token.user_id, + team_id=valid_token.team_id, + team_alias=valid_token.team_alias, + user_email=None, + key_alias=valid_token.key_alias, + ) + asyncio.create_task( + proxy_logging_obj.budget_alerts( + type="soft_budget", + user_info=call_info, + ) + ) + + +async def _team_max_budget_check( + team_object: Optional[LiteLLM_TeamTable], + valid_token: Optional[UserAPIKeyAuth], + proxy_logging_obj: ProxyLogging, +): + """ + Check if the team is over it's max budget. + + Raises: + BudgetExceededError if the team is over it's max budget. + Triggers a budget alert if the team is over it's max budget. + """ + if ( + team_object is not None + and team_object.max_budget is not None + and team_object.spend is not None + and team_object.spend > team_object.max_budget + ): + if valid_token: + call_info = CallInfo( + token=valid_token.token, + spend=team_object.spend, + max_budget=team_object.max_budget, + user_id=valid_token.user_id, + team_id=valid_token.team_id, + team_alias=valid_token.team_alias, + ) + asyncio.create_task( + proxy_logging_obj.budget_alerts( + type="team_budget", + user_info=call_info, + ) + ) + + raise litellm.BudgetExceededError( + current_cost=team_object.spend, + max_budget=team_object.max_budget, + message=f"Budget has been exceeded! Team={team_object.team_id} Current cost: {team_object.spend}, Max budget: {team_object.max_budget}", + ) + + +def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool: + """ + Check if a model matches an allowed pattern. + Handles exact matches and wildcard patterns. + + Args: + model (str): The model to check (e.g., "bedrock/anthropic.claude-3-5-sonnet-20240620") + allowed_model_pattern (str): The allowed pattern (e.g., "bedrock/*", "*", "openai/*") + + Returns: + bool: True if model matches the pattern, False otherwise + """ + if "*" in allowed_model_pattern: + pattern = f"^{allowed_model_pattern.replace('*', '.*')}$" + return bool(re.match(pattern, model)) + + return False + + +def _model_matches_any_wildcard_pattern_in_list( + model: str, allowed_model_list: list +) -> bool: + """ + Returns True if a model matches any wildcard pattern in a list. + + eg. + - model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns True + - model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True + - model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False + """ + + if any( + _is_wildcard_pattern(allowed_model_pattern) + and is_model_allowed_by_pattern( + model=model, allowed_model_pattern=allowed_model_pattern + ) + for allowed_model_pattern in allowed_model_list + ): + return True + + if any( + _is_wildcard_pattern(allowed_model_pattern) + and _model_custom_llm_provider_matches_wildcard_pattern( + model=model, allowed_model_pattern=allowed_model_pattern + ) + for allowed_model_pattern in allowed_model_list + ): + return True + + return False + + +def _model_custom_llm_provider_matches_wildcard_pattern( + model: str, allowed_model_pattern: str +) -> bool: + """ + Returns True for this scenario: + - `model=gpt-4o` + - `allowed_model_pattern=openai/*` + + or + - `model=claude-3-5-sonnet-20240620` + - `allowed_model_pattern=anthropic/*` + """ + try: + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + except Exception: + return False + + return is_model_allowed_by_pattern( + model=f"{custom_llm_provider}/{model}", + allowed_model_pattern=allowed_model_pattern, + ) + + +def _is_wildcard_pattern(allowed_model_pattern: str) -> bool: + """ + Returns True if the pattern is a wildcard pattern. + + Checks if `*` is in the pattern. + """ + return "*" in allowed_model_pattern diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks_organization.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks_organization.py new file mode 100644 index 00000000..3da3d8dd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks_organization.py @@ -0,0 +1,161 @@ +""" +Auth Checks for Organizations +""" + +from typing import Dict, List, Optional, Tuple + +from fastapi import status + +from litellm.proxy._types import * + + +def organization_role_based_access_check( + request_body: dict, + user_object: Optional[LiteLLM_UserTable], + route: str, +): + """ + Role based access control checks only run if a user is part of an Organization + + Organization Checks: + ONLY RUN IF user_object.organization_memberships is not None + + 1. Only Proxy Admins can access /organization/new + 2. IF route is a LiteLLMRoutes.org_admin_only_routes, then check if user is an Org Admin for that organization + + """ + + if user_object is None: + return + + passed_organization_id: Optional[str] = request_body.get("organization_id", None) + + if route == "/organization/new": + if user_object.user_role != LitellmUserRoles.PROXY_ADMIN.value: + raise ProxyException( + message=f"Only proxy admins can create new organizations. You are {user_object.user_role}", + type=ProxyErrorTypes.auth_error.value, + param="user_role", + code=status.HTTP_401_UNAUTHORIZED, + ) + + if user_object.user_role == LitellmUserRoles.PROXY_ADMIN.value: + return + + # Checks if route is an Org Admin Only Route + if route in LiteLLMRoutes.org_admin_only_routes.value: + _user_organizations, _user_organization_role_mapping = ( + get_user_organization_info(user_object) + ) + + if user_object.organization_memberships is None: + raise ProxyException( + message=f"Tried to access route={route} but you are not a member of any organization. Please contact the proxy admin to request access.", + type=ProxyErrorTypes.auth_error.value, + param="organization_id", + code=status.HTTP_401_UNAUTHORIZED, + ) + + if passed_organization_id is None: + raise ProxyException( + message="Passed organization_id is None, please pass an organization_id in your request", + type=ProxyErrorTypes.auth_error.value, + param="organization_id", + code=status.HTTP_401_UNAUTHORIZED, + ) + + user_role: Optional[LitellmUserRoles] = _user_organization_role_mapping.get( + passed_organization_id + ) + if user_role is None: + raise ProxyException( + message=f"You do not have a role within the selected organization. Passed organization_id: {passed_organization_id}. Please contact the organization admin to request access.", + type=ProxyErrorTypes.auth_error.value, + param="organization_id", + code=status.HTTP_401_UNAUTHORIZED, + ) + + if user_role != LitellmUserRoles.ORG_ADMIN.value: + raise ProxyException( + message=f"You do not have the required role to perform {route} in Organization {passed_organization_id}. Your role is {user_role} in Organization {passed_organization_id}", + type=ProxyErrorTypes.auth_error.value, + param="user_role", + code=status.HTTP_401_UNAUTHORIZED, + ) + elif route == "/team/new": + # if user is part of multiple teams, then they need to specify the organization_id + _user_organizations, _user_organization_role_mapping = ( + get_user_organization_info(user_object) + ) + if ( + user_object.organization_memberships is not None + and len(user_object.organization_memberships) > 0 + ): + if passed_organization_id is None: + raise ProxyException( + message=f"Passed organization_id is None, please specify the organization_id in your request. You are part of multiple organizations: {_user_organizations}", + type=ProxyErrorTypes.auth_error.value, + param="organization_id", + code=status.HTTP_401_UNAUTHORIZED, + ) + + _user_role_in_passed_org = _user_organization_role_mapping.get( + passed_organization_id + ) + if _user_role_in_passed_org != LitellmUserRoles.ORG_ADMIN.value: + raise ProxyException( + message=f"You do not have the required role to call {route}. Your role is {_user_role_in_passed_org} in Organization {passed_organization_id}", + type=ProxyErrorTypes.auth_error.value, + param="user_role", + code=status.HTTP_401_UNAUTHORIZED, + ) + + +def get_user_organization_info( + user_object: LiteLLM_UserTable, +) -> Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]: + """ + Helper function to extract user organization information. + + Args: + user_object (LiteLLM_UserTable): The user object containing organization memberships. + + Returns: + Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]: A tuple containing: + - List of organization IDs the user is a member of + - Dictionary mapping organization IDs to user roles + """ + _user_organizations: List[str] = [] + _user_organization_role_mapping: Dict[str, Optional[LitellmUserRoles]] = {} + + if user_object.organization_memberships is not None: + for _membership in user_object.organization_memberships: + if _membership.organization_id is not None: + _user_organizations.append(_membership.organization_id) + _user_organization_role_mapping[_membership.organization_id] = _membership.user_role # type: ignore + + return _user_organizations, _user_organization_role_mapping + + +def _user_is_org_admin( + request_data: dict, + user_object: Optional[LiteLLM_UserTable] = None, +) -> bool: + """ + Helper function to check if user is an org admin for the passed organization_id + """ + if request_data.get("organization_id", None) is None: + return False + + if user_object is None: + return False + + if user_object.organization_memberships is None: + return False + + for _membership in user_object.organization_memberships: + if _membership.organization_id == request_data.get("organization_id", None): + if _membership.user_role == LitellmUserRoles.ORG_ADMIN.value: + return True + + return False diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py new file mode 100644 index 00000000..91fcaf7e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py @@ -0,0 +1,514 @@ +import os +import re +import sys +from typing import Any, List, Optional, Tuple + +from fastapi import HTTPException, Request, status + +from litellm import Router, provider_list +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import * +from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS + + +def _get_request_ip_address( + request: Request, use_x_forwarded_for: Optional[bool] = False +) -> Optional[str]: + + client_ip = None + if use_x_forwarded_for is True and "x-forwarded-for" in request.headers: + client_ip = request.headers["x-forwarded-for"] + elif request.client is not None: + client_ip = request.client.host + else: + client_ip = "" + + return client_ip + + +def _check_valid_ip( + allowed_ips: Optional[List[str]], + request: Request, + use_x_forwarded_for: Optional[bool] = False, +) -> Tuple[bool, Optional[str]]: + """ + Returns if ip is allowed or not + """ + if allowed_ips is None: # if not set, assume true + return True, None + + # if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for + client_ip = _get_request_ip_address( + request=request, use_x_forwarded_for=use_x_forwarded_for + ) + + # Check if IP address is allowed + if client_ip not in allowed_ips: + return False, client_ip + + return True, client_ip + + +def check_complete_credentials(request_body: dict) -> bool: + """ + if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks. + """ + given_model: Optional[str] = None + + given_model = request_body.get("model") + if given_model is None: + return False + + if ( + "sagemaker" in given_model + or "bedrock" in given_model + or "vertex_ai" in given_model + or "vertex_ai_beta" in given_model + ): + # complex credentials - easier to make a malicious request + return False + + if "api_key" in request_body: + return True + + return False + + +def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool: + """ + Check if request_body_value matches the regex_str or is equal to param + """ + if re.match(regex_str, request_body_value) or regex_str == request_body_value: + return True + return False + + +def _is_param_allowed( + param: str, + request_body_value: Any, + configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS, +) -> bool: + """ + Check if param is a str or dict and if request_body_value is in the list of allowed values + """ + if configurable_clientside_auth_params is None: + return False + + for item in configurable_clientside_auth_params: + if isinstance(item, str) and param == item: + return True + elif isinstance(item, Dict): + if param == "api_base" and check_regex_or_str_match( + request_body_value=request_body_value, + regex_str=item["api_base"], + ): # assume param is a regex + return True + + return False + + +def _allow_model_level_clientside_configurable_parameters( + model: str, param: str, request_body_value: Any, llm_router: Optional[Router] +) -> bool: + """ + Check if model is allowed to use configurable client-side params + - get matching model + - check if 'clientside_configurable_parameters' is set for model + - + """ + if llm_router is None: + return False + # check if model is set + model_info = llm_router.get_model_group_info(model_group=model) + if model_info is None: + # check if wildcard model is set + if model.split("/", 1)[0] in provider_list: + model_info = llm_router.get_model_group_info( + model_group=model.split("/", 1)[0] + ) + + if model_info is None: + return False + + if model_info is None or model_info.configurable_clientside_auth_params is None: + return False + + return _is_param_allowed( + param=param, + request_body_value=request_body_value, + configurable_clientside_auth_params=model_info.configurable_clientside_auth_params, + ) + + +def is_request_body_safe( + request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str +) -> bool: + """ + Check if the request body is safe. + + A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key. + Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997 + """ + banned_params = ["api_base", "base_url"] + + for param in banned_params: + if ( + param in request_body + and not check_complete_credentials( # allow client-credentials to be passed to proxy + request_body=request_body + ) + ): + if general_settings.get("allow_client_side_credentials") is True: + return True + elif ( + _allow_model_level_clientside_configurable_parameters( + model=model, + param=param, + request_body_value=request_body[param], + llm_router=llm_router, + ) + is True + ): + return True + raise ValueError( + f"Rejected Request: {param} is not allowed in request body. " + "Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. " + "Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997", + ) + + return True + + +async def pre_db_read_auth_checks( + request: Request, + request_data: dict, + route: str, +): + """ + 1. Checks if request size is under max_request_size_mb (if set) + 2. Check if request body is safe (example user has not set api_base in request body) + 3. Check if IP address is allowed (if set) + 4. Check if request route is an allowed route on the proxy (if set) + + Returns: + - True + + Raises: + - HTTPException if request fails initial auth checks + """ + from litellm.proxy.proxy_server import general_settings, llm_router, premium_user + + # Check 1. request size + await check_if_request_size_is_safe(request=request) + + # Check 2. Request body is safe + is_request_body_safe( + request_body=request_data, + general_settings=general_settings, + llm_router=llm_router, + model=request_data.get( + "model", "" + ), # [TODO] use model passed in url as well (azure openai routes) + ) + + # Check 3. Check if IP address is allowed + is_valid_ip, passed_in_ip = _check_valid_ip( + allowed_ips=general_settings.get("allowed_ips", None), + use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), + request=request, + ) + + if not is_valid_ip: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Access forbidden: IP address {passed_in_ip} not allowed.", + ) + + # Check 4. Check if request route is an allowed route on the proxy + if "allowed_routes" in general_settings: + _allowed_routes = general_settings["allowed_routes"] + if premium_user is not True: + verbose_proxy_logger.error( + f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}" + ) + if route not in _allowed_routes: + verbose_proxy_logger.error( + f"Route {route} not in allowed_routes={_allowed_routes}" + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Access forbidden: Route {route} not allowed", + ) + + +def route_in_additonal_public_routes(current_route: str): + """ + Helper to check if the user defined public_routes on config.yaml + + Parameters: + - current_route: str - the route the user is trying to call + + Returns: + - bool - True if the route is defined in public_routes + - bool - False if the route is not defined in public_routes + + + In order to use this the litellm config.yaml should have the following in general_settings: + + ```yaml + general_settings: + master_key: sk-1234 + public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"] + ``` + """ + + # check if user is premium_user - if not do nothing + from litellm.proxy.proxy_server import general_settings, premium_user + + try: + if premium_user is not True: + return False + # check if this is defined on the config + if general_settings is None: + return False + + routes_defined = general_settings.get("public_routes", []) + if current_route in routes_defined: + return True + + return False + except Exception as e: + verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}") + return False + + +def get_request_route(request: Request) -> str: + """ + Helper to get the route from the request + + remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions + """ + try: + if hasattr(request, "base_url") and request.url.path.startswith( + request.base_url.path + ): + # remove base_url from path + return request.url.path[len(request.base_url.path) - 1 :] + else: + return request.url.path + except Exception as e: + verbose_proxy_logger.debug( + f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}" + ) + return request.url.path + + +async def check_if_request_size_is_safe(request: Request) -> bool: + """ + Enterprise Only: + - Checks if the request size is within the limit + + Args: + request (Request): The incoming request. + + Returns: + bool: True if the request size is within the limit + + Raises: + ProxyException: If the request size is too large + + """ + from litellm.proxy.proxy_server import general_settings, premium_user + + max_request_size_mb = general_settings.get("max_request_size_mb", None) + if max_request_size_mb is not None: + # Check if premium user + if premium_user is not True: + verbose_proxy_logger.warning( + f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}" + ) + return True + + # Get the request body + content_length = request.headers.get("content-length") + + if content_length: + header_size = int(content_length) + header_size_mb = bytes_to_mb(bytes_value=header_size) + verbose_proxy_logger.debug( + f"content_length request size in MB={header_size_mb}" + ) + + if header_size_mb > max_request_size_mb: + raise ProxyException( + message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB", + type=ProxyErrorTypes.bad_request_error.value, + code=400, + param="content-length", + ) + else: + # If Content-Length is not available, read the body + body = await request.body() + body_size = len(body) + request_size_mb = bytes_to_mb(bytes_value=body_size) + + verbose_proxy_logger.debug( + f"request body request size in MB={request_size_mb}" + ) + if request_size_mb > max_request_size_mb: + raise ProxyException( + message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB", + type=ProxyErrorTypes.bad_request_error.value, + code=400, + param="content-length", + ) + + return True + + +async def check_response_size_is_safe(response: Any) -> bool: + """ + Enterprise Only: + - Checks if the response size is within the limit + + Args: + response (Any): The response to check. + + Returns: + bool: True if the response size is within the limit + + Raises: + ProxyException: If the response size is too large + + """ + + from litellm.proxy.proxy_server import general_settings, premium_user + + max_response_size_mb = general_settings.get("max_response_size_mb", None) + if max_response_size_mb is not None: + # Check if premium user + if premium_user is not True: + verbose_proxy_logger.warning( + f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}" + ) + return True + + response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response)) + verbose_proxy_logger.debug(f"response size in MB={response_size_mb}") + if response_size_mb > max_response_size_mb: + raise ProxyException( + message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB", + type=ProxyErrorTypes.bad_request_error.value, + code=400, + param="content-length", + ) + + return True + + +def bytes_to_mb(bytes_value: int): + """ + Helper to convert bytes to MB + """ + return bytes_value / (1024 * 1024) + + +# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key +def get_key_model_rpm_limit( + user_api_key_dict: UserAPIKeyAuth, +) -> Optional[Dict[str, int]]: + if user_api_key_dict.metadata: + if "model_rpm_limit" in user_api_key_dict.metadata: + return user_api_key_dict.metadata["model_rpm_limit"] + elif user_api_key_dict.model_max_budget: + model_rpm_limit: Dict[str, Any] = {} + for model, budget in user_api_key_dict.model_max_budget.items(): + if "rpm_limit" in budget and budget["rpm_limit"] is not None: + model_rpm_limit[model] = budget["rpm_limit"] + return model_rpm_limit + + return None + + +def get_key_model_tpm_limit( + user_api_key_dict: UserAPIKeyAuth, +) -> Optional[Dict[str, int]]: + if user_api_key_dict.metadata: + if "model_tpm_limit" in user_api_key_dict.metadata: + return user_api_key_dict.metadata["model_tpm_limit"] + elif user_api_key_dict.model_max_budget: + if "tpm_limit" in user_api_key_dict.model_max_budget: + return user_api_key_dict.model_max_budget["tpm_limit"] + + return None + + +def is_pass_through_provider_route(route: str) -> bool: + PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [ + "vertex-ai", + ] + + # check if any of the prefixes are in the route + for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES: + if prefix in route: + return True + + return False + + +def should_run_auth_on_pass_through_provider_route(route: str) -> bool: + """ + Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on /vertex-ai/{endpoint} routes + Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on provider pass through routes + ex /vertex-ai/{endpoint} routes + Run virtual key auth if the following is try: + - User is premium_user + - User has enabled litellm_setting.use_client_credentials_pass_through_routes + """ + from litellm.proxy.proxy_server import general_settings, premium_user + + if premium_user is not True: + + return False + + # premium use has opted into using client credentials + if ( + general_settings.get("use_client_credentials_pass_through_routes", False) + is True + ): + return False + + # only enabled for LiteLLM Enterprise + return True + + +def _has_user_setup_sso(): + """ + Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables. + Returns a boolean indicating whether SSO has been set up. + """ + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + + sso_setup = ( + (microsoft_client_id is not None) + or (google_client_id is not None) + or (generic_client_id is not None) + ) + + return sso_setup + + +def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]: + # openai - check 'user' + if "user" in request_body and request_body["user"] is not None: + return str(request_body["user"]) + # anthropic - check 'litellm_metadata' + end_user_id = request_body.get("litellm_metadata", {}).get("user", None) + if end_user_id: + return str(end_user_id) + metadata = request_body.get("metadata") + if metadata and "user_id" in metadata and metadata["user_id"] is not None: + return str(metadata["user_id"]) + return None diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/handle_jwt.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/handle_jwt.py new file mode 100644 index 00000000..cc410501 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/handle_jwt.py @@ -0,0 +1,1001 @@ +""" +Supports using JWT's for authenticating into the proxy. + +Currently only supports admin. + +JWT token must have 'litellm_proxy_admin' in scope. +""" + +import json +import os +from typing import Any, List, Literal, Optional, Set, Tuple, cast + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from fastapi import HTTPException + +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value +from litellm.llms.custom_httpx.httpx_handler import HTTPHandler +from litellm.proxy._types import ( + RBAC_ROLES, + JWKKeyValue, + JWTAuthBuilderResult, + JWTKeyItem, + LiteLLM_EndUserTable, + LiteLLM_JWTAuth, + LiteLLM_OrganizationTable, + LiteLLM_TeamTable, + LiteLLM_UserTable, + LitellmUserRoles, + ScopeMapping, + Span, +) +from litellm.proxy.auth.auth_checks import can_team_access_model +from litellm.proxy.utils import PrismaClient, ProxyLogging + +from .auth_checks import ( + _allowed_routes_check, + allowed_routes_check, + get_actual_routes, + get_end_user_object, + get_org_object, + get_role_based_models, + get_role_based_routes, + get_team_object, + get_user_object, +) + + +class JWTHandler: + """ + - treat the sub id passed in as the user id + - return an error if id making request doesn't exist in proxy user table + - track spend against the user id + - if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets + """ + + prisma_client: Optional[PrismaClient] + user_api_key_cache: DualCache + + def __init__( + self, + ) -> None: + self.http_handler = HTTPHandler() + self.leeway = 0 + + def update_environment( + self, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + litellm_jwtauth: LiteLLM_JWTAuth, + leeway: int = 0, + ) -> None: + self.prisma_client = prisma_client + self.user_api_key_cache = user_api_key_cache + self.litellm_jwtauth = litellm_jwtauth + self.leeway = leeway + + def is_jwt(self, token: str): + parts = token.split(".") + return len(parts) == 3 + + def _rbac_role_from_role_mapping(self, token: dict) -> Optional[RBAC_ROLES]: + """ + Returns the RBAC role the token 'belongs' to based on role mappings. + + Args: + token (dict): The JWT token containing role information + + Returns: + Optional[RBAC_ROLES]: The mapped internal RBAC role if a mapping exists, + None otherwise + + Note: + The function handles both single string roles and lists of roles from the JWT. + If multiple mappings match the JWT roles, the first matching mapping is returned. + """ + if self.litellm_jwtauth.role_mappings is None: + return None + + jwt_role = self.get_jwt_role(token=token, default_value=None) + if not jwt_role: + return None + + jwt_role_set = set(jwt_role) + + for role_mapping in self.litellm_jwtauth.role_mappings: + # Check if the mapping role matches any of the JWT roles + if role_mapping.role in jwt_role_set: + return role_mapping.internal_role + + return None + + def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]: + """ + Returns the RBAC role the token 'belongs' to. + + RBAC roles allowed to make requests: + - PROXY_ADMIN: can make requests to all routes + - TEAM: can make requests to routes associated with a team + - INTERNAL_USER: can make requests to routes associated with a user + + Resolves: https://github.com/BerriAI/litellm/issues/6793 + + Returns: + - PROXY_ADMIN: if token is admin + - TEAM: if token is associated with a team + - INTERNAL_USER: if token is associated with a user + - None: if token is not associated with a team or user + """ + scopes = self.get_scopes(token=token) + is_admin = self.is_admin(scopes=scopes) + user_roles = self.get_user_roles(token=token, default_value=None) + + if is_admin: + return LitellmUserRoles.PROXY_ADMIN + elif self.get_team_id(token=token, default_value=None) is not None: + return LitellmUserRoles.TEAM + elif self.get_user_id(token=token, default_value=None) is not None: + return LitellmUserRoles.INTERNAL_USER + elif user_roles is not None and self.is_allowed_user_role( + user_roles=user_roles + ): + return LitellmUserRoles.INTERNAL_USER + elif rbac_role := self._rbac_role_from_role_mapping(token=token): + return rbac_role + + return None + + def is_admin(self, scopes: list) -> bool: + if self.litellm_jwtauth.admin_jwt_scope in scopes: + return True + return False + + def get_team_ids_from_jwt(self, token: dict) -> List[str]: + if ( + self.litellm_jwtauth.team_ids_jwt_field is not None + and token.get(self.litellm_jwtauth.team_ids_jwt_field) is not None + ): + return token[self.litellm_jwtauth.team_ids_jwt_field] + return [] + + def get_end_user_id( + self, token: dict, default_value: Optional[str] + ) -> Optional[str]: + try: + + if self.litellm_jwtauth.end_user_id_jwt_field is not None: + user_id = token[self.litellm_jwtauth.end_user_id_jwt_field] + else: + user_id = None + except KeyError: + user_id = default_value + + return user_id + + def is_required_team_id(self) -> bool: + """ + Returns: + - True: if 'team_id_jwt_field' is set + - False: if not + """ + if self.litellm_jwtauth.team_id_jwt_field is None: + return False + return True + + def is_enforced_email_domain(self) -> bool: + """ + Returns: + - True: if 'user_allowed_email_domain' is set + - False: if 'user_allowed_email_domain' is None + """ + + if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance( + self.litellm_jwtauth.user_allowed_email_domain, str + ): + return True + return False + + def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + if self.litellm_jwtauth.team_id_jwt_field is not None: + team_id = token[self.litellm_jwtauth.team_id_jwt_field] + elif self.litellm_jwtauth.team_id_default is not None: + team_id = self.litellm_jwtauth.team_id_default + else: + team_id = None + except KeyError: + team_id = default_value + return team_id + + def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool: + """ + Returns: + - True: if 'user_id_upsert' is set AND valid_user_email is not False + - False: if not + """ + if valid_user_email is False: + return False + return self.litellm_jwtauth.user_id_upsert + + def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + if self.litellm_jwtauth.user_id_jwt_field is not None: + user_id = token[self.litellm_jwtauth.user_id_jwt_field] + else: + user_id = default_value + except KeyError: + user_id = default_value + return user_id + + def get_user_roles( + self, token: dict, default_value: Optional[List[str]] + ) -> Optional[List[str]]: + """ + Returns the user role from the token. + + Set via 'user_roles_jwt_field' in the config. + """ + try: + if self.litellm_jwtauth.user_roles_jwt_field is not None: + user_roles = get_nested_value( + data=token, + key_path=self.litellm_jwtauth.user_roles_jwt_field, + default=default_value, + ) + else: + user_roles = default_value + except KeyError: + user_roles = default_value + return user_roles + + def get_jwt_role( + self, token: dict, default_value: Optional[List[str]] + ) -> Optional[List[str]]: + """ + Generic implementation of `get_user_roles` that can be used for both user and team roles. + + Returns the jwt role from the token. + + Set via 'roles_jwt_field' in the config. + """ + try: + if self.litellm_jwtauth.roles_jwt_field is not None: + user_roles = get_nested_value( + data=token, + key_path=self.litellm_jwtauth.roles_jwt_field, + default=default_value, + ) + else: + user_roles = default_value + except KeyError: + user_roles = default_value + return user_roles + + def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool: + """ + Returns the user role from the token. + + Set via 'user_allowed_roles' in the config. + """ + if ( + user_roles is not None + and self.litellm_jwtauth.user_allowed_roles is not None + and any( + role in self.litellm_jwtauth.user_allowed_roles for role in user_roles + ) + ): + return True + return False + + def get_user_email( + self, token: dict, default_value: Optional[str] + ) -> Optional[str]: + try: + if self.litellm_jwtauth.user_email_jwt_field is not None: + user_email = token[self.litellm_jwtauth.user_email_jwt_field] + else: + user_email = None + except KeyError: + user_email = default_value + return user_email + + def get_object_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + if self.litellm_jwtauth.object_id_jwt_field is not None: + object_id = token[self.litellm_jwtauth.object_id_jwt_field] + else: + object_id = default_value + except KeyError: + object_id = default_value + return object_id + + def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + if self.litellm_jwtauth.org_id_jwt_field is not None: + org_id = token[self.litellm_jwtauth.org_id_jwt_field] + else: + org_id = None + except KeyError: + org_id = default_value + return org_id + + def get_scopes(self, token: dict) -> List[str]: + try: + if isinstance(token["scope"], str): + # Assuming the scopes are stored in 'scope' claim and are space-separated + scopes = token["scope"].split() + elif isinstance(token["scope"], list): + scopes = token["scope"] + else: + raise Exception( + f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str." + ) + except KeyError: + scopes = [] + return scopes + + async def get_public_key(self, kid: Optional[str]) -> dict: + + keys_url = os.getenv("JWT_PUBLIC_KEY_URL") + + if keys_url is None: + raise Exception("Missing JWT Public Key URL from environment.") + + keys_url_list = [url.strip() for url in keys_url.split(",")] + + for key_url in keys_url_list: + + cache_key = f"litellm_jwt_auth_keys_{key_url}" + + cached_keys = await self.user_api_key_cache.async_get_cache(cache_key) + + if cached_keys is None: + response = await self.http_handler.get(key_url) + + response_json = response.json() + if "keys" in response_json: + keys: JWKKeyValue = response.json()["keys"] + else: + keys = response_json + + await self.user_api_key_cache.async_set_cache( + key=cache_key, + value=keys, + ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins + ) + else: + keys = cached_keys + + public_key = self.parse_keys(keys=keys, kid=kid) + if public_key is not None: + return cast(dict, public_key) + + raise Exception( + f"No matching public key found. keys={keys_url_list}, kid={kid}" + ) + + def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]: + public_key: Optional[JWTKeyItem] = None + if len(keys) == 1: + if isinstance(keys, dict) and (keys.get("kid", None) == kid or kid is None): + public_key = keys + elif isinstance(keys, list) and ( + keys[0].get("kid", None) == kid or kid is None + ): + public_key = keys[0] + elif len(keys) > 1: + for key in keys: + if isinstance(key, dict): + key_kid = key.get("kid", None) + else: + key_kid = None + if ( + kid is not None + and isinstance(key, dict) + and key_kid is not None + and key_kid == kid + ): + public_key = key + + return public_key + + def is_allowed_domain(self, user_email: str) -> bool: + if self.litellm_jwtauth.user_allowed_email_domain is None: + return True + + email_domain = user_email.split("@")[-1] # Extract domain from email + if email_domain == self.litellm_jwtauth.user_allowed_email_domain: + return True + else: + return False + + async def auth_jwt(self, token: str) -> dict: + # Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html + # "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret + # the key in different ways (e.g. HS* and RS*)." + algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"] + + audience = os.getenv("JWT_AUDIENCE") + decode_options = None + if audience is None: + decode_options = {"verify_aud": False} + + import jwt + from jwt.algorithms import RSAAlgorithm + + header = jwt.get_unverified_header(token) + + verbose_proxy_logger.debug("header: %s", header) + + kid = header.get("kid", None) + + public_key = await self.get_public_key(kid=kid) + + if public_key is not None and isinstance(public_key, dict): + jwk = {} + if "kty" in public_key: + jwk["kty"] = public_key["kty"] + if "kid" in public_key: + jwk["kid"] = public_key["kid"] + if "n" in public_key: + jwk["n"] = public_key["n"] + if "e" in public_key: + jwk["e"] = public_key["e"] + + public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk)) + + try: + # decode the token using the public key + payload = jwt.decode( + token, + public_key_rsa, # type: ignore + algorithms=algorithms, + options=decode_options, + audience=audience, + leeway=self.leeway, # allow testing of expired tokens + ) + return payload + + except jwt.ExpiredSignatureError: + # the token is expired, do something to refresh it + raise Exception("Token Expired") + except Exception as e: + raise Exception(f"Validation fails: {str(e)}") + elif public_key is not None and isinstance(public_key, str): + try: + cert = x509.load_pem_x509_certificate( + public_key.encode(), default_backend() + ) + + # Extract public key + key = cert.public_key().public_bytes( + serialization.Encoding.PEM, + serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + # decode the token using the public key + payload = jwt.decode( + token, + key, + algorithms=algorithms, + audience=audience, + options=decode_options, + ) + return payload + + except jwt.ExpiredSignatureError: + # the token is expired, do something to refresh it + raise Exception("Token Expired") + except Exception as e: + raise Exception(f"Validation fails: {str(e)}") + + raise Exception("Invalid JWT Submitted") + + async def close(self): + await self.http_handler.close() + + +class JWTAuthManager: + """Manages JWT authentication and authorization operations""" + + @staticmethod + def can_rbac_role_call_route( + rbac_role: RBAC_ROLES, + general_settings: dict, + route: str, + ) -> Literal[True]: + """ + Checks if user is allowed to access the route, based on their role. + """ + role_based_routes = get_role_based_routes( + rbac_role=rbac_role, general_settings=general_settings + ) + + if role_based_routes is None or route is None: + return True + + is_allowed = _allowed_routes_check( + user_route=route, + allowed_routes=role_based_routes, + ) + + if not is_allowed: + raise HTTPException( + status_code=403, + detail=f"Role={rbac_role} not allowed to call route={route}. Allowed routes={role_based_routes}", + ) + + return True + + @staticmethod + def can_rbac_role_call_model( + rbac_role: RBAC_ROLES, + general_settings: dict, + model: Optional[str], + ) -> Literal[True]: + """ + Checks if user is allowed to access the model, based on their role. + """ + role_based_models = get_role_based_models( + rbac_role=rbac_role, general_settings=general_settings + ) + if role_based_models is None or model is None: + return True + + if model not in role_based_models: + raise HTTPException( + status_code=403, + detail=f"Role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}", + ) + + return True + + @staticmethod + def check_scope_based_access( + scope_mappings: List[ScopeMapping], + scopes: List[str], + request_data: dict, + general_settings: dict, + ) -> None: + """ + Check if scope allows access to the requested model + """ + if not scope_mappings: + return None + + allowed_models = [] + for sm in scope_mappings: + if sm.scope in scopes and sm.models: + allowed_models.extend(sm.models) + + requested_model = request_data.get("model") + + if not requested_model: + return None + + if requested_model not in allowed_models: + raise HTTPException( + status_code=403, + detail={ + "error": "model={} not allowed. Allowed_models={}".format( + requested_model, allowed_models + ) + }, + ) + return None + + @staticmethod + async def check_rbac_role( + jwt_handler: JWTHandler, + jwt_valid_token: dict, + general_settings: dict, + request_data: dict, + route: str, + rbac_role: Optional[RBAC_ROLES], + ) -> None: + """Validate RBAC role and model access permissions""" + if jwt_handler.litellm_jwtauth.enforce_rbac is True: + if rbac_role is None: + raise HTTPException( + status_code=403, + detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.", + ) + JWTAuthManager.can_rbac_role_call_model( + rbac_role=rbac_role, + general_settings=general_settings, + model=request_data.get("model"), + ) + JWTAuthManager.can_rbac_role_call_route( + rbac_role=rbac_role, + general_settings=general_settings, + route=route, + ) + + @staticmethod + async def check_admin_access( + jwt_handler: JWTHandler, + scopes: list, + route: str, + user_id: Optional[str], + org_id: Optional[str], + api_key: str, + ) -> Optional[JWTAuthBuilderResult]: + """Check admin status and route access permissions""" + if not jwt_handler.is_admin(scopes=scopes): + return None + + is_allowed = allowed_routes_check( + user_role=LitellmUserRoles.PROXY_ADMIN, + user_route=route, + litellm_proxy_roles=jwt_handler.litellm_jwtauth, + ) + if not is_allowed: + allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes + actual_routes = get_actual_routes(allowed_routes=allowed_routes) + raise Exception( + f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" + ) + + return JWTAuthBuilderResult( + is_proxy_admin=True, + team_object=None, + user_object=None, + end_user_object=None, + org_object=None, + token=api_key, + team_id=None, + user_id=user_id, + end_user_id=None, + org_id=org_id, + ) + + @staticmethod + async def find_and_validate_specific_team_id( + jwt_handler: JWTHandler, + jwt_valid_token: dict, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span], + proxy_logging_obj: ProxyLogging, + ) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]: + """Find and validate specific team ID""" + individual_team_id = jwt_handler.get_team_id( + token=jwt_valid_token, default_value=None + ) + + if not individual_team_id and jwt_handler.is_required_team_id() is True: + raise Exception( + f"No team id found in token. Checked team_id field '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'" + ) + + ## VALIDATE TEAM OBJECT ### + team_object: Optional[LiteLLM_TeamTable] = None + if individual_team_id: + team_object = await get_team_object( + team_id=individual_team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert, + ) + + return individual_team_id, team_object + + @staticmethod + def get_all_team_ids(jwt_handler: JWTHandler, jwt_valid_token: dict) -> Set[str]: + """Get combined team IDs from groups and individual team_id""" + team_ids_from_groups = jwt_handler.get_team_ids_from_jwt(token=jwt_valid_token) + + all_team_ids = set(team_ids_from_groups) + + return all_team_ids + + @staticmethod + async def find_team_with_model_access( + team_ids: Set[str], + requested_model: Optional[str], + route: str, + jwt_handler: JWTHandler, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span], + proxy_logging_obj: ProxyLogging, + ) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]: + """Find first team with access to the requested model""" + + if not team_ids: + if jwt_handler.litellm_jwtauth.enforce_team_based_model_access: + raise HTTPException( + status_code=403, + detail="No teams found in token. `enforce_team_based_model_access` is set to True. Token must belong to a team.", + ) + return None, None + + for team_id in team_ids: + try: + team_object = await get_team_object( + team_id=team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + + if team_object and team_object.models is not None: + team_models = team_object.models + if isinstance(team_models, list) and ( + not requested_model + or can_team_access_model( + model=requested_model, + team_object=team_object, + llm_router=None, + team_model_aliases=None, + ) + ): + is_allowed = allowed_routes_check( + user_role=LitellmUserRoles.TEAM, + user_route=route, + litellm_proxy_roles=jwt_handler.litellm_jwtauth, + ) + if is_allowed: + return team_id, team_object + except Exception: + continue + + if requested_model: + raise HTTPException( + status_code=403, + detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}. Check `/models` to see all available models.", + ) + + return None, None + + @staticmethod + async def get_user_info( + jwt_handler: JWTHandler, + jwt_valid_token: dict, + ) -> Tuple[Optional[str], Optional[str], Optional[bool]]: + """Get user email and validation status""" + user_email = jwt_handler.get_user_email( + token=jwt_valid_token, default_value=None + ) + valid_user_email = None + if jwt_handler.is_enforced_email_domain(): + valid_user_email = ( + False + if user_email is None + else jwt_handler.is_allowed_domain(user_email=user_email) + ) + user_id = jwt_handler.get_user_id( + token=jwt_valid_token, default_value=user_email + ) + return user_id, user_email, valid_user_email + + @staticmethod + async def get_objects( + user_id: Optional[str], + user_email: Optional[str], + org_id: Optional[str], + end_user_id: Optional[str], + valid_user_email: Optional[bool], + jwt_handler: JWTHandler, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span], + proxy_logging_obj: ProxyLogging, + ) -> Tuple[ + Optional[LiteLLM_UserTable], + Optional[LiteLLM_OrganizationTable], + Optional[LiteLLM_EndUserTable], + ]: + """Get user, org, and end user objects""" + org_object: Optional[LiteLLM_OrganizationTable] = None + if org_id: + org_object = ( + await get_org_object( + org_id=org_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + if org_id + else None + ) + + user_object: Optional[LiteLLM_UserTable] = None + if user_id: + user_object = ( + await get_user_object( + user_id=user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + user_id_upsert=jwt_handler.is_upsert_user_id( + valid_user_email=valid_user_email + ), + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + user_email=user_email, + sso_user_id=user_id, + ) + if user_id + else None + ) + + end_user_object: Optional[LiteLLM_EndUserTable] = None + if end_user_id: + end_user_object = ( + await get_end_user_object( + end_user_id=end_user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + if end_user_id + else None + ) + + return user_object, org_object, end_user_object + + @staticmethod + def validate_object_id( + user_id: Optional[str], + team_id: Optional[str], + enforce_rbac: bool, + is_proxy_admin: bool, + ) -> Literal[True]: + """If enforce_rbac is true, validate that a valid rbac id is returned for spend tracking""" + if enforce_rbac and not is_proxy_admin and not user_id and not team_id: + raise HTTPException( + status_code=403, + detail="No user or team id found in token. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.", + ) + return True + + @staticmethod + async def auth_builder( + api_key: str, + jwt_handler: JWTHandler, + request_data: dict, + general_settings: dict, + route: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span], + proxy_logging_obj: ProxyLogging, + ) -> JWTAuthBuilderResult: + """Main authentication and authorization builder""" + jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key) + + # Check custom validate + if jwt_handler.litellm_jwtauth.custom_validate: + if not jwt_handler.litellm_jwtauth.custom_validate(jwt_valid_token): + raise HTTPException( + status_code=403, + detail="Invalid JWT token", + ) + + # Check RBAC + rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token) + await JWTAuthManager.check_rbac_role( + jwt_handler, + jwt_valid_token, + general_settings, + request_data, + route, + rbac_role, + ) + + # Check Scope Based Access + scopes = jwt_handler.get_scopes(token=jwt_valid_token) + if ( + jwt_handler.litellm_jwtauth.enforce_scope_based_access + and jwt_handler.litellm_jwtauth.scope_mappings + ): + JWTAuthManager.check_scope_based_access( + scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings, + scopes=scopes, + request_data=request_data, + general_settings=general_settings, + ) + + object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None) + + # Get basic user info + scopes = jwt_handler.get_scopes(token=jwt_valid_token) + user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info( + jwt_handler, jwt_valid_token + ) + + # Get IDs + org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None) + end_user_id = jwt_handler.get_end_user_id( + token=jwt_valid_token, default_value=None + ) + team_id: Optional[str] = None + team_object: Optional[LiteLLM_TeamTable] = None + object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None) + + if rbac_role and object_id: + + if rbac_role == LitellmUserRoles.TEAM: + team_id = object_id + elif rbac_role == LitellmUserRoles.INTERNAL_USER: + user_id = object_id + + # Check admin access + admin_result = await JWTAuthManager.check_admin_access( + jwt_handler, scopes, route, user_id, org_id, api_key + ) + if admin_result: + return admin_result + + # Get team with model access + ## SPECIFIC TEAM ID + + if not team_id: + team_id, team_object = ( + await JWTAuthManager.find_and_validate_specific_team_id( + jwt_handler, + jwt_valid_token, + prisma_client, + user_api_key_cache, + parent_otel_span, + proxy_logging_obj, + ) + ) + + if not team_object and not team_id: + ## CHECK USER GROUP ACCESS + all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token) + team_id, team_object = await JWTAuthManager.find_team_with_model_access( + team_ids=all_team_ids, + requested_model=request_data.get("model"), + route=route, + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + + # Get other objects + user_object, org_object, end_user_object = await JWTAuthManager.get_objects( + user_id=user_id, + user_email=user_email, + org_id=org_id, + end_user_id=end_user_id, + valid_user_email=valid_user_email, + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + + # Validate that a valid rbac id is returned for spend tracking + JWTAuthManager.validate_object_id( + user_id=user_id, + team_id=team_id, + enforce_rbac=general_settings.get("enforce_rbac", False), + is_proxy_admin=False, + ) + + return JWTAuthBuilderResult( + is_proxy_admin=False, + team_id=team_id, + team_object=team_object, + user_id=user_id, + user_object=user_object, + org_id=org_id, + org_object=org_object, + end_user_id=end_user_id, + end_user_object=end_user_object, + token=api_key, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/litellm_license.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/litellm_license.py new file mode 100644 index 00000000..67ec91f5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/litellm_license.py @@ -0,0 +1,169 @@ +# What is this? +## If litellm license in env, checks if it's valid +import base64 +import json +import os +from datetime import datetime +from typing import Optional + +import httpx + +from litellm._logging import verbose_proxy_logger +from litellm.llms.custom_httpx.http_handler import HTTPHandler + + +class LicenseCheck: + """ + - Check if license in env + - Returns if license is valid + """ + + base_url = "https://license.litellm.ai" + + def __init__(self) -> None: + self.license_str = os.getenv("LITELLM_LICENSE", None) + verbose_proxy_logger.debug("License Str value - {}".format(self.license_str)) + self.http_handler = HTTPHandler(timeout=15) + self.public_key = None + self.read_public_key() + + def read_public_key(self): + try: + from cryptography.hazmat.primitives import serialization + + # current dir + current_dir = os.path.dirname(os.path.realpath(__file__)) + + # check if public_key.pem exists + _path_to_public_key = os.path.join(current_dir, "public_key.pem") + if os.path.exists(_path_to_public_key): + with open(_path_to_public_key, "rb") as key_file: + self.public_key = serialization.load_pem_public_key(key_file.read()) + else: + self.public_key = None + except Exception as e: + verbose_proxy_logger.error(f"Error reading public key: {str(e)}") + + def _verify(self, license_str: str) -> bool: + + verbose_proxy_logger.debug( + "litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format( + self.base_url, license_str + ) + ) + url = "{}/verify_license/{}".format(self.base_url, license_str) + + response: Optional[httpx.Response] = None + try: # don't impact user, if call fails + num_retries = 3 + for i in range(num_retries): + try: + response = self.http_handler.get(url=url) + if response is None: + raise Exception("No response from license server") + response.raise_for_status() + except httpx.HTTPStatusError: + if i == num_retries - 1: + raise + + if response is None: + raise Exception("No response from license server") + + response_json = response.json() + + premium = response_json["verify"] + + assert isinstance(premium, bool) + + verbose_proxy_logger.debug( + "litellm.proxy.auth.litellm_license.py::_verify - License={} is premium={}".format( + license_str, premium + ) + ) + return premium + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License={} via api. - {}".format( + license_str, str(e) + ) + ) + return False + + def is_premium(self) -> bool: + """ + 1. verify_license_without_api_request: checks if license was generate using private / public key pair + 2. _verify: checks if license is valid calling litellm API. This is the old way we were generating/validating license + """ + try: + verbose_proxy_logger.debug( + "litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - LiteLLM License={}".format( + self.license_str + ) + ) + + if self.license_str is None: + self.license_str = os.getenv("LITELLM_LICENSE", None) + + verbose_proxy_logger.debug( + "litellm.proxy.auth.litellm_license.py::is_premium() - Updated 'self.license_str' - {}".format( + self.license_str + ) + ) + + if self.license_str is None: + return False + elif ( + self.verify_license_without_api_request( + public_key=self.public_key, license_key=self.license_str + ) + is True + ): + return True + elif self._verify(license_str=self.license_str) is True: + return True + return False + except Exception: + return False + + def verify_license_without_api_request(self, public_key, license_key): + try: + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.asymmetric import padding + + # Decode the license key + decoded = base64.b64decode(license_key) + message, signature = decoded.split(b".", 1) + + # Verify the signature + public_key.verify( + signature, + message, + padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH, + ), + hashes.SHA256(), + ) + + # Decode and parse the data + license_data = json.loads(message.decode()) + + # debug information provided in license data + verbose_proxy_logger.debug("License data: %s", license_data) + + # Check expiration date + expiration_date = datetime.strptime( + license_data["expiration_date"], "%Y-%m-%d" + ) + if expiration_date < datetime.now(): + return False, "License has expired" + + return True + + except Exception as e: + verbose_proxy_logger.debug( + "litellm.proxy.auth.litellm_license.py::verify_license_without_api_request - Unable to verify License locally. - {}".format( + str(e) + ) + ) + return False diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py new file mode 100644 index 00000000..a48ef6ae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py @@ -0,0 +1,197 @@ +# What is this? +## Common checks for /v1/models and `/model/info` +import copy +from typing import Dict, List, Optional, Set + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth +from litellm.utils import get_valid_models + + +def _check_wildcard_routing(model: str) -> bool: + """ + Returns True if a model is a provider wildcard. + + eg: + - anthropic/* + - openai/* + - * + """ + if "*" in model: + return True + return False + + +def get_provider_models(provider: str) -> Optional[List[str]]: + """ + Returns the list of known models by provider + """ + if provider == "*": + return get_valid_models() + + if provider in litellm.models_by_provider: + provider_models = copy.deepcopy(litellm.models_by_provider[provider]) + for idx, _model in enumerate(provider_models): + if provider not in _model: + provider_models[idx] = f"{provider}/{_model}" + return provider_models + return None + + +def _get_models_from_access_groups( + model_access_groups: Dict[str, List[str]], + all_models: List[str], +) -> List[str]: + idx_to_remove = [] + new_models = [] + for idx, model in enumerate(all_models): + if model in model_access_groups: + idx_to_remove.append(idx) + new_models.extend(model_access_groups[model]) + + for idx in sorted(idx_to_remove, reverse=True): + all_models.pop(idx) + + all_models.extend(new_models) + return all_models + + +def get_key_models( + user_api_key_dict: UserAPIKeyAuth, + proxy_model_list: List[str], + model_access_groups: Dict[str, List[str]], +) -> List[str]: + """ + Returns: + - List of model name strings + - Empty list if no models set + - If model_access_groups is provided, only return models that are in the access groups + """ + all_models: List[str] = [] + if len(user_api_key_dict.models) > 0: + all_models = user_api_key_dict.models + if SpecialModelNames.all_team_models.value in all_models: + all_models = user_api_key_dict.team_models + if SpecialModelNames.all_proxy_models.value in all_models: + all_models = proxy_model_list + + all_models = _get_models_from_access_groups( + model_access_groups=model_access_groups, all_models=all_models + ) + + verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models))) + return all_models + + +def get_team_models( + team_models: List[str], + proxy_model_list: List[str], + model_access_groups: Dict[str, List[str]], +) -> List[str]: + """ + Returns: + - List of model name strings + - Empty list if no models set + - If model_access_groups is provided, only return models that are in the access groups + """ + all_models = [] + if len(team_models) > 0: + all_models = team_models + if SpecialModelNames.all_team_models.value in all_models: + all_models = team_models + if SpecialModelNames.all_proxy_models.value in all_models: + all_models = proxy_model_list + + all_models = _get_models_from_access_groups( + model_access_groups=model_access_groups, all_models=all_models + ) + + verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models))) + return all_models + + +def get_complete_model_list( + key_models: List[str], + team_models: List[str], + proxy_model_list: List[str], + user_model: Optional[str], + infer_model_from_keys: Optional[bool], + return_wildcard_routes: Optional[bool] = False, +) -> List[str]: + """Logic for returning complete model list for a given key + team pair""" + + """ + - If key list is empty -> defer to team list + - If team list is empty -> defer to proxy model list + + If list contains wildcard -> return known provider models + """ + unique_models: Set[str] = set() + if key_models: + unique_models.update(key_models) + elif team_models: + unique_models.update(team_models) + else: + unique_models.update(proxy_model_list) + + if user_model: + unique_models.add(user_model) + + if infer_model_from_keys: + valid_models = get_valid_models() + unique_models.update(valid_models) + + all_wildcard_models = _get_wildcard_models( + unique_models=unique_models, return_wildcard_routes=return_wildcard_routes + ) + + return list(unique_models) + all_wildcard_models + + +def get_known_models_from_wildcard(wildcard_model: str) -> List[str]: + try: + provider, model = wildcard_model.split("/", 1) + except ValueError: # safely fail + return [] + # get all known provider models + wildcard_models = get_provider_models(provider=provider) + if wildcard_models is None: + return [] + if model == "*": + return wildcard_models or [] + else: + model_prefix = model.replace("*", "") + filtered_wildcard_models = [ + wc_model + for wc_model in wildcard_models + if wc_model.split("/")[1].startswith(model_prefix) + ] + + return filtered_wildcard_models + + +def _get_wildcard_models( + unique_models: Set[str], return_wildcard_routes: Optional[bool] = False +) -> List[str]: + models_to_remove = set() + all_wildcard_models = [] + for model in unique_models: + if _check_wildcard_routing(model=model): + + if ( + return_wildcard_routes + ): # will add the wildcard route to the list eg: anthropic/*. + all_wildcard_models.append(model) + + # get all known provider models + wildcard_models = get_known_models_from_wildcard(wildcard_model=model) + + if wildcard_models is not None: + models_to_remove.add(model) + all_wildcard_models.extend(wildcard_models) + + for model in models_to_remove: + unique_models.remove(model) + + return all_wildcard_models diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py new file mode 100644 index 00000000..4851c270 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py @@ -0,0 +1,80 @@ +from litellm.proxy._types import UserAPIKeyAuth + + +async def check_oauth2_token(token: str) -> UserAPIKeyAuth: + """ + Makes a request to the token info endpoint to validate the OAuth2 token. + + Args: + token (str): The OAuth2 token to validate. + + Returns: + Literal[True]: If the token is valid. + + Raises: + ValueError: If the token is invalid, the request fails, or the token info endpoint is not set. + """ + import os + + import httpx + + from litellm._logging import verbose_proxy_logger + from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, + ) + from litellm.proxy._types import CommonProxyErrors + from litellm.proxy.proxy_server import premium_user + + if premium_user is not True: + raise ValueError( + "Oauth2 token validation is only available for premium users" + + CommonProxyErrors.not_premium_user.value + ) + + verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token) + # Get the token info endpoint from environment variable + token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT") + user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub") + user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role") + user_team_id_field_name = os.environ.get("OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id") + + if not token_info_endpoint: + raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set") + + client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + try: + response = await client.get(token_info_endpoint, headers=headers) + + # if it's a bad token we expect it to raise an HTTPStatusError + response.raise_for_status() + + # If we get here, the request was successful + data = response.json() + + verbose_proxy_logger.debug( + "Oauth2 token validation for token=%s, response from /token/info=%s", + token, + data, + ) + + # You might want to add additional checks here based on the response + # For example, checking if the token is expired or has the correct scope + user_id = data.get(user_id_field_name) + user_team_id = data.get(user_team_id_field_name) + user_role = data.get(user_role_field_name) + + return UserAPIKeyAuth( + api_key=token, + team_id=user_team_id, + user_id=user_id, + user_role=user_role, + ) + except httpx.HTTPStatusError as e: + # This will catch any 4xx or 5xx errors + raise ValueError(f"Oauth 2.0 Token validation failed: {e}") + except Exception as e: + # This will catch any other errors (like network issues) + raise ValueError(f"An error occurred during token validation: {e}") diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_proxy_hook.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_proxy_hook.py new file mode 100644 index 00000000..a1db5d84 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_proxy_hook.py @@ -0,0 +1,45 @@ +from typing import Any, Dict + +from fastapi import Request + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import UserAPIKeyAuth + + +async def handle_oauth2_proxy_request(request: Request) -> UserAPIKeyAuth: + """ + Handle request from oauth2 proxy. + """ + from litellm.proxy.proxy_server import general_settings + + verbose_proxy_logger.debug("Handling oauth2 proxy request") + # Define the OAuth2 config mappings + oauth2_config_mappings: Dict[str, str] = general_settings.get( + "oauth2_config_mappings", None + ) + verbose_proxy_logger.debug(f"Oauth2 config mappings: {oauth2_config_mappings}") + + if not oauth2_config_mappings: + raise ValueError("Oauth2 config mappings not found in general_settings") + # Initialize a dictionary to store the mapped values + auth_data: Dict[str, Any] = {} + + # Extract values from headers based on the mappings + for key, header in oauth2_config_mappings.items(): + value = request.headers.get(header) + if value: + # Convert max_budget to float if present + if key == "max_budget": + auth_data[key] = float(value) + # Convert models to list if present + elif key == "models": + auth_data[key] = [model.strip() for model in value.split(",")] + else: + auth_data[key] = value + verbose_proxy_logger.debug( + f"Auth data before creating UserAPIKeyAuth object: {auth_data}" + ) + user_api_key_auth = UserAPIKeyAuth(**auth_data) + verbose_proxy_logger.debug(f"UserAPIKeyAuth object created: {user_api_key_auth}") + # Create and return UserAPIKeyAuth object + return user_api_key_auth diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/public_key.pem b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/public_key.pem new file mode 100644 index 00000000..0962794a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/public_key.pem @@ -0,0 +1,9 @@ + -----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwcNBabWBZzrDhFAuA4Fh +FhIcA3rF7vrLb8+1yhF2U62AghQp9nStyuJRjxMUuldWgJ1yRJ2s7UffVw5r8DeA +dqXPD+w+3LCNwqJGaIKN08QGJXNArM3QtMaN0RTzAyQ4iibN1r6609W5muK9wGp0 +b1j5+iDUmf0ynItnhvaX6B8Xoaflc3WD/UBdrygLmsU5uR3XC86+/8ILoSZH3HtN +6FJmWhlhjS2TR1cKZv8K5D0WuADTFf5MF8jYFR+uORPj5Pe/EJlLGN26Lfn2QnGu +XgbPF6nCGwZ0hwH1Xkn3xzGaJ4xBEC761wqp5cHxWSDktHyFKnLbP3jVeegjVIHh +pQIDAQAB +-----END PUBLIC KEY-----
\ No newline at end of file diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py new file mode 100644 index 00000000..053cdb91 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py @@ -0,0 +1,187 @@ +import os +from typing import Any, Optional, Union + +import httpx + + +def init_rds_client( + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_session_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + aws_role_name: Optional[str] = None, + aws_web_identity_token: Optional[str] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, +): + from litellm.secret_managers.main import get_secret + + # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + standard_aws_region_name = get_secret("AWS_REGION", None) + ## CHECK IS 'os.environ/' passed in + # Define the list of parameters to check + params_to_check = [ + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + ] + + # Iterate over parameters and update if needed + for i, param in enumerate(params_to_check): + if param and param.startswith("os.environ/"): + params_to_check[i] = get_secret(param) # type: ignore + # Assign updated values back to parameters + ( + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + ) = params_to_check + + ### SET REGION NAME + region_name = aws_region_name + if aws_region_name: + region_name = aws_region_name + elif litellm_aws_region_name: + region_name = litellm_aws_region_name + elif standard_aws_region_name: + region_name = standard_aws_region_name + else: + raise Exception( + "AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", + ) + + import boto3 + + if isinstance(timeout, float): + config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore + elif isinstance(timeout, httpx.Timeout): + config = boto3.session.Config( # type: ignore + connect_timeout=timeout.connect, read_timeout=timeout.read + ) + else: + config = boto3.session.Config() # type: ignore + + ### CHECK STS ### + if ( + aws_web_identity_token is not None + and aws_role_name is not None + and aws_session_name is not None + ): + try: + oidc_token = open(aws_web_identity_token).read() # check if filepath + except Exception: + oidc_token = get_secret(aws_web_identity_token) + + if oidc_token is None: + raise Exception( + "OIDC token could not be retrieved from secret manager.", + ) + + sts_client = boto3.client("sts") + + # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=aws_role_name, + RoleSessionName=aws_session_name, + WebIdentityToken=oidc_token, + DurationSeconds=3600, + ) + + client = boto3.client( + service_name="rds", + aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], + aws_session_token=sts_response["Credentials"]["SessionToken"], + region_name=region_name, + config=config, + ) + + elif aws_role_name is not None and aws_session_name is not None: + # use sts if role name passed in + sts_client = boto3.client( + "sts", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) + + sts_response = sts_client.assume_role( + RoleArn=aws_role_name, RoleSessionName=aws_session_name + ) + + client = boto3.client( + service_name="rds", + aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], + aws_session_token=sts_response["Credentials"]["SessionToken"], + region_name=region_name, + config=config, + ) + elif aws_access_key_id is not None: + # uses auth params passed to completion + # aws_access_key_id is not None, assume user is trying to auth using litellm.completion + + client = boto3.client( + service_name="rds", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=region_name, + config=config, + ) + elif aws_profile_name is not None: + # uses auth values from AWS profile usually stored in ~/.aws/credentials + + client = boto3.Session(profile_name=aws_profile_name).client( + service_name="rds", + region_name=region_name, + config=config, + ) + + else: + # aws_access_key_id is None, assume user is trying to auth using env variables + # boto3 automatically reads env variables + + client = boto3.client( + service_name="rds", + region_name=region_name, + config=config, + ) + + return client + + +def generate_iam_auth_token( + db_host, db_port, db_user, client: Optional[Any] = None +) -> str: + from urllib.parse import quote + + if client is None: + boto_client = init_rds_client( + aws_region_name=os.getenv("AWS_REGION_NAME"), + aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), + aws_session_name=os.getenv("AWS_SESSION_NAME"), + aws_profile_name=os.getenv("AWS_PROFILE_NAME"), + aws_role_name=os.getenv("AWS_ROLE_NAME", os.getenv("AWS_ROLE_ARN")), + aws_web_identity_token=os.getenv( + "AWS_WEB_IDENTITY_TOKEN", os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE") + ), + ) + else: + boto_client = client + + token = boto_client.generate_db_auth_token( + DBHostname=db_host, Port=db_port, DBUsername=db_user + ) + cleaned_token = quote(token, safe="") + + return cleaned_token diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/route_checks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/route_checks.py new file mode 100644 index 00000000..a18a7ab5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/route_checks.py @@ -0,0 +1,257 @@ +import re +from typing import List, Optional + +from fastapi import HTTPException, Request, status + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + CommonProxyErrors, + LiteLLM_UserTable, + LiteLLMRoutes, + LitellmUserRoles, + UserAPIKeyAuth, +) + +from .auth_checks_organization import _user_is_org_admin + + +class RouteChecks: + + @staticmethod + def non_proxy_admin_allowed_routes_check( + user_obj: Optional[LiteLLM_UserTable], + _user_role: Optional[LitellmUserRoles], + route: str, + request: Request, + valid_token: UserAPIKeyAuth, + api_key: str, + request_data: dict, + ): + """ + Checks if Non Proxy Admin User is allowed to access the route + """ + + # Check user has defined custom admin routes + RouteChecks.custom_admin_only_route_check( + route=route, + ) + + if RouteChecks.is_llm_api_route(route=route): + pass + elif ( + route in LiteLLMRoutes.info_routes.value + ): # check if user allowed to call an info route + if route == "/key/info": + # handled by function itself + pass + elif route == "/user/info": + # check if user can access this route + query_params = request.query_params + user_id = query_params.get("user_id") + verbose_proxy_logger.debug( + f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}" + ) + if user_id and user_id != valid_token.user_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="key not allowed to access this user's info. user_id={}, key's user_id={}".format( + user_id, valid_token.user_id + ), + ) + elif route == "/model/info": + # /model/info just shows models user has access to + pass + elif route == "/team/info": + pass # handled by function itself + elif ( + route in LiteLLMRoutes.global_spend_tracking_routes.value + and getattr(valid_token, "permissions", None) is not None + and "get_spend_routes" in getattr(valid_token, "permissions", []) + ): + + pass + elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value: + if RouteChecks.is_llm_api_route(route=route): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"user not allowed to access this OpenAI routes, role= {_user_role}", + ) + if RouteChecks.check_route_access( + route=route, allowed_routes=LiteLLMRoutes.management_routes.value + ): + # the Admin Viewer is only allowed to call /user/update for their own user_id and can only update + if route == "/user/update": + + # Check the Request params are valid for PROXY_ADMIN_VIEW_ONLY + if request_data is not None and isinstance(request_data, dict): + _params_updated = request_data.keys() + for param in _params_updated: + if param not in ["user_email", "password"]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route} and updating invalid param: {param}. only user_email and password can be updated", + ) + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}", + ) + + elif ( + _user_role == LitellmUserRoles.INTERNAL_USER.value + and RouteChecks.check_route_access( + route=route, allowed_routes=LiteLLMRoutes.internal_user_routes.value + ) + ): + pass + elif _user_is_org_admin( + request_data=request_data, user_object=user_obj + ) and RouteChecks.check_route_access( + route=route, allowed_routes=LiteLLMRoutes.org_admin_allowed_routes.value + ): + pass + elif ( + _user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value + and RouteChecks.check_route_access( + route=route, + allowed_routes=LiteLLMRoutes.internal_user_view_only_routes.value, + ) + ): + pass + elif RouteChecks.check_route_access( + route=route, allowed_routes=LiteLLMRoutes.self_managed_routes.value + ): # routes that manage their own allowed/disallowed logic + pass + else: + user_role = "unknown" + user_id = "unknown" + if user_obj is not None: + user_role = user_obj.user_role or "unknown" + user_id = user_obj.user_id or "unknown" + raise Exception( + f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}" + ) + + @staticmethod + def custom_admin_only_route_check(route: str): + from litellm.proxy.proxy_server import general_settings, premium_user + + if "admin_only_routes" in general_settings: + if premium_user is not True: + verbose_proxy_logger.error( + f"Trying to use 'admin_only_routes' this is an Enterprise only feature. {CommonProxyErrors.not_premium_user.value}" + ) + return + if route in general_settings["admin_only_routes"]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"user not allowed to access this route. Route={route} is an admin only route", + ) + pass + + @staticmethod + def is_llm_api_route(route: str) -> bool: + """ + Helper to checks if provided route is an OpenAI route + + + Returns: + - True: if route is an OpenAI route + - False: if route is not an OpenAI route + """ + + if route in LiteLLMRoutes.openai_routes.value: + return True + + if route in LiteLLMRoutes.anthropic_routes.value: + return True + + # fuzzy match routes like "/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ" + # Check for routes with placeholders + for openai_route in LiteLLMRoutes.openai_routes.value: + # Replace placeholders with regex pattern + # placeholders are written as "/threads/{thread_id}" + if "{" in openai_route: + if RouteChecks._route_matches_pattern( + route=route, pattern=openai_route + ): + return True + + if RouteChecks._is_azure_openai_route(route=route): + return True + + for _llm_passthrough_route in LiteLLMRoutes.mapped_pass_through_routes.value: + if _llm_passthrough_route in route: + return True + + return False + + @staticmethod + def _is_azure_openai_route(route: str) -> bool: + """ + Check if route is a route from AzureOpenAI SDK client + + eg. + route='/openai/deployments/vertex_ai/gemini-1.5-flash/chat/completions' + """ + # Add support for deployment and engine model paths + deployment_pattern = r"^/openai/deployments/[^/]+/[^/]+/chat/completions$" + engine_pattern = r"^/engines/[^/]+/chat/completions$" + + if re.match(deployment_pattern, route) or re.match(engine_pattern, route): + return True + return False + + @staticmethod + def _route_matches_pattern(route: str, pattern: str) -> bool: + """ + Check if route matches the pattern placed in proxy/_types.py + + Example: + - pattern: "/threads/{thread_id}" + - route: "/threads/thread_49EIN5QF32s4mH20M7GFKdlZ" + - returns: True + + + - pattern: "/key/{token_id}/regenerate" + - route: "/key/regenerate/82akk800000000jjsk" + - returns: False, pattern is "/key/{token_id}/regenerate" + """ + pattern = re.sub(r"\{[^}]+\}", r"[^/]+", pattern) + # Anchor the pattern to match the entire string + pattern = f"^{pattern}$" + if re.match(pattern, route): + return True + return False + + @staticmethod + def check_route_access(route: str, allowed_routes: List[str]) -> bool: + """ + Check if a route has access by checking both exact matches and patterns + + Args: + route (str): The route to check + allowed_routes (list): List of allowed routes/patterns + + Returns: + bool: True if route is allowed, False otherwise + """ + return route in allowed_routes or any( # Check exact match + RouteChecks._route_matches_pattern(route=route, pattern=allowed_route) + for allowed_route in allowed_routes + ) # Check pattern match + + @staticmethod + def _is_assistants_api_request(request: Request) -> bool: + """ + Returns True if `thread` or `assistant` is in the request path + + Args: + request (Request): The request object + + Returns: + bool: True if `thread` or `assistant` is in the request path, False otherwise + """ + if "thread" in request.url.path or "assistant" in request.url.path: + return True + return False diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/service_account_checks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/service_account_checks.py new file mode 100644 index 00000000..87d7d668 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/service_account_checks.py @@ -0,0 +1,53 @@ +""" +Checks for LiteLLM service account keys + +""" + +from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth + + +def check_if_token_is_service_account(valid_token: UserAPIKeyAuth) -> bool: + """ + Checks if the token is a service account + + Returns: + bool: True if token is a service account + + """ + if valid_token.metadata: + if "service_account_id" in valid_token.metadata: + return True + return False + + +async def service_account_checks( + valid_token: UserAPIKeyAuth, request_data: dict +) -> bool: + """ + If a virtual key is a service account, checks it's a valid service account + + A token is a service account if it has a service_account_id in its metadata + + Service Account Specific Checks: + - Check if required_params is set + """ + + if check_if_token_is_service_account(valid_token) is not True: + return True + + from litellm.proxy.proxy_server import general_settings + + if "service_account_settings" in general_settings: + service_account_settings = general_settings["service_account_settings"] + if "enforced_params" in service_account_settings: + _enforced_params = service_account_settings["enforced_params"] + for param in _enforced_params: + if param not in request_data: + raise ProxyException( + type=ProxyErrorTypes.bad_request_error.value, + code=400, + param=param, + message=f"BadRequest please pass param={param} in request body. This is a required param for service account", + ) + + return True diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/user_api_key_auth.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/user_api_key_auth.py new file mode 100644 index 00000000..ace0bf49 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/user_api_key_auth.py @@ -0,0 +1,1337 @@ +""" +This file handles authentication for the LiteLLM Proxy. + +it checks if the user passed a valid API Key to the LiteLLM Proxy + +Returns a UserAPIKeyAuth object if the API key is valid + +""" + +import asyncio +import re +import secrets +from datetime import datetime, timezone +from typing import Optional, cast + +import fastapi +from fastapi import HTTPException, Request, WebSocket, status +from fastapi.security.api_key import APIKeyHeader + +import litellm +from litellm._logging import verbose_logger, verbose_proxy_logger +from litellm._service_logger import ServiceLogging +from litellm.caching import DualCache +from litellm.litellm_core_utils.dd_tracing import tracer +from litellm.proxy._types import * +from litellm.proxy.auth.auth_checks import ( + _cache_key_object, + _handle_failed_db_connection_for_get_key_object, + _virtual_key_max_budget_check, + _virtual_key_soft_budget_check, + can_key_call_model, + common_checks, + get_end_user_object, + get_key_object, + get_team_object, + get_user_object, + is_valid_fallback_model, +) +from litellm.proxy.auth.auth_utils import ( + _get_request_ip_address, + get_end_user_id_from_request_body, + get_request_route, + is_pass_through_provider_route, + pre_db_read_auth_checks, + route_in_additonal_public_routes, + should_run_auth_on_pass_through_provider_route, +) +from litellm.proxy.auth.handle_jwt import JWTAuthManager, JWTHandler +from litellm.proxy.auth.oauth2_check import check_oauth2_token +from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request +from litellm.proxy.auth.route_checks import RouteChecks +from litellm.proxy.auth.service_account_checks import service_account_checks +from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm.proxy.utils import PrismaClient, ProxyLogging +from litellm.types.services import ServiceTypes + +user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL + + +api_key_header = APIKeyHeader( + name=SpecialHeaders.openai_authorization.value, + auto_error=False, + description="Bearer token", +) +azure_api_key_header = APIKeyHeader( + name=SpecialHeaders.azure_authorization.value, + auto_error=False, + description="Some older versions of the openai Python package will send an API-Key header with just the API key ", +) +anthropic_api_key_header = APIKeyHeader( + name=SpecialHeaders.anthropic_authorization.value, + auto_error=False, + description="If anthropic client used.", +) +google_ai_studio_api_key_header = APIKeyHeader( + name=SpecialHeaders.google_ai_studio_authorization.value, + auto_error=False, + description="If google ai studio client used.", +) +azure_apim_header = APIKeyHeader( + name=SpecialHeaders.azure_apim_authorization.value, + auto_error=False, + description="The default name of the subscription key header of Azure", +) + + +def _get_bearer_token( + api_key: str, +): + if api_key.startswith("Bearer "): # ensure Bearer token passed in + api_key = api_key.replace("Bearer ", "") # extract the token + elif api_key.startswith("Basic "): + api_key = api_key.replace("Basic ", "") # handle langfuse input + elif api_key.startswith("bearer "): + api_key = api_key.replace("bearer ", "") + else: + api_key = "" + return api_key + + +def _is_ui_route( + route: str, + user_obj: Optional[LiteLLM_UserTable] = None, +) -> bool: + """ + - Check if the route is a UI used route + """ + # this token is only used for managing the ui + allowed_routes = LiteLLMRoutes.ui_routes.value + # check if the current route startswith any of the allowed routes + if ( + route is not None + and isinstance(route, str) + and any(route.startswith(allowed_route) for allowed_route in allowed_routes) + ): + # Do something if the current route starts with any of the allowed routes + return True + elif any( + RouteChecks._route_matches_pattern(route=route, pattern=allowed_route) + for allowed_route in allowed_routes + ): + return True + return False + + +def _is_api_route_allowed( + route: str, + request: Request, + request_data: dict, + api_key: str, + valid_token: Optional[UserAPIKeyAuth], + user_obj: Optional[LiteLLM_UserTable] = None, +) -> bool: + """ + - Route b/w api token check and normal token check + """ + _user_role = _get_user_role(user_obj=user_obj) + + if valid_token is None: + raise Exception("Invalid proxy server token passed. valid_token=None.") + + if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin + RouteChecks.non_proxy_admin_allowed_routes_check( + user_obj=user_obj, + _user_role=_user_role, + route=route, + request=request, + request_data=request_data, + api_key=api_key, + valid_token=valid_token, + ) + return True + + +def _is_allowed_route( + route: str, + token_type: Literal["ui", "api"], + request: Request, + request_data: dict, + api_key: str, + valid_token: Optional[UserAPIKeyAuth], + user_obj: Optional[LiteLLM_UserTable] = None, +) -> bool: + """ + - Route b/w ui token check and normal token check + """ + + if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj): + return True + else: + return _is_api_route_allowed( + route=route, + request=request, + request_data=request_data, + api_key=api_key, + valid_token=valid_token, + user_obj=user_obj, + ) + + +async def user_api_key_auth_websocket(websocket: WebSocket): + # Accept the WebSocket connection + + request = Request(scope={"type": "http"}) + request._url = websocket.url + + query_params = websocket.query_params + + model = query_params.get("model") + + async def return_body(): + return_string = f'{{"model": "{model}"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body # type: ignore + + # Extract the Authorization header + authorization = websocket.headers.get("authorization") + + # If no Authorization header, try the api-key header + if not authorization: + api_key = websocket.headers.get("api-key") + if not api_key: + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + raise HTTPException(status_code=403, detail="No API key provided") + else: + # Extract the API key from the Bearer token + if not authorization.startswith("Bearer "): + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + raise HTTPException( + status_code=403, detail="Invalid Authorization header format" + ) + + api_key = authorization[len("Bearer ") :].strip() + + # Call user_api_key_auth with the extracted API key + # Note: You'll need to modify this to work with WebSocket context if needed + try: + return await user_api_key_auth(request=request, api_key=f"Bearer {api_key}") + except Exception as e: + verbose_proxy_logger.exception(e) + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + raise HTTPException(status_code=403, detail=str(e)) + + +def update_valid_token_with_end_user_params( + valid_token: UserAPIKeyAuth, end_user_params: dict +) -> UserAPIKeyAuth: + valid_token.end_user_id = end_user_params.get("end_user_id") + valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit") + valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit") + valid_token.allowed_model_region = end_user_params.get("allowed_model_region") + return valid_token + + +async def get_global_proxy_spend( + litellm_proxy_admin_name: str, + user_api_key_cache: DualCache, + prisma_client: Optional[PrismaClient], + token: str, + proxy_logging_obj: ProxyLogging, +) -> Optional[float]: + global_proxy_spend = None + if litellm.max_budget > 0: # user set proxy max budget + # check cache + global_proxy_spend = await user_api_key_cache.async_get_cache( + key="{}:spend".format(litellm_proxy_admin_name) + ) + if global_proxy_spend is None and prisma_client is not None: + # get from db + sql_query = ( + """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";""" + ) + + response = await prisma_client.db.query_raw(query=sql_query) + + global_proxy_spend = response[0]["total_spend"] + + await user_api_key_cache.async_set_cache( + key="{}:spend".format(litellm_proxy_admin_name), + value=global_proxy_spend, + ) + if global_proxy_spend is not None: + user_info = CallInfo( + user_id=litellm_proxy_admin_name, + max_budget=litellm.max_budget, + spend=global_proxy_spend, + token=token, + ) + asyncio.create_task( + proxy_logging_obj.budget_alerts( + type="proxy_budget", + user_info=user_info, + ) + ) + return global_proxy_spend + + +def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str: + is_admin = jwt_handler.is_admin(scopes=scopes) + if is_admin: + return LitellmUserRoles.PROXY_ADMIN + else: + return LitellmUserRoles.TEAM + + +def get_model_from_request(request_data: dict, route: str) -> Optional[str]: + + # First try to get model from request_data + model = request_data.get("model") + + # If model not in request_data, try to extract from route + if model is None: + # Parse model from route that follows the pattern /openai/deployments/{model}/* + match = re.match(r"/openai/deployments/([^/]+)", route) + if match: + model = match.group(1) + + return model + + +async def _user_api_key_auth_builder( # noqa: PLR0915 + request: Request, + api_key: str, + azure_api_key_header: str, + anthropic_api_key_header: Optional[str], + google_ai_studio_api_key_header: Optional[str], + azure_apim_header: Optional[str], + request_data: dict, +) -> UserAPIKeyAuth: + + from litellm.proxy.proxy_server import ( + general_settings, + jwt_handler, + litellm_proxy_admin_name, + llm_model_list, + llm_router, + master_key, + model_max_budget_limiter, + open_telemetry_logger, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + user_custom_auth, + ) + + parent_otel_span: Optional[Span] = None + start_time = datetime.now() + route: str = get_request_route(request=request) + try: + + # get the request body + + await pre_db_read_auth_checks( + request_data=request_data, + request=request, + route=route, + ) + pass_through_endpoints: Optional[List[dict]] = general_settings.get( + "pass_through_endpoints", None + ) + passed_in_key: Optional[str] = None + if isinstance(api_key, str): + passed_in_key = api_key + api_key = _get_bearer_token(api_key=api_key) + elif isinstance(azure_api_key_header, str): + api_key = azure_api_key_header + elif isinstance(anthropic_api_key_header, str): + api_key = anthropic_api_key_header + elif isinstance(google_ai_studio_api_key_header, str): + api_key = google_ai_studio_api_key_header + elif isinstance(azure_apim_header, str): + api_key = azure_apim_header + elif pass_through_endpoints is not None: + for endpoint in pass_through_endpoints: + if endpoint.get("path", "") == route: + headers: Optional[dict] = endpoint.get("headers", None) + if headers is not None: + header_key: str = headers.get("litellm_user_api_key", "") + if request.headers.get(key=header_key) is not None: + api_key = request.headers.get(key=header_key) + + # if user wants to pass LiteLLM_Master_Key as a custom header, example pass litellm keys as X-LiteLLM-Key: Bearer sk-1234 + custom_litellm_key_header_name = general_settings.get("litellm_key_header_name") + if custom_litellm_key_header_name is not None: + api_key = get_api_key_from_custom_header( + request=request, + custom_litellm_key_header_name=custom_litellm_key_header_name, + ) + + if open_telemetry_logger is not None: + parent_otel_span = ( + open_telemetry_logger.create_litellm_proxy_request_started_span( + start_time=start_time, + headers=dict(request.headers), + ) + ) + + ### USER-DEFINED AUTH FUNCTION ### + if user_custom_auth is not None: + response = await user_custom_auth(request=request, api_key=api_key) # type: ignore + return UserAPIKeyAuth.model_validate(response) + + ### LITELLM-DEFINED AUTH FUNCTION ### + #### IF JWT #### + """ + LiteLLM supports using JWTs. + + Enable this in proxy config, by setting + ``` + general_settings: + enable_jwt_auth: true + ``` + """ + + ######## Route Checks Before Reading DB / Cache for "token" ################ + if ( + route in LiteLLMRoutes.public_routes.value # type: ignore + or route_in_additonal_public_routes(current_route=route) + ): + # check if public endpoint + return UserAPIKeyAuth(user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY) + elif is_pass_through_provider_route(route=route): + if should_run_auth_on_pass_through_provider_route(route=route) is False: + return UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY + ) + + ########## End of Route Checks Before Reading DB / Cache for "token" ######## + + if general_settings.get("enable_oauth2_auth", False) is True: + # return UserAPIKeyAuth object + # helper to check if the api_key is a valid oauth2 token + from litellm.proxy.proxy_server import premium_user + + if premium_user is not True: + raise ValueError( + "Oauth2 token validation is only available for premium users" + + CommonProxyErrors.not_premium_user.value + ) + + return await check_oauth2_token(token=api_key) + + if general_settings.get("enable_oauth2_proxy_auth", False) is True: + return await handle_oauth2_proxy_request(request=request) + + if general_settings.get("enable_jwt_auth", False) is True: + from litellm.proxy.proxy_server import premium_user + + if premium_user is not True: + raise ValueError( + f"JWT Auth is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}" + ) + is_jwt = jwt_handler.is_jwt(token=api_key) + verbose_proxy_logger.debug("is_jwt: %s", is_jwt) + if is_jwt: + result = await JWTAuthManager.auth_builder( + request_data=request_data, + general_settings=general_settings, + api_key=api_key, + jwt_handler=jwt_handler, + route=route, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + parent_otel_span=parent_otel_span, + ) + + is_proxy_admin = result["is_proxy_admin"] + team_id = result["team_id"] + team_object = result["team_object"] + user_id = result["user_id"] + user_object = result["user_object"] + end_user_id = result["end_user_id"] + end_user_object = result["end_user_object"] + org_id = result["org_id"] + token = result["token"] + + global_proxy_spend = await get_global_proxy_spend( + litellm_proxy_admin_name=litellm_proxy_admin_name, + user_api_key_cache=user_api_key_cache, + prisma_client=prisma_client, + token=token, + proxy_logging_obj=proxy_logging_obj, + ) + + if is_proxy_admin: + return UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, + ) + # run through common checks + _ = await common_checks( + request_body=request_data, + team_object=team_object, + user_object=user_object, + end_user_object=end_user_object, + general_settings=general_settings, + global_proxy_spend=global_proxy_spend, + route=route, + llm_router=llm_router, + proxy_logging_obj=proxy_logging_obj, + valid_token=None, + ) + + # return UserAPIKeyAuth object + return UserAPIKeyAuth( + api_key=None, + team_id=team_id, + team_tpm_limit=( + team_object.tpm_limit if team_object is not None else None + ), + team_rpm_limit=( + team_object.rpm_limit if team_object is not None else None + ), + team_models=team_object.models if team_object is not None else [], + user_role=LitellmUserRoles.INTERNAL_USER, + user_id=user_id, + org_id=org_id, + parent_otel_span=parent_otel_span, + end_user_id=end_user_id, + ) + + #### ELSE #### + ## CHECK PASS-THROUGH ENDPOINTS ## + is_mapped_pass_through_route: bool = False + for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: # type: ignore + if route.startswith(mapped_route): + is_mapped_pass_through_route = True + if is_mapped_pass_through_route: + if request.headers.get("litellm_user_api_key") is not None: + api_key = request.headers.get("litellm_user_api_key") or "" + if pass_through_endpoints is not None: + for endpoint in pass_through_endpoints: + if isinstance(endpoint, dict) and endpoint.get("path", "") == route: + ## IF AUTH DISABLED + if endpoint.get("auth") is not True: + return UserAPIKeyAuth() + ## IF AUTH ENABLED + ### IF CUSTOM PARSER REQUIRED + if ( + endpoint.get("custom_auth_parser") is not None + and endpoint.get("custom_auth_parser") == "langfuse" + ): + """ + - langfuse returns {'Authorization': 'Basic YW55dGhpbmc6YW55dGhpbmc'} + - check the langfuse public key if it contains the litellm api key + """ + import base64 + + api_key = api_key.replace("Basic ", "").strip() + decoded_bytes = base64.b64decode(api_key) + decoded_str = decoded_bytes.decode("utf-8") + api_key = decoded_str.split(":")[0] + else: + headers = endpoint.get("headers", None) + if headers is not None: + header_key = headers.get("litellm_user_api_key", "") + if ( + isinstance(request.headers, dict) + and request.headers.get(key=header_key) is not None # type: ignore + ): + api_key = request.headers.get(key=header_key) # type: ignore + if master_key is None: + if isinstance(api_key, str): + return UserAPIKeyAuth( + api_key=api_key, + user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, + ) + else: + return UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, + ) + elif api_key is None: # only require api key if master key is set + raise Exception("No api key passed in.") + elif api_key == "": + # missing 'Bearer ' prefix + raise Exception( + f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}" + ) + + if route == "/user/auth": + if general_settings.get("allow_user_auth", False) is True: + return UserAPIKeyAuth() + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="'allow_user_auth' not set or set to False", + ) + + ## Check END-USER OBJECT + _end_user_object = None + end_user_params = {} + + end_user_id = get_end_user_id_from_request_body(request_data) + if end_user_id: + try: + end_user_params["end_user_id"] = end_user_id + + # get end-user object + _end_user_object = await get_end_user_object( + end_user_id=end_user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + if _end_user_object is not None: + end_user_params["allowed_model_region"] = ( + _end_user_object.allowed_model_region + ) + if _end_user_object.litellm_budget_table is not None: + budget_info = _end_user_object.litellm_budget_table + if budget_info.tpm_limit is not None: + end_user_params["end_user_tpm_limit"] = ( + budget_info.tpm_limit + ) + if budget_info.rpm_limit is not None: + end_user_params["end_user_rpm_limit"] = ( + budget_info.rpm_limit + ) + if budget_info.max_budget is not None: + end_user_params["end_user_max_budget"] = ( + budget_info.max_budget + ) + except Exception as e: + if isinstance(e, litellm.BudgetExceededError): + raise e + verbose_proxy_logger.debug( + "Unable to find user in db. Error - {}".format(str(e)) + ) + pass + + ### CHECK IF ADMIN ### + # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead + ### CHECK IF ADMIN ### + # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead + ## Check CACHE + try: + valid_token = await get_key_object( + hashed_token=hash_token(api_key), + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + check_cache_only=True, + ) + except Exception: + verbose_logger.debug("api key not found in cache.") + valid_token = None + + if ( + valid_token is not None + and isinstance(valid_token, UserAPIKeyAuth) + and valid_token.user_role == LitellmUserRoles.PROXY_ADMIN + ): + # update end-user params on valid token + valid_token = update_valid_token_with_end_user_params( + valid_token=valid_token, end_user_params=end_user_params + ) + valid_token.parent_otel_span = parent_otel_span + + return valid_token + + if ( + valid_token is not None + and isinstance(valid_token, UserAPIKeyAuth) + and valid_token.team_id is not None + ): + ## UPDATE TEAM VALUES BASED ON CACHED TEAM OBJECT - allows `/team/update` values to work for cached token + try: + team_obj: LiteLLM_TeamTableCachedObj = await get_team_object( + team_id=valid_token.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + check_cache_only=True, + ) + + if ( + team_obj.last_refreshed_at is not None + and valid_token.last_refreshed_at is not None + and team_obj.last_refreshed_at > valid_token.last_refreshed_at + ): + team_obj_dict = team_obj.__dict__ + + for k, v in team_obj_dict.items(): + field_name = f"team_{k}" + if field_name in valid_token.__fields__: + setattr(valid_token, field_name, v) + except Exception as e: + verbose_logger.debug( + e + ) # moving from .warning to .debug as it spams logs when team missing from cache. + + try: + is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore + except Exception: + is_master_key_valid = False + + ## VALIDATE MASTER KEY ## + try: + assert isinstance(master_key, str) + except Exception: + raise HTTPException( + status_code=500, + detail={ + "Master key must be a valid string. Current type={}".format( + type(master_key) + ) + }, + ) + + if is_master_key_valid: + _user_api_key_obj = await _return_user_api_key_auth_obj( + user_obj=None, + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key=master_key, + parent_otel_span=parent_otel_span, + valid_token_dict={ + **end_user_params, + "user_id": litellm_proxy_admin_name, + }, + route=route, + start_time=start_time, + ) + asyncio.create_task( + _cache_key_object( + hashed_token=hash_token(master_key), + user_api_key_obj=_user_api_key_obj, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + ) + + _user_api_key_obj = update_valid_token_with_end_user_params( + valid_token=_user_api_key_obj, end_user_params=end_user_params + ) + + return _user_api_key_obj + + ## IF it's not a master key + ## Route should not be in master_key_only_routes + if route in LiteLLMRoutes.master_key_only_routes.value: # type: ignore + raise Exception( + f"Tried to access route={route}, which is only for MASTER KEY" + ) + + ## Check DB + if isinstance( + api_key, str + ): # if generated token, make sure it starts with sk-. + assert api_key.startswith( + "sk-" + ), "LiteLLM Virtual Key expected. Received={}, expected to start with 'sk-'.".format( + api_key + ) # prevent token hashes from being used + else: + verbose_logger.warning( + "litellm.proxy.proxy_server.user_api_key_auth(): Warning - Key={} is not a string.".format( + api_key + ) + ) + + if ( + prisma_client is None + ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error + return await _handle_failed_db_connection_for_get_key_object( + e=Exception("No connected db.") + ) + + ## check for cache hit (In-Memory Cache) + _user_role = None + if api_key.startswith("sk-"): + api_key = hash_token(token=api_key) + + if valid_token is None: + try: + valid_token = await get_key_object( + hashed_token=api_key, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + # update end-user params on valid token + # These can change per request - it's important to update them here + valid_token.end_user_id = end_user_params.get("end_user_id") + valid_token.end_user_tpm_limit = end_user_params.get( + "end_user_tpm_limit" + ) + valid_token.end_user_rpm_limit = end_user_params.get( + "end_user_rpm_limit" + ) + valid_token.allowed_model_region = end_user_params.get( + "allowed_model_region" + ) + # update key budget with temp budget increase + valid_token = _update_key_budget_with_temp_budget_increase( + valid_token + ) # updating it here, allows all downstream reporting / checks to use the updated budget + except Exception: + verbose_logger.info( + "litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format( + api_key + ) + ) + valid_token = None + + if valid_token is None: + raise Exception( + "Invalid proxy server token passed. Received API Key (hashed)={}. Unable to find token in cache or `LiteLLM_VerificationTokenTable`".format( + api_key + ) + ) + + user_obj: Optional[LiteLLM_UserTable] = None + valid_token_dict: dict = {} + if valid_token is not None: + # Got Valid Token from Cache, DB + # Run checks for + # 1. If token can call model + ## 1a. If token can call fallback models (if client-side fallbacks given) + # 2. If user_id for this token is in budget + # 3. If the user spend within their own team is within budget + # 4. If 'user' passed to /chat/completions, /embeddings endpoint is in budget + # 5. If token is expired + # 6. If token spend is under Budget for the token + # 7. If token spend per model is under budget per model + # 8. If token spend is under team budget + # 9. If team spend is under team budget + + ## base case ## key is disabled + if valid_token.blocked is True: + raise Exception( + "Key is blocked. Update via `/key/unblock` if you're admin." + ) + config = valid_token.config + + if config != {}: + model_list = config.get("model_list", []) + new_model_list = model_list + verbose_proxy_logger.debug( + f"\n new llm router model list {new_model_list}" + ) + elif ( + isinstance(valid_token.models, list) + and "all-team-models" in valid_token.models + ): + # Do not do any validation at this step + # the validation will occur when checking the team has access to this model + pass + else: + model = get_model_from_request(request_data, route) + fallback_models = cast( + Optional[List[ALL_FALLBACK_MODEL_VALUES]], + request_data.get("fallbacks", None), + ) + + if model is not None: + await can_key_call_model( + model=model, + llm_model_list=llm_model_list, + valid_token=valid_token, + llm_router=llm_router, + ) + + if fallback_models is not None: + for m in fallback_models: + await can_key_call_model( + model=m["model"] if isinstance(m, dict) else m, + llm_model_list=llm_model_list, + valid_token=valid_token, + llm_router=llm_router, + ) + await is_valid_fallback_model( + model=m["model"] if isinstance(m, dict) else m, + llm_router=llm_router, + user_model=None, + ) + + # Check 2. If user_id for this token is in budget - done in common_checks() + if valid_token.user_id is not None: + try: + user_obj = await get_user_object( + user_id=valid_token.user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + user_id_upsert=False, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + except Exception as e: + verbose_logger.debug( + "litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to get user from db/cache. Setting user_obj to None. Exception received - {}".format( + str(e) + ) + ) + user_obj = None + + # Check 3. Check if user is in their team budget + if valid_token.team_member_spend is not None: + if prisma_client is not None: + + _cache_key = f"{valid_token.team_id}_{valid_token.user_id}" + + team_member_info = await user_api_key_cache.async_get_cache( + key=_cache_key + ) + if team_member_info is None: + # read from DB + _user_id = valid_token.user_id + _team_id = valid_token.team_id + + if _user_id is not None and _team_id is not None: + team_member_info = await prisma_client.db.litellm_teammembership.find_first( + where={ + "user_id": _user_id, + "team_id": _team_id, + }, # type: ignore + include={"litellm_budget_table": True}, + ) + await user_api_key_cache.async_set_cache( + key=_cache_key, + value=team_member_info, + ) + + if ( + team_member_info is not None + and team_member_info.litellm_budget_table is not None + ): + team_member_budget = ( + team_member_info.litellm_budget_table.max_budget + ) + if team_member_budget is not None and team_member_budget > 0: + if valid_token.team_member_spend > team_member_budget: + raise litellm.BudgetExceededError( + current_cost=valid_token.team_member_spend, + max_budget=team_member_budget, + ) + + # Check 3. If token is expired + if valid_token.expires is not None: + current_time = datetime.now(timezone.utc) + if isinstance(valid_token.expires, datetime): + expiry_time = valid_token.expires + else: + expiry_time = datetime.fromisoformat(valid_token.expires) + if ( + expiry_time.tzinfo is None + or expiry_time.tzinfo.utcoffset(expiry_time) is None + ): + expiry_time = expiry_time.replace(tzinfo=timezone.utc) + verbose_proxy_logger.debug( + f"Checking if token expired, expiry time {expiry_time} and current time {current_time}" + ) + if expiry_time < current_time: + # Token exists but is expired. + raise ProxyException( + message=f"Authentication Error - Expired Key. Key Expiry time {expiry_time} and current time {current_time}", + type=ProxyErrorTypes.expired_key, + code=400, + param=api_key, + ) + + # Check 4. Token Spend is under budget + await _virtual_key_max_budget_check( + valid_token=valid_token, + proxy_logging_obj=proxy_logging_obj, + user_obj=user_obj, + ) + + # Check 5. Soft Budget Check + await _virtual_key_soft_budget_check( + valid_token=valid_token, + proxy_logging_obj=proxy_logging_obj, + ) + + # Check 5. Token Model Spend is under Model budget + max_budget_per_model = valid_token.model_max_budget + current_model = request_data.get("model", None) + + if ( + max_budget_per_model is not None + and isinstance(max_budget_per_model, dict) + and len(max_budget_per_model) > 0 + and prisma_client is not None + and current_model is not None + and valid_token.token is not None + ): + ## GET THE SPEND FOR THIS MODEL + await model_max_budget_limiter.is_key_within_model_budget( + user_api_key_dict=valid_token, + model=current_model, + ) + + # Check 6: Additional Common Checks across jwt + key auth + if valid_token.team_id is not None: + _team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable( + team_id=valid_token.team_id, + max_budget=valid_token.team_max_budget, + spend=valid_token.team_spend, + tpm_limit=valid_token.team_tpm_limit, + rpm_limit=valid_token.team_rpm_limit, + blocked=valid_token.team_blocked, + models=valid_token.team_models, + metadata=valid_token.team_metadata, + ) + else: + _team_obj = None + + # Check 7: Check if key is a service account key + await service_account_checks( + valid_token=valid_token, + request_data=request_data, + ) + + user_api_key_cache.set_cache( + key=valid_token.team_id, value=_team_obj + ) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py + + global_proxy_spend = None + if ( + litellm.max_budget > 0 and prisma_client is not None + ): # user set proxy max budget + # check cache + global_proxy_spend = await user_api_key_cache.async_get_cache( + key="{}:spend".format(litellm_proxy_admin_name) + ) + if global_proxy_spend is None: + # get from db + sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";""" + + response = await prisma_client.db.query_raw(query=sql_query) + + global_proxy_spend = response[0]["total_spend"] + await user_api_key_cache.async_set_cache( + key="{}:spend".format(litellm_proxy_admin_name), + value=global_proxy_spend, + ) + + if global_proxy_spend is not None: + call_info = CallInfo( + token=valid_token.token, + spend=global_proxy_spend, + max_budget=litellm.max_budget, + user_id=litellm_proxy_admin_name, + team_id=valid_token.team_id, + ) + asyncio.create_task( + proxy_logging_obj.budget_alerts( + type="proxy_budget", + user_info=call_info, + ) + ) + _ = await common_checks( + request_body=request_data, + team_object=_team_obj, + user_object=user_obj, + end_user_object=_end_user_object, + general_settings=general_settings, + global_proxy_spend=global_proxy_spend, + route=route, + llm_router=llm_router, + proxy_logging_obj=proxy_logging_obj, + valid_token=valid_token, + ) + # Token passed all checks + if valid_token is None: + raise HTTPException(401, detail="Invalid API key") + if valid_token.token is None: + raise HTTPException(401, detail="Invalid API key, no token associated") + api_key = valid_token.token + + # Add hashed token to cache + asyncio.create_task( + _cache_key_object( + hashed_token=api_key, + user_api_key_obj=valid_token, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + ) + + valid_token_dict = valid_token.model_dump(exclude_none=True) + valid_token_dict.pop("token", None) + + if _end_user_object is not None: + valid_token_dict.update(end_user_params) + + # check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions + # sso/login, ui/login, /key functions and /user functions + # this will never be allowed to call /chat/completions + token_team = getattr(valid_token, "team_id", None) + token_type: Literal["ui", "api"] = ( + "ui" + if token_team is not None and token_team == "litellm-dashboard" + else "api" + ) + _is_route_allowed = _is_allowed_route( + route=route, + token_type=token_type, + user_obj=user_obj, + request=request, + request_data=request_data, + api_key=api_key, + valid_token=valid_token, + ) + if not _is_route_allowed: + raise HTTPException(401, detail="Invalid route for UI token") + + if valid_token is None: + # No token was found when looking up in the DB + raise Exception("Invalid proxy server token passed") + if valid_token_dict is not None: + return await _return_user_api_key_auth_obj( + user_obj=user_obj, + api_key=api_key, + parent_otel_span=parent_otel_span, + valid_token_dict=valid_token_dict, + route=route, + start_time=start_time, + ) + else: + raise Exception() + except Exception as e: + requester_ip = _get_request_ip_address( + request=request, + use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), + ) + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format( + str(e), + requester_ip, + ), + extra={"requester_ip": requester_ip}, + ) + + # Log this exception to OTEL, Datadog etc + user_api_key_dict = UserAPIKeyAuth( + parent_otel_span=parent_otel_span, + api_key=api_key, + ) + asyncio.create_task( + proxy_logging_obj.post_call_failure_hook( + request_data=request_data, + original_exception=e, + user_api_key_dict=user_api_key_dict, + error_type=ProxyErrorTypes.auth_error, + route=route, + ) + ) + + if isinstance(e, litellm.BudgetExceededError): + raise ProxyException( + message=e.message, + type=ProxyErrorTypes.budget_exceeded, + param=None, + code=400, + ) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=status.HTTP_401_UNAUTHORIZED, + ) + + +@tracer.wrap() +async def user_api_key_auth( + request: Request, + api_key: str = fastapi.Security(api_key_header), + azure_api_key_header: str = fastapi.Security(azure_api_key_header), + anthropic_api_key_header: Optional[str] = fastapi.Security( + anthropic_api_key_header + ), + google_ai_studio_api_key_header: Optional[str] = fastapi.Security( + google_ai_studio_api_key_header + ), + azure_apim_header: Optional[str] = fastapi.Security(azure_apim_header), +) -> UserAPIKeyAuth: + """ + Parent function to authenticate user api key / jwt token. + """ + + request_data = await _read_request_body(request=request) + + user_api_key_auth_obj = await _user_api_key_auth_builder( + request=request, + api_key=api_key, + azure_api_key_header=azure_api_key_header, + anthropic_api_key_header=anthropic_api_key_header, + google_ai_studio_api_key_header=google_ai_studio_api_key_header, + azure_apim_header=azure_apim_header, + request_data=request_data, + ) + + end_user_id = get_end_user_id_from_request_body(request_data) + if end_user_id is not None: + user_api_key_auth_obj.end_user_id = end_user_id + + return user_api_key_auth_obj + + +async def _return_user_api_key_auth_obj( + user_obj: Optional[LiteLLM_UserTable], + api_key: str, + parent_otel_span: Optional[Span], + valid_token_dict: dict, + route: str, + start_time: datetime, + user_role: Optional[LitellmUserRoles] = None, +) -> UserAPIKeyAuth: + end_time = datetime.now() + + asyncio.create_task( + user_api_key_service_logger_obj.async_service_success_hook( + service=ServiceTypes.AUTH, + call_type=route, + start_time=start_time, + end_time=end_time, + duration=end_time.timestamp() - start_time.timestamp(), + parent_otel_span=parent_otel_span, + ) + ) + + retrieved_user_role = ( + user_role or _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER + ) + + user_api_key_kwargs = { + "api_key": api_key, + "parent_otel_span": parent_otel_span, + "user_role": retrieved_user_role, + **valid_token_dict, + } + if user_obj is not None: + user_api_key_kwargs.update( + user_tpm_limit=user_obj.tpm_limit, + user_rpm_limit=user_obj.rpm_limit, + user_email=user_obj.user_email, + ) + if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj): + user_api_key_kwargs.update( + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + return UserAPIKeyAuth(**user_api_key_kwargs) + else: + return UserAPIKeyAuth(**user_api_key_kwargs) + + +def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]): + if user_obj is None: + return False + + if ( + user_obj.user_role is not None + and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value + ): + return True + + if ( + user_obj.user_role is not None + and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value + ): + return True + + return False + + +def _get_user_role( + user_obj: Optional[LiteLLM_UserTable], +) -> Optional[LitellmUserRoles]: + if user_obj is None: + return None + + _user = user_obj + + _user_role = _user.user_role + try: + role = LitellmUserRoles(_user_role) + except ValueError: + return LitellmUserRoles.INTERNAL_USER + + return role + + +def get_api_key_from_custom_header( + request: Request, custom_litellm_key_header_name: str +) -> str: + """ + Get API key from custom header + + Args: + request (Request): Request object + custom_litellm_key_header_name (str): Custom header name + + Returns: + Optional[str]: API key + """ + api_key: str = "" + # use this as the virtual key passed to litellm proxy + custom_litellm_key_header_name = custom_litellm_key_header_name.lower() + _headers = {k.lower(): v for k, v in request.headers.items()} + verbose_proxy_logger.debug( + "searching for custom_litellm_key_header_name= %s, in headers=%s", + custom_litellm_key_header_name, + _headers, + ) + custom_api_key = _headers.get(custom_litellm_key_header_name) + if custom_api_key: + api_key = _get_bearer_token(api_key=custom_api_key) + verbose_proxy_logger.debug( + "Found custom API key using header: {}, setting api_key={}".format( + custom_litellm_key_header_name, api_key + ) + ) + else: + verbose_proxy_logger.exception( + f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer <api_key>" + ) + return api_key + + +def _get_temp_budget_increase(valid_token: UserAPIKeyAuth): + valid_token_metadata = valid_token.metadata + if ( + "temp_budget_increase" in valid_token_metadata + and "temp_budget_expiry" in valid_token_metadata + ): + expiry = datetime.fromisoformat(valid_token_metadata["temp_budget_expiry"]) + if expiry > datetime.now(): + return valid_token_metadata["temp_budget_increase"] + return None + + +def _update_key_budget_with_temp_budget_increase( + valid_token: UserAPIKeyAuth, +) -> UserAPIKeyAuth: + if valid_token.max_budget is None: + return valid_token + temp_budget_increase = _get_temp_budget_increase(valid_token) or 0.0 + valid_token.max_budget = valid_token.max_budget + temp_budget_increase + return valid_token |