aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/auth
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/auth')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py1373
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks_organization.py161
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py514
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/handle_jwt.py1001
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/litellm_license.py169
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py197
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py80
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_proxy_hook.py45
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/public_key.pem9
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py187
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/route_checks.py257
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/service_account_checks.py53
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/user_api_key_auth.py1337
13 files changed, 5383 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py
new file mode 100644
index 00000000..f029511d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py
@@ -0,0 +1,1373 @@
+# What is this?
+## Common auth checks between jwt + key based auth
+"""
+Got Valid Token from Cache, DB
+Run checks for:
+
+1. If user can call model
+2. If user is in budget
+3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
+"""
+import asyncio
+import re
+import time
+import traceback
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
+
+from fastapi import status
+from pydantic import BaseModel
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.caching.dual_cache import LimitedSizeOrderedDict
+from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
+from litellm.proxy._types import (
+ DB_CONNECTION_ERROR_TYPES,
+ RBAC_ROLES,
+ CallInfo,
+ LiteLLM_EndUserTable,
+ LiteLLM_JWTAuth,
+ LiteLLM_OrganizationMembershipTable,
+ LiteLLM_OrganizationTable,
+ LiteLLM_TeamTable,
+ LiteLLM_TeamTableCachedObj,
+ LiteLLM_UserTable,
+ LiteLLMRoutes,
+ LitellmUserRoles,
+ ProxyErrorTypes,
+ ProxyException,
+ RoleBasedPermissions,
+ SpecialModelNames,
+ UserAPIKeyAuth,
+)
+from litellm.proxy.auth.route_checks import RouteChecks
+from litellm.proxy.route_llm_request import route_request
+from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
+from litellm.router import Router
+from litellm.types.services import ServiceTypes
+
+from .auth_checks_organization import organization_role_based_access_check
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ Span = _Span
+else:
+ Span = Any
+
+
+last_db_access_time = LimitedSizeOrderedDict(max_size=100)
+db_cache_expiry = 5 # refresh every 5s
+
+all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
+
+
+async def common_checks(
+ request_body: dict,
+ team_object: Optional[LiteLLM_TeamTable],
+ user_object: Optional[LiteLLM_UserTable],
+ end_user_object: Optional[LiteLLM_EndUserTable],
+ global_proxy_spend: Optional[float],
+ general_settings: dict,
+ route: str,
+ llm_router: Optional[Router],
+ proxy_logging_obj: ProxyLogging,
+ valid_token: Optional[UserAPIKeyAuth],
+) -> bool:
+ """
+ Common checks across jwt + key-based auth.
+
+ 1. If team is blocked
+ 2. If team can call model
+ 3. If team is in budget
+ 4. If user passed in (JWT or key.user_id) - is in budget
+ 5. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
+ 6. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
+ 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
+ 8. [OPTIONAL] If guardrails modified - is request allowed to change this
+ 9. Check if request body is safe
+ 10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
+ """
+ _model = request_body.get("model", None)
+
+ # 1. If team is blocked
+ if team_object is not None and team_object.blocked is True:
+ raise Exception(
+ f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
+ )
+
+ # 2. If team can call model
+ if _model and team_object:
+ if not await can_team_access_model(
+ model=_model,
+ team_object=team_object,
+ llm_router=llm_router,
+ team_model_aliases=valid_token.team_model_aliases if valid_token else None,
+ ):
+ raise ProxyException(
+ message=f"Team not allowed to access model. Team={team_object.team_id}, Model={_model}. Allowed team models = {team_object.models}",
+ type=ProxyErrorTypes.team_model_access_denied,
+ param="model",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+ ## 2.1 If user can call model (if personal key)
+ if team_object is None and user_object is not None:
+ await can_user_call_model(
+ model=_model,
+ llm_router=llm_router,
+ user_object=user_object,
+ )
+
+ # 3. If team is in budget
+ await _team_max_budget_check(
+ team_object=team_object,
+ proxy_logging_obj=proxy_logging_obj,
+ valid_token=valid_token,
+ )
+
+ # 4. If user is in budget
+ ## 4.1 check personal budget, if personal key
+ if (
+ (team_object is None or team_object.team_id is None)
+ and user_object is not None
+ and user_object.max_budget is not None
+ ):
+ user_budget = user_object.max_budget
+ if user_budget < user_object.spend:
+ raise litellm.BudgetExceededError(
+ current_cost=user_object.spend,
+ max_budget=user_budget,
+ message=f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}",
+ )
+
+ ## 4.2 check team member budget, if team key
+ # 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
+ if end_user_object is not None and end_user_object.litellm_budget_table is not None:
+ end_user_budget = end_user_object.litellm_budget_table.max_budget
+ if end_user_budget is not None and end_user_object.spend > end_user_budget:
+ raise litellm.BudgetExceededError(
+ current_cost=end_user_object.spend,
+ max_budget=end_user_budget,
+ message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}",
+ )
+
+ # 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
+ if (
+ general_settings.get("enforce_user_param", None) is not None
+ and general_settings["enforce_user_param"] is True
+ ):
+ if RouteChecks.is_llm_api_route(route=route) and "user" not in request_body:
+ raise Exception(
+ f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
+ )
+ # 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
+ if (
+ litellm.max_budget > 0
+ and global_proxy_spend is not None
+ # only run global budget checks for OpenAI routes
+ # Reason - the Admin UI should continue working if the proxy crosses it's global budget
+ and RouteChecks.is_llm_api_route(route=route)
+ and route != "/v1/models"
+ and route != "/models"
+ ):
+ if global_proxy_spend > litellm.max_budget:
+ raise litellm.BudgetExceededError(
+ current_cost=global_proxy_spend, max_budget=litellm.max_budget
+ )
+
+ _request_metadata: dict = request_body.get("metadata", {}) or {}
+ if _request_metadata.get("guardrails"):
+ # check if team allowed to modify guardrails
+ from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails
+
+ can_modify: bool = can_modify_guardrails(team_object)
+ if can_modify is False:
+ from fastapi import HTTPException
+
+ raise HTTPException(
+ status_code=403,
+ detail={
+ "error": "Your team does not have permission to modify guardrails."
+ },
+ )
+
+ # 10 [OPTIONAL] Organization RBAC checks
+ organization_role_based_access_check(
+ user_object=user_object, route=route, request_body=request_body
+ )
+
+ return True
+
+
+def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
+ """
+ Return if a user is allowed to access route. Helper function for `allowed_routes_check`.
+
+ Parameters:
+ - user_route: str - the route the user is trying to call
+ - allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user.
+ """
+
+ for allowed_route in allowed_routes:
+ if (
+ allowed_route in LiteLLMRoutes.__members__
+ and user_route in LiteLLMRoutes[allowed_route].value
+ ):
+ return True
+ elif allowed_route == user_route:
+ return True
+ return False
+
+
+def allowed_routes_check(
+ user_role: Literal[
+ LitellmUserRoles.PROXY_ADMIN,
+ LitellmUserRoles.TEAM,
+ LitellmUserRoles.INTERNAL_USER,
+ ],
+ user_route: str,
+ litellm_proxy_roles: LiteLLM_JWTAuth,
+) -> bool:
+ """
+ Check if user -> not admin - allowed to access these routes
+ """
+
+ if user_role == LitellmUserRoles.PROXY_ADMIN:
+ is_allowed = _allowed_routes_check(
+ user_route=user_route,
+ allowed_routes=litellm_proxy_roles.admin_allowed_routes,
+ )
+ return is_allowed
+
+ elif user_role == LitellmUserRoles.TEAM:
+ if litellm_proxy_roles.team_allowed_routes is None:
+ """
+ By default allow a team to call openai + info routes
+ """
+ is_allowed = _allowed_routes_check(
+ user_route=user_route, allowed_routes=["openai_routes", "info_routes"]
+ )
+ return is_allowed
+ elif litellm_proxy_roles.team_allowed_routes is not None:
+ is_allowed = _allowed_routes_check(
+ user_route=user_route,
+ allowed_routes=litellm_proxy_roles.team_allowed_routes,
+ )
+ return is_allowed
+ return False
+
+
+def allowed_route_check_inside_route(
+ user_api_key_dict: UserAPIKeyAuth,
+ requested_user_id: Optional[str],
+) -> bool:
+ ret_val = True
+ if (
+ user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
+ and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
+ ):
+ ret_val = False
+ if requested_user_id is not None and user_api_key_dict.user_id is not None:
+ if user_api_key_dict.user_id == requested_user_id:
+ ret_val = True
+ return ret_val
+
+
+def get_actual_routes(allowed_routes: list) -> list:
+ actual_routes: list = []
+ for route_name in allowed_routes:
+ try:
+ route_value = LiteLLMRoutes[route_name].value
+ if isinstance(route_value, set):
+ actual_routes.extend(list(route_value))
+ else:
+ actual_routes.extend(route_value)
+
+ except KeyError:
+ actual_routes.append(route_name)
+ return actual_routes
+
+
+@log_db_metrics
+async def get_end_user_object(
+ end_user_id: Optional[str],
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span] = None,
+ proxy_logging_obj: Optional[ProxyLogging] = None,
+) -> Optional[LiteLLM_EndUserTable]:
+ """
+ Returns end user object, if in db.
+
+ Do a isolated check for end user in table vs. doing a combined key + team + user + end-user check, as key might come in frequently for different end-users. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (end-user).
+ """
+ if prisma_client is None:
+ raise Exception("No db connected")
+
+ if end_user_id is None:
+ return None
+ _key = "end_user_id:{}".format(end_user_id)
+
+ def check_in_budget(end_user_obj: LiteLLM_EndUserTable):
+ if end_user_obj.litellm_budget_table is None:
+ return
+ end_user_budget = end_user_obj.litellm_budget_table.max_budget
+ if end_user_budget is not None and end_user_obj.spend > end_user_budget:
+ raise litellm.BudgetExceededError(
+ current_cost=end_user_obj.spend, max_budget=end_user_budget
+ )
+
+ # check if in cache
+ cached_user_obj = await user_api_key_cache.async_get_cache(key=_key)
+ if cached_user_obj is not None:
+ if isinstance(cached_user_obj, dict):
+ return_obj = LiteLLM_EndUserTable(**cached_user_obj)
+ check_in_budget(end_user_obj=return_obj)
+ return return_obj
+ elif isinstance(cached_user_obj, LiteLLM_EndUserTable):
+ return_obj = cached_user_obj
+ check_in_budget(end_user_obj=return_obj)
+ return return_obj
+ # else, check db
+ try:
+ response = await prisma_client.db.litellm_endusertable.find_unique(
+ where={"user_id": end_user_id},
+ include={"litellm_budget_table": True},
+ )
+
+ if response is None:
+ raise Exception
+
+ # save the end-user object to cache
+ await user_api_key_cache.async_set_cache(
+ key="end_user_id:{}".format(end_user_id), value=response
+ )
+
+ _response = LiteLLM_EndUserTable(**response.dict())
+
+ check_in_budget(end_user_obj=_response)
+
+ return _response
+ except Exception as e: # if end-user not in db
+ if isinstance(e, litellm.BudgetExceededError):
+ raise e
+ return None
+
+
+def model_in_access_group(
+ model: str, team_models: Optional[List[str]], llm_router: Optional[Router]
+) -> bool:
+ from collections import defaultdict
+
+ if team_models is None:
+ return True
+ if model in team_models:
+ return True
+
+ access_groups: dict[str, list[str]] = defaultdict(list)
+ if llm_router:
+ access_groups = llm_router.get_model_access_groups(model_name=model)
+
+ if len(access_groups) > 0: # check if token contains any model access groups
+ for idx, m in enumerate(
+ team_models
+ ): # loop token models, if any of them are an access group add the access group
+ if m in access_groups:
+ return True
+
+ # Filter out models that are access_groups
+ filtered_models = [m for m in team_models if m not in access_groups]
+
+ if model in filtered_models:
+ return True
+
+ return False
+
+
+def _should_check_db(
+ key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int
+) -> bool:
+ """
+ Prevent calling db repeatedly for items that don't exist in the db.
+ """
+ current_time = time.time()
+ # if key doesn't exist in last_db_access_time -> check db
+ if key not in last_db_access_time:
+ return True
+ elif (
+ last_db_access_time[key][0] is not None
+ ): # check db for non-null values (for refresh operations)
+ return True
+ elif last_db_access_time[key][0] is None:
+ if current_time - last_db_access_time[key] >= db_cache_expiry:
+ return True
+ return False
+
+
+def _update_last_db_access_time(
+ key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict
+):
+ last_db_access_time[key] = (value, time.time())
+
+
+def _get_role_based_permissions(
+ rbac_role: RBAC_ROLES,
+ general_settings: dict,
+ key: Literal["models", "routes"],
+) -> Optional[List[str]]:
+ """
+ Get the role based permissions from the general settings.
+ """
+ role_based_permissions = cast(
+ Optional[List[RoleBasedPermissions]],
+ general_settings.get("role_permissions", []),
+ )
+ if role_based_permissions is None:
+ return None
+
+ for role_based_permission in role_based_permissions:
+
+ if role_based_permission.role == rbac_role:
+ return getattr(role_based_permission, key)
+
+ return None
+
+
+def get_role_based_models(
+ rbac_role: RBAC_ROLES,
+ general_settings: dict,
+) -> Optional[List[str]]:
+ """
+ Get the models allowed for a user role.
+
+ Used by JWT Auth.
+ """
+
+ return _get_role_based_permissions(
+ rbac_role=rbac_role,
+ general_settings=general_settings,
+ key="models",
+ )
+
+
+def get_role_based_routes(
+ rbac_role: RBAC_ROLES,
+ general_settings: dict,
+) -> Optional[List[str]]:
+ """
+ Get the routes allowed for a user role.
+ """
+
+ return _get_role_based_permissions(
+ rbac_role=rbac_role,
+ general_settings=general_settings,
+ key="routes",
+ )
+
+
+async def _get_fuzzy_user_object(
+ prisma_client: PrismaClient,
+ sso_user_id: Optional[str] = None,
+ user_email: Optional[str] = None,
+) -> Optional[LiteLLM_UserTable]:
+ """
+ Checks if sso user is in db.
+
+ Called when user id match is not found in db.
+
+ - Check if sso_user_id is user_id in db
+ - Check if sso_user_id is sso_user_id in db
+ - Check if user_email is user_email in db
+ - If not, create new user with user_email and sso_user_id and user_id = sso_user_id
+ """
+ response = None
+ if sso_user_id is not None:
+ response = await prisma_client.db.litellm_usertable.find_unique(
+ where={"sso_user_id": sso_user_id},
+ include={"organization_memberships": True},
+ )
+
+ if response is None and user_email is not None:
+ response = await prisma_client.db.litellm_usertable.find_first(
+ where={"user_email": user_email},
+ include={"organization_memberships": True},
+ )
+
+ if response is not None and sso_user_id is not None: # update sso_user_id
+ asyncio.create_task( # background task to update user with sso id
+ prisma_client.db.litellm_usertable.update(
+ where={"user_id": response.user_id},
+ data={"sso_user_id": sso_user_id},
+ )
+ )
+
+ return response
+
+
+@log_db_metrics
+async def get_user_object(
+ user_id: Optional[str],
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ user_id_upsert: bool,
+ parent_otel_span: Optional[Span] = None,
+ proxy_logging_obj: Optional[ProxyLogging] = None,
+ sso_user_id: Optional[str] = None,
+ user_email: Optional[str] = None,
+) -> Optional[LiteLLM_UserTable]:
+ """
+ - Check if user id in proxy User Table
+ - if valid, return LiteLLM_UserTable object with defined limits
+ - if not, then raise an error
+ """
+
+ if user_id is None:
+ return None
+
+ # check if in cache
+ cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
+ if cached_user_obj is not None:
+ if isinstance(cached_user_obj, dict):
+ return LiteLLM_UserTable(**cached_user_obj)
+ elif isinstance(cached_user_obj, LiteLLM_UserTable):
+ return cached_user_obj
+ # else, check db
+ if prisma_client is None:
+ raise Exception("No db connected")
+ try:
+ db_access_time_key = "user_id:{}".format(user_id)
+ should_check_db = _should_check_db(
+ key=db_access_time_key,
+ last_db_access_time=last_db_access_time,
+ db_cache_expiry=db_cache_expiry,
+ )
+
+ if should_check_db:
+ response = await prisma_client.db.litellm_usertable.find_unique(
+ where={"user_id": user_id}, include={"organization_memberships": True}
+ )
+
+ if response is None:
+ response = await _get_fuzzy_user_object(
+ prisma_client=prisma_client,
+ sso_user_id=sso_user_id,
+ user_email=user_email,
+ )
+
+ else:
+ response = None
+
+ if response is None:
+ if user_id_upsert:
+ response = await prisma_client.db.litellm_usertable.create(
+ data={"user_id": user_id},
+ include={"organization_memberships": True},
+ )
+ else:
+ raise Exception
+
+ if (
+ response.organization_memberships is not None
+ and len(response.organization_memberships) > 0
+ ):
+ # dump each organization membership to type LiteLLM_OrganizationMembershipTable
+ _dumped_memberships = [
+ LiteLLM_OrganizationMembershipTable(**membership.model_dump())
+ for membership in response.organization_memberships
+ if membership is not None
+ ]
+ response.organization_memberships = _dumped_memberships
+
+ _response = LiteLLM_UserTable(**dict(response))
+ response_dict = _response.model_dump()
+
+ # save the user object to cache
+ await user_api_key_cache.async_set_cache(key=user_id, value=response_dict)
+
+ # save to db access time
+ _update_last_db_access_time(
+ key=db_access_time_key,
+ value=response_dict,
+ last_db_access_time=last_db_access_time,
+ )
+
+ return _response
+ except Exception as e: # if user not in db
+ raise ValueError(
+ f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call. Got error - {e}"
+ )
+
+
+async def _cache_management_object(
+ key: str,
+ value: BaseModel,
+ user_api_key_cache: DualCache,
+ proxy_logging_obj: Optional[ProxyLogging],
+):
+ await user_api_key_cache.async_set_cache(key=key, value=value)
+
+
+async def _cache_team_object(
+ team_id: str,
+ team_table: LiteLLM_TeamTableCachedObj,
+ user_api_key_cache: DualCache,
+ proxy_logging_obj: Optional[ProxyLogging],
+):
+ key = "team_id:{}".format(team_id)
+
+ ## CACHE REFRESH TIME!
+ team_table.last_refreshed_at = time.time()
+
+ await _cache_management_object(
+ key=key,
+ value=team_table,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+
+async def _cache_key_object(
+ hashed_token: str,
+ user_api_key_obj: UserAPIKeyAuth,
+ user_api_key_cache: DualCache,
+ proxy_logging_obj: Optional[ProxyLogging],
+):
+ key = hashed_token
+
+ ## CACHE REFRESH TIME
+ user_api_key_obj.last_refreshed_at = time.time()
+
+ await _cache_management_object(
+ key=key,
+ value=user_api_key_obj,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+
+async def _delete_cache_key_object(
+ hashed_token: str,
+ user_api_key_cache: DualCache,
+ proxy_logging_obj: Optional[ProxyLogging],
+):
+ key = hashed_token
+
+ user_api_key_cache.delete_cache(key=key)
+
+ ## UPDATE REDIS CACHE ##
+ if proxy_logging_obj is not None:
+ await proxy_logging_obj.internal_usage_cache.dual_cache.async_delete_cache(
+ key=key
+ )
+
+
+@log_db_metrics
+async def _get_team_db_check(
+ team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None
+):
+ response = await prisma_client.db.litellm_teamtable.find_unique(
+ where={"team_id": team_id}
+ )
+
+ if response is None and team_id_upsert:
+ response = await prisma_client.db.litellm_teamtable.create(
+ data={"team_id": team_id}
+ )
+
+ return response
+
+
+async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
+ return await prisma_client.db.litellm_teamtable.find_unique(
+ where={"team_id": team_id}
+ )
+
+
+async def _get_team_object_from_user_api_key_cache(
+ team_id: str,
+ prisma_client: PrismaClient,
+ user_api_key_cache: DualCache,
+ last_db_access_time: LimitedSizeOrderedDict,
+ db_cache_expiry: int,
+ proxy_logging_obj: Optional[ProxyLogging],
+ key: str,
+ team_id_upsert: Optional[bool] = None,
+) -> LiteLLM_TeamTableCachedObj:
+ db_access_time_key = key
+ should_check_db = _should_check_db(
+ key=db_access_time_key,
+ last_db_access_time=last_db_access_time,
+ db_cache_expiry=db_cache_expiry,
+ )
+ if should_check_db:
+ response = await _get_team_db_check(
+ team_id=team_id, prisma_client=prisma_client, team_id_upsert=team_id_upsert
+ )
+ else:
+ response = None
+
+ if response is None:
+ raise Exception
+
+ _response = LiteLLM_TeamTableCachedObj(**response.dict())
+ # save the team object to cache
+ await _cache_team_object(
+ team_id=team_id,
+ team_table=_response,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+ # save to db access time
+ # save to db access time
+ _update_last_db_access_time(
+ key=db_access_time_key,
+ value=_response,
+ last_db_access_time=last_db_access_time,
+ )
+
+ return _response
+
+
+async def _get_team_object_from_cache(
+ key: str,
+ proxy_logging_obj: Optional[ProxyLogging],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span],
+) -> Optional[LiteLLM_TeamTableCachedObj]:
+ cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None
+
+ ## CHECK REDIS CACHE ##
+ if (
+ proxy_logging_obj is not None
+ and proxy_logging_obj.internal_usage_cache.dual_cache
+ ):
+
+ cached_team_obj = (
+ await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
+ key=key, parent_otel_span=parent_otel_span
+ )
+ )
+
+ if cached_team_obj is None:
+ cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
+
+ if cached_team_obj is not None:
+ if isinstance(cached_team_obj, dict):
+ return LiteLLM_TeamTableCachedObj(**cached_team_obj)
+ elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj):
+ return cached_team_obj
+
+ return None
+
+
+async def get_team_object(
+ team_id: str,
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span] = None,
+ proxy_logging_obj: Optional[ProxyLogging] = None,
+ check_cache_only: Optional[bool] = None,
+ check_db_only: Optional[bool] = None,
+ team_id_upsert: Optional[bool] = None,
+) -> LiteLLM_TeamTableCachedObj:
+ """
+ - Check if team id in proxy Team Table
+ - if valid, return LiteLLM_TeamTable object with defined limits
+ - if not, then raise an error
+
+ Raises:
+ - Exception: If team doesn't exist in db or cache
+ """
+ if prisma_client is None:
+ raise Exception(
+ "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
+ )
+
+ # check if in cache
+ key = "team_id:{}".format(team_id)
+
+ if not check_db_only:
+ cached_team_obj = await _get_team_object_from_cache(
+ key=key,
+ proxy_logging_obj=proxy_logging_obj,
+ user_api_key_cache=user_api_key_cache,
+ parent_otel_span=parent_otel_span,
+ )
+
+ if cached_team_obj is not None:
+ return cached_team_obj
+
+ if check_cache_only:
+ raise Exception(
+ f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
+ )
+
+ # else, check db
+ try:
+ return await _get_team_object_from_user_api_key_cache(
+ team_id=team_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ last_db_access_time=last_db_access_time,
+ db_cache_expiry=db_cache_expiry,
+ key=key,
+ team_id_upsert=team_id_upsert,
+ )
+ except Exception:
+ raise Exception(
+ f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
+ )
+
+
+@log_db_metrics
+async def get_key_object(
+ hashed_token: str,
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span] = None,
+ proxy_logging_obj: Optional[ProxyLogging] = None,
+ check_cache_only: Optional[bool] = None,
+) -> UserAPIKeyAuth:
+ """
+ - Check if team id in proxy Team Table
+ - if valid, return LiteLLM_TeamTable object with defined limits
+ - if not, then raise an error
+ """
+ if prisma_client is None:
+ raise Exception(
+ "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
+ )
+
+ # check if in cache
+ key = hashed_token
+
+ cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache(
+ key=key
+ )
+
+ if cached_key_obj is not None:
+ if isinstance(cached_key_obj, dict):
+ return UserAPIKeyAuth(**cached_key_obj)
+ elif isinstance(cached_key_obj, UserAPIKeyAuth):
+ return cached_key_obj
+
+ if check_cache_only:
+ raise Exception(
+ f"Key doesn't exist in cache + check_cache_only=True. key={key}."
+ )
+
+ # else, check db
+ try:
+ _valid_token: Optional[BaseModel] = await prisma_client.get_data(
+ token=hashed_token,
+ table_name="combined_view",
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+ if _valid_token is None:
+ raise Exception
+
+ _response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
+
+ # save the key object to cache
+ await _cache_key_object(
+ hashed_token=hashed_token,
+ user_api_key_obj=_response,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+ return _response
+ except DB_CONNECTION_ERROR_TYPES as e:
+ return await _handle_failed_db_connection_for_get_key_object(e=e)
+ except Exception:
+ traceback.print_exc()
+ raise Exception(
+ f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
+ )
+
+
+async def _handle_failed_db_connection_for_get_key_object(
+ e: Exception,
+) -> UserAPIKeyAuth:
+ """
+ Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB
+
+ Use this if you don't want failed DB queries to block LLM API reqiests
+
+ Returns:
+ - UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
+
+ Raises:
+ - Orignal Exception in all other cases
+ """
+ from litellm.proxy.proxy_server import (
+ general_settings,
+ litellm_proxy_admin_name,
+ proxy_logging_obj,
+ )
+
+ # If this flag is on, requests failing to connect to the DB will be allowed
+ if general_settings.get("allow_requests_on_db_unavailable", False) is True:
+ # log this as a DB failure on prometheus
+ proxy_logging_obj.service_logging_obj.service_failure_hook(
+ service=ServiceTypes.DB,
+ call_type="get_key_object",
+ error=e,
+ duration=0.0,
+ )
+
+ return UserAPIKeyAuth(
+ key_name="failed-to-connect-to-db",
+ token="failed-to-connect-to-db",
+ user_id=litellm_proxy_admin_name,
+ )
+ else:
+ # raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus
+ raise e
+
+
+@log_db_metrics
+async def get_org_object(
+ org_id: str,
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span] = None,
+ proxy_logging_obj: Optional[ProxyLogging] = None,
+) -> Optional[LiteLLM_OrganizationTable]:
+ """
+ - Check if org id in proxy Org Table
+ - if valid, return LiteLLM_OrganizationTable object
+ - if not, then raise an error
+ """
+ if prisma_client is None:
+ raise Exception(
+ "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
+ )
+
+ # check if in cache
+ cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
+ if cached_org_obj is not None:
+ if isinstance(cached_org_obj, dict):
+ return LiteLLM_OrganizationTable(**cached_org_obj)
+ elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
+ return cached_org_obj
+ # else, check db
+ try:
+ response = await prisma_client.db.litellm_organizationtable.find_unique(
+ where={"organization_id": org_id}
+ )
+
+ if response is None:
+ raise Exception
+
+ return response
+ except Exception:
+ raise Exception(
+ f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
+ )
+
+
+async def _can_object_call_model(
+ model: str,
+ llm_router: Optional[Router],
+ models: List[str],
+ team_model_aliases: Optional[Dict[str, str]] = None,
+ object_type: Literal["user", "team", "key"] = "user",
+) -> Literal[True]:
+ """
+ Checks if token can call a given model
+
+ Args:
+ - model: str
+ - llm_router: Optional[Router]
+ - models: List[str]
+ - team_model_aliases: Optional[Dict[str, str]]
+ - object_type: Literal["user", "team", "key"]. We use the object type to raise the correct exception type
+
+ Returns:
+ - True: if token allowed to call model
+
+ Raises:
+ - Exception: If token not allowed to call model
+ """
+ if model in litellm.model_alias_map:
+ model = litellm.model_alias_map[model]
+
+ ## check if model in allowed model names
+ from collections import defaultdict
+
+ access_groups: Dict[str, List[str]] = defaultdict(list)
+
+ if llm_router:
+ access_groups = llm_router.get_model_access_groups(model_name=model)
+ if (
+ len(access_groups) > 0 and llm_router is not None
+ ): # check if token contains any model access groups
+ for idx, m in enumerate(
+ models
+ ): # loop token models, if any of them are an access group add the access group
+ if m in access_groups:
+ return True
+
+ # Filter out models that are access_groups
+ filtered_models = [m for m in models if m not in access_groups]
+
+ verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
+
+ if _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases):
+ return True
+
+ if _model_matches_any_wildcard_pattern_in_list(
+ model=model, allowed_model_list=filtered_models
+ ):
+ return True
+
+ all_model_access: bool = False
+
+ if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models:
+ all_model_access = True
+
+ if SpecialModelNames.all_proxy_models.value in filtered_models:
+ all_model_access = True
+
+ if model is not None and model not in filtered_models and all_model_access is False:
+ raise ProxyException(
+ message=f"{object_type} not allowed to access model. This {object_type} can only access models={models}. Tried to access {model}",
+ type=ProxyErrorTypes.get_model_access_error_type_for_object(
+ object_type=object_type
+ ),
+ param="model",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+ verbose_proxy_logger.debug(
+ f"filtered allowed_models: {filtered_models}; models: {models}"
+ )
+ return True
+
+
+def _model_in_team_aliases(
+ model: str, team_model_aliases: Optional[Dict[str, str]] = None
+) -> bool:
+ """
+ Returns True if `model` being accessed is an alias of a team model
+
+ - `model=gpt-4o`
+ - `team_model_aliases={"gpt-4o": "gpt-4o-team-1"}`
+ - returns True
+
+ - `model=gp-4o`
+ - `team_model_aliases={"o-3": "o3-preview"}`
+ - returns False
+ """
+ if team_model_aliases:
+ if model in team_model_aliases:
+ return True
+ return False
+
+
+async def can_key_call_model(
+ model: str,
+ llm_model_list: Optional[list],
+ valid_token: UserAPIKeyAuth,
+ llm_router: Optional[litellm.Router],
+) -> Literal[True]:
+ """
+ Checks if token can call a given model
+
+ Returns:
+ - True: if token allowed to call model
+
+ Raises:
+ - Exception: If token not allowed to call model
+ """
+ return await _can_object_call_model(
+ model=model,
+ llm_router=llm_router,
+ models=valid_token.models,
+ team_model_aliases=valid_token.team_model_aliases,
+ object_type="key",
+ )
+
+
+async def can_team_access_model(
+ model: str,
+ team_object: Optional[LiteLLM_TeamTable],
+ llm_router: Optional[Router],
+ team_model_aliases: Optional[Dict[str, str]] = None,
+) -> Literal[True]:
+ """
+ Returns True if the team can access a specific model.
+
+ """
+ return await _can_object_call_model(
+ model=model,
+ llm_router=llm_router,
+ models=team_object.models if team_object else [],
+ team_model_aliases=team_model_aliases,
+ object_type="team",
+ )
+
+
+async def can_user_call_model(
+ model: str,
+ llm_router: Optional[Router],
+ user_object: Optional[LiteLLM_UserTable],
+) -> Literal[True]:
+
+ if user_object is None:
+ return True
+
+ if SpecialModelNames.no_default_models.value in user_object.models:
+ raise ProxyException(
+ message=f"User not allowed to access model. No default model access, only team models allowed. Tried to access {model}",
+ type=ProxyErrorTypes.key_model_access_denied,
+ param="model",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+ return await _can_object_call_model(
+ model=model,
+ llm_router=llm_router,
+ models=user_object.models,
+ object_type="user",
+ )
+
+
+async def is_valid_fallback_model(
+ model: str,
+ llm_router: Optional[Router],
+ user_model: Optional[str],
+) -> Literal[True]:
+ """
+ Try to route the fallback model request.
+
+ Validate if it can't be routed.
+
+ Help catch invalid fallback models.
+ """
+ await route_request(
+ data={
+ "model": model,
+ "messages": [{"role": "user", "content": "Who was Alexander?"}],
+ },
+ llm_router=llm_router,
+ user_model=user_model,
+ route_type="acompletion", # route type shouldn't affect the fallback model check
+ )
+
+ return True
+
+
+async def _virtual_key_max_budget_check(
+ valid_token: UserAPIKeyAuth,
+ proxy_logging_obj: ProxyLogging,
+ user_obj: Optional[LiteLLM_UserTable] = None,
+):
+ """
+ Raises:
+ BudgetExceededError if the token is over it's max budget.
+ Triggers a budget alert if the token is over it's max budget.
+
+ """
+ if valid_token.spend is not None and valid_token.max_budget is not None:
+ ####################################
+ # collect information for alerting #
+ ####################################
+
+ user_email = None
+ # Check if the token has any user id information
+ if user_obj is not None:
+ user_email = user_obj.user_email
+
+ call_info = CallInfo(
+ token=valid_token.token,
+ spend=valid_token.spend,
+ max_budget=valid_token.max_budget,
+ user_id=valid_token.user_id,
+ team_id=valid_token.team_id,
+ user_email=user_email,
+ key_alias=valid_token.key_alias,
+ )
+ asyncio.create_task(
+ proxy_logging_obj.budget_alerts(
+ type="token_budget",
+ user_info=call_info,
+ )
+ )
+
+ ####################################
+ # collect information for alerting #
+ ####################################
+
+ if valid_token.spend >= valid_token.max_budget:
+ raise litellm.BudgetExceededError(
+ current_cost=valid_token.spend,
+ max_budget=valid_token.max_budget,
+ )
+
+
+async def _virtual_key_soft_budget_check(
+ valid_token: UserAPIKeyAuth,
+ proxy_logging_obj: ProxyLogging,
+):
+ """
+ Triggers a budget alert if the token is over it's soft budget.
+
+ """
+
+ if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
+ verbose_proxy_logger.debug(
+ "Crossed Soft Budget for token %s, spend %s, soft_budget %s",
+ valid_token.token,
+ valid_token.spend,
+ valid_token.soft_budget,
+ )
+ call_info = CallInfo(
+ token=valid_token.token,
+ spend=valid_token.spend,
+ max_budget=valid_token.max_budget,
+ soft_budget=valid_token.soft_budget,
+ user_id=valid_token.user_id,
+ team_id=valid_token.team_id,
+ team_alias=valid_token.team_alias,
+ user_email=None,
+ key_alias=valid_token.key_alias,
+ )
+ asyncio.create_task(
+ proxy_logging_obj.budget_alerts(
+ type="soft_budget",
+ user_info=call_info,
+ )
+ )
+
+
+async def _team_max_budget_check(
+ team_object: Optional[LiteLLM_TeamTable],
+ valid_token: Optional[UserAPIKeyAuth],
+ proxy_logging_obj: ProxyLogging,
+):
+ """
+ Check if the team is over it's max budget.
+
+ Raises:
+ BudgetExceededError if the team is over it's max budget.
+ Triggers a budget alert if the team is over it's max budget.
+ """
+ if (
+ team_object is not None
+ and team_object.max_budget is not None
+ and team_object.spend is not None
+ and team_object.spend > team_object.max_budget
+ ):
+ if valid_token:
+ call_info = CallInfo(
+ token=valid_token.token,
+ spend=team_object.spend,
+ max_budget=team_object.max_budget,
+ user_id=valid_token.user_id,
+ team_id=valid_token.team_id,
+ team_alias=valid_token.team_alias,
+ )
+ asyncio.create_task(
+ proxy_logging_obj.budget_alerts(
+ type="team_budget",
+ user_info=call_info,
+ )
+ )
+
+ raise litellm.BudgetExceededError(
+ current_cost=team_object.spend,
+ max_budget=team_object.max_budget,
+ message=f"Budget has been exceeded! Team={team_object.team_id} Current cost: {team_object.spend}, Max budget: {team_object.max_budget}",
+ )
+
+
+def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool:
+ """
+ Check if a model matches an allowed pattern.
+ Handles exact matches and wildcard patterns.
+
+ Args:
+ model (str): The model to check (e.g., "bedrock/anthropic.claude-3-5-sonnet-20240620")
+ allowed_model_pattern (str): The allowed pattern (e.g., "bedrock/*", "*", "openai/*")
+
+ Returns:
+ bool: True if model matches the pattern, False otherwise
+ """
+ if "*" in allowed_model_pattern:
+ pattern = f"^{allowed_model_pattern.replace('*', '.*')}$"
+ return bool(re.match(pattern, model))
+
+ return False
+
+
+def _model_matches_any_wildcard_pattern_in_list(
+ model: str, allowed_model_list: list
+) -> bool:
+ """
+ Returns True if a model matches any wildcard pattern in a list.
+
+ eg.
+ - model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns True
+ - model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True
+ - model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False
+ """
+
+ if any(
+ _is_wildcard_pattern(allowed_model_pattern)
+ and is_model_allowed_by_pattern(
+ model=model, allowed_model_pattern=allowed_model_pattern
+ )
+ for allowed_model_pattern in allowed_model_list
+ ):
+ return True
+
+ if any(
+ _is_wildcard_pattern(allowed_model_pattern)
+ and _model_custom_llm_provider_matches_wildcard_pattern(
+ model=model, allowed_model_pattern=allowed_model_pattern
+ )
+ for allowed_model_pattern in allowed_model_list
+ ):
+ return True
+
+ return False
+
+
+def _model_custom_llm_provider_matches_wildcard_pattern(
+ model: str, allowed_model_pattern: str
+) -> bool:
+ """
+ Returns True for this scenario:
+ - `model=gpt-4o`
+ - `allowed_model_pattern=openai/*`
+
+ or
+ - `model=claude-3-5-sonnet-20240620`
+ - `allowed_model_pattern=anthropic/*`
+ """
+ try:
+ model, custom_llm_provider, _, _ = get_llm_provider(model=model)
+ except Exception:
+ return False
+
+ return is_model_allowed_by_pattern(
+ model=f"{custom_llm_provider}/{model}",
+ allowed_model_pattern=allowed_model_pattern,
+ )
+
+
+def _is_wildcard_pattern(allowed_model_pattern: str) -> bool:
+ """
+ Returns True if the pattern is a wildcard pattern.
+
+ Checks if `*` is in the pattern.
+ """
+ return "*" in allowed_model_pattern
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks_organization.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks_organization.py
new file mode 100644
index 00000000..3da3d8dd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks_organization.py
@@ -0,0 +1,161 @@
+"""
+Auth Checks for Organizations
+"""
+
+from typing import Dict, List, Optional, Tuple
+
+from fastapi import status
+
+from litellm.proxy._types import *
+
+
+def organization_role_based_access_check(
+ request_body: dict,
+ user_object: Optional[LiteLLM_UserTable],
+ route: str,
+):
+ """
+ Role based access control checks only run if a user is part of an Organization
+
+ Organization Checks:
+ ONLY RUN IF user_object.organization_memberships is not None
+
+ 1. Only Proxy Admins can access /organization/new
+ 2. IF route is a LiteLLMRoutes.org_admin_only_routes, then check if user is an Org Admin for that organization
+
+ """
+
+ if user_object is None:
+ return
+
+ passed_organization_id: Optional[str] = request_body.get("organization_id", None)
+
+ if route == "/organization/new":
+ if user_object.user_role != LitellmUserRoles.PROXY_ADMIN.value:
+ raise ProxyException(
+ message=f"Only proxy admins can create new organizations. You are {user_object.user_role}",
+ type=ProxyErrorTypes.auth_error.value,
+ param="user_role",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+ if user_object.user_role == LitellmUserRoles.PROXY_ADMIN.value:
+ return
+
+ # Checks if route is an Org Admin Only Route
+ if route in LiteLLMRoutes.org_admin_only_routes.value:
+ _user_organizations, _user_organization_role_mapping = (
+ get_user_organization_info(user_object)
+ )
+
+ if user_object.organization_memberships is None:
+ raise ProxyException(
+ message=f"Tried to access route={route} but you are not a member of any organization. Please contact the proxy admin to request access.",
+ type=ProxyErrorTypes.auth_error.value,
+ param="organization_id",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+ if passed_organization_id is None:
+ raise ProxyException(
+ message="Passed organization_id is None, please pass an organization_id in your request",
+ type=ProxyErrorTypes.auth_error.value,
+ param="organization_id",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+ user_role: Optional[LitellmUserRoles] = _user_organization_role_mapping.get(
+ passed_organization_id
+ )
+ if user_role is None:
+ raise ProxyException(
+ message=f"You do not have a role within the selected organization. Passed organization_id: {passed_organization_id}. Please contact the organization admin to request access.",
+ type=ProxyErrorTypes.auth_error.value,
+ param="organization_id",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+ if user_role != LitellmUserRoles.ORG_ADMIN.value:
+ raise ProxyException(
+ message=f"You do not have the required role to perform {route} in Organization {passed_organization_id}. Your role is {user_role} in Organization {passed_organization_id}",
+ type=ProxyErrorTypes.auth_error.value,
+ param="user_role",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+ elif route == "/team/new":
+ # if user is part of multiple teams, then they need to specify the organization_id
+ _user_organizations, _user_organization_role_mapping = (
+ get_user_organization_info(user_object)
+ )
+ if (
+ user_object.organization_memberships is not None
+ and len(user_object.organization_memberships) > 0
+ ):
+ if passed_organization_id is None:
+ raise ProxyException(
+ message=f"Passed organization_id is None, please specify the organization_id in your request. You are part of multiple organizations: {_user_organizations}",
+ type=ProxyErrorTypes.auth_error.value,
+ param="organization_id",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+ _user_role_in_passed_org = _user_organization_role_mapping.get(
+ passed_organization_id
+ )
+ if _user_role_in_passed_org != LitellmUserRoles.ORG_ADMIN.value:
+ raise ProxyException(
+ message=f"You do not have the required role to call {route}. Your role is {_user_role_in_passed_org} in Organization {passed_organization_id}",
+ type=ProxyErrorTypes.auth_error.value,
+ param="user_role",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+
+def get_user_organization_info(
+ user_object: LiteLLM_UserTable,
+) -> Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]:
+ """
+ Helper function to extract user organization information.
+
+ Args:
+ user_object (LiteLLM_UserTable): The user object containing organization memberships.
+
+ Returns:
+ Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]: A tuple containing:
+ - List of organization IDs the user is a member of
+ - Dictionary mapping organization IDs to user roles
+ """
+ _user_organizations: List[str] = []
+ _user_organization_role_mapping: Dict[str, Optional[LitellmUserRoles]] = {}
+
+ if user_object.organization_memberships is not None:
+ for _membership in user_object.organization_memberships:
+ if _membership.organization_id is not None:
+ _user_organizations.append(_membership.organization_id)
+ _user_organization_role_mapping[_membership.organization_id] = _membership.user_role # type: ignore
+
+ return _user_organizations, _user_organization_role_mapping
+
+
+def _user_is_org_admin(
+ request_data: dict,
+ user_object: Optional[LiteLLM_UserTable] = None,
+) -> bool:
+ """
+ Helper function to check if user is an org admin for the passed organization_id
+ """
+ if request_data.get("organization_id", None) is None:
+ return False
+
+ if user_object is None:
+ return False
+
+ if user_object.organization_memberships is None:
+ return False
+
+ for _membership in user_object.organization_memberships:
+ if _membership.organization_id == request_data.get("organization_id", None):
+ if _membership.user_role == LitellmUserRoles.ORG_ADMIN.value:
+ return True
+
+ return False
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py
new file mode 100644
index 00000000..91fcaf7e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py
@@ -0,0 +1,514 @@
+import os
+import re
+import sys
+from typing import Any, List, Optional, Tuple
+
+from fastapi import HTTPException, Request, status
+
+from litellm import Router, provider_list
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import *
+from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS
+
+
+def _get_request_ip_address(
+ request: Request, use_x_forwarded_for: Optional[bool] = False
+) -> Optional[str]:
+
+ client_ip = None
+ if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
+ client_ip = request.headers["x-forwarded-for"]
+ elif request.client is not None:
+ client_ip = request.client.host
+ else:
+ client_ip = ""
+
+ return client_ip
+
+
+def _check_valid_ip(
+ allowed_ips: Optional[List[str]],
+ request: Request,
+ use_x_forwarded_for: Optional[bool] = False,
+) -> Tuple[bool, Optional[str]]:
+ """
+ Returns if ip is allowed or not
+ """
+ if allowed_ips is None: # if not set, assume true
+ return True, None
+
+ # if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
+ client_ip = _get_request_ip_address(
+ request=request, use_x_forwarded_for=use_x_forwarded_for
+ )
+
+ # Check if IP address is allowed
+ if client_ip not in allowed_ips:
+ return False, client_ip
+
+ return True, client_ip
+
+
+def check_complete_credentials(request_body: dict) -> bool:
+ """
+ if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks.
+ """
+ given_model: Optional[str] = None
+
+ given_model = request_body.get("model")
+ if given_model is None:
+ return False
+
+ if (
+ "sagemaker" in given_model
+ or "bedrock" in given_model
+ or "vertex_ai" in given_model
+ or "vertex_ai_beta" in given_model
+ ):
+ # complex credentials - easier to make a malicious request
+ return False
+
+ if "api_key" in request_body:
+ return True
+
+ return False
+
+
+def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
+ """
+ Check if request_body_value matches the regex_str or is equal to param
+ """
+ if re.match(regex_str, request_body_value) or regex_str == request_body_value:
+ return True
+ return False
+
+
+def _is_param_allowed(
+ param: str,
+ request_body_value: Any,
+ configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
+) -> bool:
+ """
+ Check if param is a str or dict and if request_body_value is in the list of allowed values
+ """
+ if configurable_clientside_auth_params is None:
+ return False
+
+ for item in configurable_clientside_auth_params:
+ if isinstance(item, str) and param == item:
+ return True
+ elif isinstance(item, Dict):
+ if param == "api_base" and check_regex_or_str_match(
+ request_body_value=request_body_value,
+ regex_str=item["api_base"],
+ ): # assume param is a regex
+ return True
+
+ return False
+
+
+def _allow_model_level_clientside_configurable_parameters(
+ model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
+) -> bool:
+ """
+ Check if model is allowed to use configurable client-side params
+ - get matching model
+ - check if 'clientside_configurable_parameters' is set for model
+ -
+ """
+ if llm_router is None:
+ return False
+ # check if model is set
+ model_info = llm_router.get_model_group_info(model_group=model)
+ if model_info is None:
+ # check if wildcard model is set
+ if model.split("/", 1)[0] in provider_list:
+ model_info = llm_router.get_model_group_info(
+ model_group=model.split("/", 1)[0]
+ )
+
+ if model_info is None:
+ return False
+
+ if model_info is None or model_info.configurable_clientside_auth_params is None:
+ return False
+
+ return _is_param_allowed(
+ param=param,
+ request_body_value=request_body_value,
+ configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
+ )
+
+
+def is_request_body_safe(
+ request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
+) -> bool:
+ """
+ Check if the request body is safe.
+
+ A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key.
+ Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997
+ """
+ banned_params = ["api_base", "base_url"]
+
+ for param in banned_params:
+ if (
+ param in request_body
+ and not check_complete_credentials( # allow client-credentials to be passed to proxy
+ request_body=request_body
+ )
+ ):
+ if general_settings.get("allow_client_side_credentials") is True:
+ return True
+ elif (
+ _allow_model_level_clientside_configurable_parameters(
+ model=model,
+ param=param,
+ request_body_value=request_body[param],
+ llm_router=llm_router,
+ )
+ is True
+ ):
+ return True
+ raise ValueError(
+ f"Rejected Request: {param} is not allowed in request body. "
+ "Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
+ "Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
+ )
+
+ return True
+
+
+async def pre_db_read_auth_checks(
+ request: Request,
+ request_data: dict,
+ route: str,
+):
+ """
+ 1. Checks if request size is under max_request_size_mb (if set)
+ 2. Check if request body is safe (example user has not set api_base in request body)
+ 3. Check if IP address is allowed (if set)
+ 4. Check if request route is an allowed route on the proxy (if set)
+
+ Returns:
+ - True
+
+ Raises:
+ - HTTPException if request fails initial auth checks
+ """
+ from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
+
+ # Check 1. request size
+ await check_if_request_size_is_safe(request=request)
+
+ # Check 2. Request body is safe
+ is_request_body_safe(
+ request_body=request_data,
+ general_settings=general_settings,
+ llm_router=llm_router,
+ model=request_data.get(
+ "model", ""
+ ), # [TODO] use model passed in url as well (azure openai routes)
+ )
+
+ # Check 3. Check if IP address is allowed
+ is_valid_ip, passed_in_ip = _check_valid_ip(
+ allowed_ips=general_settings.get("allowed_ips", None),
+ use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
+ request=request,
+ )
+
+ if not is_valid_ip:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
+ )
+
+ # Check 4. Check if request route is an allowed route on the proxy
+ if "allowed_routes" in general_settings:
+ _allowed_routes = general_settings["allowed_routes"]
+ if premium_user is not True:
+ verbose_proxy_logger.error(
+ f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}"
+ )
+ if route not in _allowed_routes:
+ verbose_proxy_logger.error(
+ f"Route {route} not in allowed_routes={_allowed_routes}"
+ )
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"Access forbidden: Route {route} not allowed",
+ )
+
+
+def route_in_additonal_public_routes(current_route: str):
+ """
+ Helper to check if the user defined public_routes on config.yaml
+
+ Parameters:
+ - current_route: str - the route the user is trying to call
+
+ Returns:
+ - bool - True if the route is defined in public_routes
+ - bool - False if the route is not defined in public_routes
+
+
+ In order to use this the litellm config.yaml should have the following in general_settings:
+
+ ```yaml
+ general_settings:
+ master_key: sk-1234
+ public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"]
+ ```
+ """
+
+ # check if user is premium_user - if not do nothing
+ from litellm.proxy.proxy_server import general_settings, premium_user
+
+ try:
+ if premium_user is not True:
+ return False
+ # check if this is defined on the config
+ if general_settings is None:
+ return False
+
+ routes_defined = general_settings.get("public_routes", [])
+ if current_route in routes_defined:
+ return True
+
+ return False
+ except Exception as e:
+ verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}")
+ return False
+
+
+def get_request_route(request: Request) -> str:
+ """
+ Helper to get the route from the request
+
+ remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions
+ """
+ try:
+ if hasattr(request, "base_url") and request.url.path.startswith(
+ request.base_url.path
+ ):
+ # remove base_url from path
+ return request.url.path[len(request.base_url.path) - 1 :]
+ else:
+ return request.url.path
+ except Exception as e:
+ verbose_proxy_logger.debug(
+ f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}"
+ )
+ return request.url.path
+
+
+async def check_if_request_size_is_safe(request: Request) -> bool:
+ """
+ Enterprise Only:
+ - Checks if the request size is within the limit
+
+ Args:
+ request (Request): The incoming request.
+
+ Returns:
+ bool: True if the request size is within the limit
+
+ Raises:
+ ProxyException: If the request size is too large
+
+ """
+ from litellm.proxy.proxy_server import general_settings, premium_user
+
+ max_request_size_mb = general_settings.get("max_request_size_mb", None)
+ if max_request_size_mb is not None:
+ # Check if premium user
+ if premium_user is not True:
+ verbose_proxy_logger.warning(
+ f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
+ )
+ return True
+
+ # Get the request body
+ content_length = request.headers.get("content-length")
+
+ if content_length:
+ header_size = int(content_length)
+ header_size_mb = bytes_to_mb(bytes_value=header_size)
+ verbose_proxy_logger.debug(
+ f"content_length request size in MB={header_size_mb}"
+ )
+
+ if header_size_mb > max_request_size_mb:
+ raise ProxyException(
+ message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB",
+ type=ProxyErrorTypes.bad_request_error.value,
+ code=400,
+ param="content-length",
+ )
+ else:
+ # If Content-Length is not available, read the body
+ body = await request.body()
+ body_size = len(body)
+ request_size_mb = bytes_to_mb(bytes_value=body_size)
+
+ verbose_proxy_logger.debug(
+ f"request body request size in MB={request_size_mb}"
+ )
+ if request_size_mb > max_request_size_mb:
+ raise ProxyException(
+ message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB",
+ type=ProxyErrorTypes.bad_request_error.value,
+ code=400,
+ param="content-length",
+ )
+
+ return True
+
+
+async def check_response_size_is_safe(response: Any) -> bool:
+ """
+ Enterprise Only:
+ - Checks if the response size is within the limit
+
+ Args:
+ response (Any): The response to check.
+
+ Returns:
+ bool: True if the response size is within the limit
+
+ Raises:
+ ProxyException: If the response size is too large
+
+ """
+
+ from litellm.proxy.proxy_server import general_settings, premium_user
+
+ max_response_size_mb = general_settings.get("max_response_size_mb", None)
+ if max_response_size_mb is not None:
+ # Check if premium user
+ if premium_user is not True:
+ verbose_proxy_logger.warning(
+ f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
+ )
+ return True
+
+ response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response))
+ verbose_proxy_logger.debug(f"response size in MB={response_size_mb}")
+ if response_size_mb > max_response_size_mb:
+ raise ProxyException(
+ message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB",
+ type=ProxyErrorTypes.bad_request_error.value,
+ code=400,
+ param="content-length",
+ )
+
+ return True
+
+
+def bytes_to_mb(bytes_value: int):
+ """
+ Helper to convert bytes to MB
+ """
+ return bytes_value / (1024 * 1024)
+
+
+# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
+def get_key_model_rpm_limit(
+ user_api_key_dict: UserAPIKeyAuth,
+) -> Optional[Dict[str, int]]:
+ if user_api_key_dict.metadata:
+ if "model_rpm_limit" in user_api_key_dict.metadata:
+ return user_api_key_dict.metadata["model_rpm_limit"]
+ elif user_api_key_dict.model_max_budget:
+ model_rpm_limit: Dict[str, Any] = {}
+ for model, budget in user_api_key_dict.model_max_budget.items():
+ if "rpm_limit" in budget and budget["rpm_limit"] is not None:
+ model_rpm_limit[model] = budget["rpm_limit"]
+ return model_rpm_limit
+
+ return None
+
+
+def get_key_model_tpm_limit(
+ user_api_key_dict: UserAPIKeyAuth,
+) -> Optional[Dict[str, int]]:
+ if user_api_key_dict.metadata:
+ if "model_tpm_limit" in user_api_key_dict.metadata:
+ return user_api_key_dict.metadata["model_tpm_limit"]
+ elif user_api_key_dict.model_max_budget:
+ if "tpm_limit" in user_api_key_dict.model_max_budget:
+ return user_api_key_dict.model_max_budget["tpm_limit"]
+
+ return None
+
+
+def is_pass_through_provider_route(route: str) -> bool:
+ PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
+ "vertex-ai",
+ ]
+
+ # check if any of the prefixes are in the route
+ for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES:
+ if prefix in route:
+ return True
+
+ return False
+
+
+def should_run_auth_on_pass_through_provider_route(route: str) -> bool:
+ """
+ Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on /vertex-ai/{endpoint} routes
+ Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on provider pass through routes
+ ex /vertex-ai/{endpoint} routes
+ Run virtual key auth if the following is try:
+ - User is premium_user
+ - User has enabled litellm_setting.use_client_credentials_pass_through_routes
+ """
+ from litellm.proxy.proxy_server import general_settings, premium_user
+
+ if premium_user is not True:
+
+ return False
+
+ # premium use has opted into using client credentials
+ if (
+ general_settings.get("use_client_credentials_pass_through_routes", False)
+ is True
+ ):
+ return False
+
+ # only enabled for LiteLLM Enterprise
+ return True
+
+
+def _has_user_setup_sso():
+ """
+ Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables.
+ Returns a boolean indicating whether SSO has been set up.
+ """
+ microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
+ google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
+ generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
+
+ sso_setup = (
+ (microsoft_client_id is not None)
+ or (google_client_id is not None)
+ or (generic_client_id is not None)
+ )
+
+ return sso_setup
+
+
+def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]:
+ # openai - check 'user'
+ if "user" in request_body and request_body["user"] is not None:
+ return str(request_body["user"])
+ # anthropic - check 'litellm_metadata'
+ end_user_id = request_body.get("litellm_metadata", {}).get("user", None)
+ if end_user_id:
+ return str(end_user_id)
+ metadata = request_body.get("metadata")
+ if metadata and "user_id" in metadata and metadata["user_id"] is not None:
+ return str(metadata["user_id"])
+ return None
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/handle_jwt.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/handle_jwt.py
new file mode 100644
index 00000000..cc410501
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/handle_jwt.py
@@ -0,0 +1,1001 @@
+"""
+Supports using JWT's for authenticating into the proxy.
+
+Currently only supports admin.
+
+JWT token must have 'litellm_proxy_admin' in scope.
+"""
+
+import json
+import os
+from typing import Any, List, Literal, Optional, Set, Tuple, cast
+
+from cryptography import x509
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+from fastapi import HTTPException
+
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
+from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
+from litellm.proxy._types import (
+ RBAC_ROLES,
+ JWKKeyValue,
+ JWTAuthBuilderResult,
+ JWTKeyItem,
+ LiteLLM_EndUserTable,
+ LiteLLM_JWTAuth,
+ LiteLLM_OrganizationTable,
+ LiteLLM_TeamTable,
+ LiteLLM_UserTable,
+ LitellmUserRoles,
+ ScopeMapping,
+ Span,
+)
+from litellm.proxy.auth.auth_checks import can_team_access_model
+from litellm.proxy.utils import PrismaClient, ProxyLogging
+
+from .auth_checks import (
+ _allowed_routes_check,
+ allowed_routes_check,
+ get_actual_routes,
+ get_end_user_object,
+ get_org_object,
+ get_role_based_models,
+ get_role_based_routes,
+ get_team_object,
+ get_user_object,
+)
+
+
+class JWTHandler:
+ """
+ - treat the sub id passed in as the user id
+ - return an error if id making request doesn't exist in proxy user table
+ - track spend against the user id
+ - if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets
+ """
+
+ prisma_client: Optional[PrismaClient]
+ user_api_key_cache: DualCache
+
+ def __init__(
+ self,
+ ) -> None:
+ self.http_handler = HTTPHandler()
+ self.leeway = 0
+
+ def update_environment(
+ self,
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ litellm_jwtauth: LiteLLM_JWTAuth,
+ leeway: int = 0,
+ ) -> None:
+ self.prisma_client = prisma_client
+ self.user_api_key_cache = user_api_key_cache
+ self.litellm_jwtauth = litellm_jwtauth
+ self.leeway = leeway
+
+ def is_jwt(self, token: str):
+ parts = token.split(".")
+ return len(parts) == 3
+
+ def _rbac_role_from_role_mapping(self, token: dict) -> Optional[RBAC_ROLES]:
+ """
+ Returns the RBAC role the token 'belongs' to based on role mappings.
+
+ Args:
+ token (dict): The JWT token containing role information
+
+ Returns:
+ Optional[RBAC_ROLES]: The mapped internal RBAC role if a mapping exists,
+ None otherwise
+
+ Note:
+ The function handles both single string roles and lists of roles from the JWT.
+ If multiple mappings match the JWT roles, the first matching mapping is returned.
+ """
+ if self.litellm_jwtauth.role_mappings is None:
+ return None
+
+ jwt_role = self.get_jwt_role(token=token, default_value=None)
+ if not jwt_role:
+ return None
+
+ jwt_role_set = set(jwt_role)
+
+ for role_mapping in self.litellm_jwtauth.role_mappings:
+ # Check if the mapping role matches any of the JWT roles
+ if role_mapping.role in jwt_role_set:
+ return role_mapping.internal_role
+
+ return None
+
+ def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]:
+ """
+ Returns the RBAC role the token 'belongs' to.
+
+ RBAC roles allowed to make requests:
+ - PROXY_ADMIN: can make requests to all routes
+ - TEAM: can make requests to routes associated with a team
+ - INTERNAL_USER: can make requests to routes associated with a user
+
+ Resolves: https://github.com/BerriAI/litellm/issues/6793
+
+ Returns:
+ - PROXY_ADMIN: if token is admin
+ - TEAM: if token is associated with a team
+ - INTERNAL_USER: if token is associated with a user
+ - None: if token is not associated with a team or user
+ """
+ scopes = self.get_scopes(token=token)
+ is_admin = self.is_admin(scopes=scopes)
+ user_roles = self.get_user_roles(token=token, default_value=None)
+
+ if is_admin:
+ return LitellmUserRoles.PROXY_ADMIN
+ elif self.get_team_id(token=token, default_value=None) is not None:
+ return LitellmUserRoles.TEAM
+ elif self.get_user_id(token=token, default_value=None) is not None:
+ return LitellmUserRoles.INTERNAL_USER
+ elif user_roles is not None and self.is_allowed_user_role(
+ user_roles=user_roles
+ ):
+ return LitellmUserRoles.INTERNAL_USER
+ elif rbac_role := self._rbac_role_from_role_mapping(token=token):
+ return rbac_role
+
+ return None
+
+ def is_admin(self, scopes: list) -> bool:
+ if self.litellm_jwtauth.admin_jwt_scope in scopes:
+ return True
+ return False
+
+ def get_team_ids_from_jwt(self, token: dict) -> List[str]:
+ if (
+ self.litellm_jwtauth.team_ids_jwt_field is not None
+ and token.get(self.litellm_jwtauth.team_ids_jwt_field) is not None
+ ):
+ return token[self.litellm_jwtauth.team_ids_jwt_field]
+ return []
+
+ def get_end_user_id(
+ self, token: dict, default_value: Optional[str]
+ ) -> Optional[str]:
+ try:
+
+ if self.litellm_jwtauth.end_user_id_jwt_field is not None:
+ user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
+ else:
+ user_id = None
+ except KeyError:
+ user_id = default_value
+
+ return user_id
+
+ def is_required_team_id(self) -> bool:
+ """
+ Returns:
+ - True: if 'team_id_jwt_field' is set
+ - False: if not
+ """
+ if self.litellm_jwtauth.team_id_jwt_field is None:
+ return False
+ return True
+
+ def is_enforced_email_domain(self) -> bool:
+ """
+ Returns:
+ - True: if 'user_allowed_email_domain' is set
+ - False: if 'user_allowed_email_domain' is None
+ """
+
+ if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance(
+ self.litellm_jwtauth.user_allowed_email_domain, str
+ ):
+ return True
+ return False
+
+ def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
+ try:
+ if self.litellm_jwtauth.team_id_jwt_field is not None:
+ team_id = token[self.litellm_jwtauth.team_id_jwt_field]
+ elif self.litellm_jwtauth.team_id_default is not None:
+ team_id = self.litellm_jwtauth.team_id_default
+ else:
+ team_id = None
+ except KeyError:
+ team_id = default_value
+ return team_id
+
+ def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool:
+ """
+ Returns:
+ - True: if 'user_id_upsert' is set AND valid_user_email is not False
+ - False: if not
+ """
+ if valid_user_email is False:
+ return False
+ return self.litellm_jwtauth.user_id_upsert
+
+ def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
+ try:
+ if self.litellm_jwtauth.user_id_jwt_field is not None:
+ user_id = token[self.litellm_jwtauth.user_id_jwt_field]
+ else:
+ user_id = default_value
+ except KeyError:
+ user_id = default_value
+ return user_id
+
+ def get_user_roles(
+ self, token: dict, default_value: Optional[List[str]]
+ ) -> Optional[List[str]]:
+ """
+ Returns the user role from the token.
+
+ Set via 'user_roles_jwt_field' in the config.
+ """
+ try:
+ if self.litellm_jwtauth.user_roles_jwt_field is not None:
+ user_roles = get_nested_value(
+ data=token,
+ key_path=self.litellm_jwtauth.user_roles_jwt_field,
+ default=default_value,
+ )
+ else:
+ user_roles = default_value
+ except KeyError:
+ user_roles = default_value
+ return user_roles
+
+ def get_jwt_role(
+ self, token: dict, default_value: Optional[List[str]]
+ ) -> Optional[List[str]]:
+ """
+ Generic implementation of `get_user_roles` that can be used for both user and team roles.
+
+ Returns the jwt role from the token.
+
+ Set via 'roles_jwt_field' in the config.
+ """
+ try:
+ if self.litellm_jwtauth.roles_jwt_field is not None:
+ user_roles = get_nested_value(
+ data=token,
+ key_path=self.litellm_jwtauth.roles_jwt_field,
+ default=default_value,
+ )
+ else:
+ user_roles = default_value
+ except KeyError:
+ user_roles = default_value
+ return user_roles
+
+ def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool:
+ """
+ Returns the user role from the token.
+
+ Set via 'user_allowed_roles' in the config.
+ """
+ if (
+ user_roles is not None
+ and self.litellm_jwtauth.user_allowed_roles is not None
+ and any(
+ role in self.litellm_jwtauth.user_allowed_roles for role in user_roles
+ )
+ ):
+ return True
+ return False
+
+ def get_user_email(
+ self, token: dict, default_value: Optional[str]
+ ) -> Optional[str]:
+ try:
+ if self.litellm_jwtauth.user_email_jwt_field is not None:
+ user_email = token[self.litellm_jwtauth.user_email_jwt_field]
+ else:
+ user_email = None
+ except KeyError:
+ user_email = default_value
+ return user_email
+
+ def get_object_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
+ try:
+ if self.litellm_jwtauth.object_id_jwt_field is not None:
+ object_id = token[self.litellm_jwtauth.object_id_jwt_field]
+ else:
+ object_id = default_value
+ except KeyError:
+ object_id = default_value
+ return object_id
+
+ def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
+ try:
+ if self.litellm_jwtauth.org_id_jwt_field is not None:
+ org_id = token[self.litellm_jwtauth.org_id_jwt_field]
+ else:
+ org_id = None
+ except KeyError:
+ org_id = default_value
+ return org_id
+
+ def get_scopes(self, token: dict) -> List[str]:
+ try:
+ if isinstance(token["scope"], str):
+ # Assuming the scopes are stored in 'scope' claim and are space-separated
+ scopes = token["scope"].split()
+ elif isinstance(token["scope"], list):
+ scopes = token["scope"]
+ else:
+ raise Exception(
+ f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str."
+ )
+ except KeyError:
+ scopes = []
+ return scopes
+
+ async def get_public_key(self, kid: Optional[str]) -> dict:
+
+ keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
+
+ if keys_url is None:
+ raise Exception("Missing JWT Public Key URL from environment.")
+
+ keys_url_list = [url.strip() for url in keys_url.split(",")]
+
+ for key_url in keys_url_list:
+
+ cache_key = f"litellm_jwt_auth_keys_{key_url}"
+
+ cached_keys = await self.user_api_key_cache.async_get_cache(cache_key)
+
+ if cached_keys is None:
+ response = await self.http_handler.get(key_url)
+
+ response_json = response.json()
+ if "keys" in response_json:
+ keys: JWKKeyValue = response.json()["keys"]
+ else:
+ keys = response_json
+
+ await self.user_api_key_cache.async_set_cache(
+ key=cache_key,
+ value=keys,
+ ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
+ )
+ else:
+ keys = cached_keys
+
+ public_key = self.parse_keys(keys=keys, kid=kid)
+ if public_key is not None:
+ return cast(dict, public_key)
+
+ raise Exception(
+ f"No matching public key found. keys={keys_url_list}, kid={kid}"
+ )
+
+ def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]:
+ public_key: Optional[JWTKeyItem] = None
+ if len(keys) == 1:
+ if isinstance(keys, dict) and (keys.get("kid", None) == kid or kid is None):
+ public_key = keys
+ elif isinstance(keys, list) and (
+ keys[0].get("kid", None) == kid or kid is None
+ ):
+ public_key = keys[0]
+ elif len(keys) > 1:
+ for key in keys:
+ if isinstance(key, dict):
+ key_kid = key.get("kid", None)
+ else:
+ key_kid = None
+ if (
+ kid is not None
+ and isinstance(key, dict)
+ and key_kid is not None
+ and key_kid == kid
+ ):
+ public_key = key
+
+ return public_key
+
+ def is_allowed_domain(self, user_email: str) -> bool:
+ if self.litellm_jwtauth.user_allowed_email_domain is None:
+ return True
+
+ email_domain = user_email.split("@")[-1] # Extract domain from email
+ if email_domain == self.litellm_jwtauth.user_allowed_email_domain:
+ return True
+ else:
+ return False
+
+ async def auth_jwt(self, token: str) -> dict:
+ # Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
+ # "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret
+ # the key in different ways (e.g. HS* and RS*)."
+ algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
+
+ audience = os.getenv("JWT_AUDIENCE")
+ decode_options = None
+ if audience is None:
+ decode_options = {"verify_aud": False}
+
+ import jwt
+ from jwt.algorithms import RSAAlgorithm
+
+ header = jwt.get_unverified_header(token)
+
+ verbose_proxy_logger.debug("header: %s", header)
+
+ kid = header.get("kid", None)
+
+ public_key = await self.get_public_key(kid=kid)
+
+ if public_key is not None and isinstance(public_key, dict):
+ jwk = {}
+ if "kty" in public_key:
+ jwk["kty"] = public_key["kty"]
+ if "kid" in public_key:
+ jwk["kid"] = public_key["kid"]
+ if "n" in public_key:
+ jwk["n"] = public_key["n"]
+ if "e" in public_key:
+ jwk["e"] = public_key["e"]
+
+ public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk))
+
+ try:
+ # decode the token using the public key
+ payload = jwt.decode(
+ token,
+ public_key_rsa, # type: ignore
+ algorithms=algorithms,
+ options=decode_options,
+ audience=audience,
+ leeway=self.leeway, # allow testing of expired tokens
+ )
+ return payload
+
+ except jwt.ExpiredSignatureError:
+ # the token is expired, do something to refresh it
+ raise Exception("Token Expired")
+ except Exception as e:
+ raise Exception(f"Validation fails: {str(e)}")
+ elif public_key is not None and isinstance(public_key, str):
+ try:
+ cert = x509.load_pem_x509_certificate(
+ public_key.encode(), default_backend()
+ )
+
+ # Extract public key
+ key = cert.public_key().public_bytes(
+ serialization.Encoding.PEM,
+ serialization.PublicFormat.SubjectPublicKeyInfo,
+ )
+
+ # decode the token using the public key
+ payload = jwt.decode(
+ token,
+ key,
+ algorithms=algorithms,
+ audience=audience,
+ options=decode_options,
+ )
+ return payload
+
+ except jwt.ExpiredSignatureError:
+ # the token is expired, do something to refresh it
+ raise Exception("Token Expired")
+ except Exception as e:
+ raise Exception(f"Validation fails: {str(e)}")
+
+ raise Exception("Invalid JWT Submitted")
+
+ async def close(self):
+ await self.http_handler.close()
+
+
+class JWTAuthManager:
+ """Manages JWT authentication and authorization operations"""
+
+ @staticmethod
+ def can_rbac_role_call_route(
+ rbac_role: RBAC_ROLES,
+ general_settings: dict,
+ route: str,
+ ) -> Literal[True]:
+ """
+ Checks if user is allowed to access the route, based on their role.
+ """
+ role_based_routes = get_role_based_routes(
+ rbac_role=rbac_role, general_settings=general_settings
+ )
+
+ if role_based_routes is None or route is None:
+ return True
+
+ is_allowed = _allowed_routes_check(
+ user_route=route,
+ allowed_routes=role_based_routes,
+ )
+
+ if not is_allowed:
+ raise HTTPException(
+ status_code=403,
+ detail=f"Role={rbac_role} not allowed to call route={route}. Allowed routes={role_based_routes}",
+ )
+
+ return True
+
+ @staticmethod
+ def can_rbac_role_call_model(
+ rbac_role: RBAC_ROLES,
+ general_settings: dict,
+ model: Optional[str],
+ ) -> Literal[True]:
+ """
+ Checks if user is allowed to access the model, based on their role.
+ """
+ role_based_models = get_role_based_models(
+ rbac_role=rbac_role, general_settings=general_settings
+ )
+ if role_based_models is None or model is None:
+ return True
+
+ if model not in role_based_models:
+ raise HTTPException(
+ status_code=403,
+ detail=f"Role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
+ )
+
+ return True
+
+ @staticmethod
+ def check_scope_based_access(
+ scope_mappings: List[ScopeMapping],
+ scopes: List[str],
+ request_data: dict,
+ general_settings: dict,
+ ) -> None:
+ """
+ Check if scope allows access to the requested model
+ """
+ if not scope_mappings:
+ return None
+
+ allowed_models = []
+ for sm in scope_mappings:
+ if sm.scope in scopes and sm.models:
+ allowed_models.extend(sm.models)
+
+ requested_model = request_data.get("model")
+
+ if not requested_model:
+ return None
+
+ if requested_model not in allowed_models:
+ raise HTTPException(
+ status_code=403,
+ detail={
+ "error": "model={} not allowed. Allowed_models={}".format(
+ requested_model, allowed_models
+ )
+ },
+ )
+ return None
+
+ @staticmethod
+ async def check_rbac_role(
+ jwt_handler: JWTHandler,
+ jwt_valid_token: dict,
+ general_settings: dict,
+ request_data: dict,
+ route: str,
+ rbac_role: Optional[RBAC_ROLES],
+ ) -> None:
+ """Validate RBAC role and model access permissions"""
+ if jwt_handler.litellm_jwtauth.enforce_rbac is True:
+ if rbac_role is None:
+ raise HTTPException(
+ status_code=403,
+ detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
+ )
+ JWTAuthManager.can_rbac_role_call_model(
+ rbac_role=rbac_role,
+ general_settings=general_settings,
+ model=request_data.get("model"),
+ )
+ JWTAuthManager.can_rbac_role_call_route(
+ rbac_role=rbac_role,
+ general_settings=general_settings,
+ route=route,
+ )
+
+ @staticmethod
+ async def check_admin_access(
+ jwt_handler: JWTHandler,
+ scopes: list,
+ route: str,
+ user_id: Optional[str],
+ org_id: Optional[str],
+ api_key: str,
+ ) -> Optional[JWTAuthBuilderResult]:
+ """Check admin status and route access permissions"""
+ if not jwt_handler.is_admin(scopes=scopes):
+ return None
+
+ is_allowed = allowed_routes_check(
+ user_role=LitellmUserRoles.PROXY_ADMIN,
+ user_route=route,
+ litellm_proxy_roles=jwt_handler.litellm_jwtauth,
+ )
+ if not is_allowed:
+ allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
+ actual_routes = get_actual_routes(allowed_routes=allowed_routes)
+ raise Exception(
+ f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
+ )
+
+ return JWTAuthBuilderResult(
+ is_proxy_admin=True,
+ team_object=None,
+ user_object=None,
+ end_user_object=None,
+ org_object=None,
+ token=api_key,
+ team_id=None,
+ user_id=user_id,
+ end_user_id=None,
+ org_id=org_id,
+ )
+
+ @staticmethod
+ async def find_and_validate_specific_team_id(
+ jwt_handler: JWTHandler,
+ jwt_valid_token: dict,
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span],
+ proxy_logging_obj: ProxyLogging,
+ ) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
+ """Find and validate specific team ID"""
+ individual_team_id = jwt_handler.get_team_id(
+ token=jwt_valid_token, default_value=None
+ )
+
+ if not individual_team_id and jwt_handler.is_required_team_id() is True:
+ raise Exception(
+ f"No team id found in token. Checked team_id field '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
+ )
+
+ ## VALIDATE TEAM OBJECT ###
+ team_object: Optional[LiteLLM_TeamTable] = None
+ if individual_team_id:
+ team_object = await get_team_object(
+ team_id=individual_team_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert,
+ )
+
+ return individual_team_id, team_object
+
+ @staticmethod
+ def get_all_team_ids(jwt_handler: JWTHandler, jwt_valid_token: dict) -> Set[str]:
+ """Get combined team IDs from groups and individual team_id"""
+ team_ids_from_groups = jwt_handler.get_team_ids_from_jwt(token=jwt_valid_token)
+
+ all_team_ids = set(team_ids_from_groups)
+
+ return all_team_ids
+
+ @staticmethod
+ async def find_team_with_model_access(
+ team_ids: Set[str],
+ requested_model: Optional[str],
+ route: str,
+ jwt_handler: JWTHandler,
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span],
+ proxy_logging_obj: ProxyLogging,
+ ) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
+ """Find first team with access to the requested model"""
+
+ if not team_ids:
+ if jwt_handler.litellm_jwtauth.enforce_team_based_model_access:
+ raise HTTPException(
+ status_code=403,
+ detail="No teams found in token. `enforce_team_based_model_access` is set to True. Token must belong to a team.",
+ )
+ return None, None
+
+ for team_id in team_ids:
+ try:
+ team_object = await get_team_object(
+ team_id=team_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+ if team_object and team_object.models is not None:
+ team_models = team_object.models
+ if isinstance(team_models, list) and (
+ not requested_model
+ or can_team_access_model(
+ model=requested_model,
+ team_object=team_object,
+ llm_router=None,
+ team_model_aliases=None,
+ )
+ ):
+ is_allowed = allowed_routes_check(
+ user_role=LitellmUserRoles.TEAM,
+ user_route=route,
+ litellm_proxy_roles=jwt_handler.litellm_jwtauth,
+ )
+ if is_allowed:
+ return team_id, team_object
+ except Exception:
+ continue
+
+ if requested_model:
+ raise HTTPException(
+ status_code=403,
+ detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}. Check `/models` to see all available models.",
+ )
+
+ return None, None
+
+ @staticmethod
+ async def get_user_info(
+ jwt_handler: JWTHandler,
+ jwt_valid_token: dict,
+ ) -> Tuple[Optional[str], Optional[str], Optional[bool]]:
+ """Get user email and validation status"""
+ user_email = jwt_handler.get_user_email(
+ token=jwt_valid_token, default_value=None
+ )
+ valid_user_email = None
+ if jwt_handler.is_enforced_email_domain():
+ valid_user_email = (
+ False
+ if user_email is None
+ else jwt_handler.is_allowed_domain(user_email=user_email)
+ )
+ user_id = jwt_handler.get_user_id(
+ token=jwt_valid_token, default_value=user_email
+ )
+ return user_id, user_email, valid_user_email
+
+ @staticmethod
+ async def get_objects(
+ user_id: Optional[str],
+ user_email: Optional[str],
+ org_id: Optional[str],
+ end_user_id: Optional[str],
+ valid_user_email: Optional[bool],
+ jwt_handler: JWTHandler,
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span],
+ proxy_logging_obj: ProxyLogging,
+ ) -> Tuple[
+ Optional[LiteLLM_UserTable],
+ Optional[LiteLLM_OrganizationTable],
+ Optional[LiteLLM_EndUserTable],
+ ]:
+ """Get user, org, and end user objects"""
+ org_object: Optional[LiteLLM_OrganizationTable] = None
+ if org_id:
+ org_object = (
+ await get_org_object(
+ org_id=org_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+ if org_id
+ else None
+ )
+
+ user_object: Optional[LiteLLM_UserTable] = None
+ if user_id:
+ user_object = (
+ await get_user_object(
+ user_id=user_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ user_id_upsert=jwt_handler.is_upsert_user_id(
+ valid_user_email=valid_user_email
+ ),
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ user_email=user_email,
+ sso_user_id=user_id,
+ )
+ if user_id
+ else None
+ )
+
+ end_user_object: Optional[LiteLLM_EndUserTable] = None
+ if end_user_id:
+ end_user_object = (
+ await get_end_user_object(
+ end_user_id=end_user_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+ if end_user_id
+ else None
+ )
+
+ return user_object, org_object, end_user_object
+
+ @staticmethod
+ def validate_object_id(
+ user_id: Optional[str],
+ team_id: Optional[str],
+ enforce_rbac: bool,
+ is_proxy_admin: bool,
+ ) -> Literal[True]:
+ """If enforce_rbac is true, validate that a valid rbac id is returned for spend tracking"""
+ if enforce_rbac and not is_proxy_admin and not user_id and not team_id:
+ raise HTTPException(
+ status_code=403,
+ detail="No user or team id found in token. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
+ )
+ return True
+
+ @staticmethod
+ async def auth_builder(
+ api_key: str,
+ jwt_handler: JWTHandler,
+ request_data: dict,
+ general_settings: dict,
+ route: str,
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span],
+ proxy_logging_obj: ProxyLogging,
+ ) -> JWTAuthBuilderResult:
+ """Main authentication and authorization builder"""
+ jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
+
+ # Check custom validate
+ if jwt_handler.litellm_jwtauth.custom_validate:
+ if not jwt_handler.litellm_jwtauth.custom_validate(jwt_valid_token):
+ raise HTTPException(
+ status_code=403,
+ detail="Invalid JWT token",
+ )
+
+ # Check RBAC
+ rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
+ await JWTAuthManager.check_rbac_role(
+ jwt_handler,
+ jwt_valid_token,
+ general_settings,
+ request_data,
+ route,
+ rbac_role,
+ )
+
+ # Check Scope Based Access
+ scopes = jwt_handler.get_scopes(token=jwt_valid_token)
+ if (
+ jwt_handler.litellm_jwtauth.enforce_scope_based_access
+ and jwt_handler.litellm_jwtauth.scope_mappings
+ ):
+ JWTAuthManager.check_scope_based_access(
+ scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings,
+ scopes=scopes,
+ request_data=request_data,
+ general_settings=general_settings,
+ )
+
+ object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
+
+ # Get basic user info
+ scopes = jwt_handler.get_scopes(token=jwt_valid_token)
+ user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info(
+ jwt_handler, jwt_valid_token
+ )
+
+ # Get IDs
+ org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
+ end_user_id = jwt_handler.get_end_user_id(
+ token=jwt_valid_token, default_value=None
+ )
+ team_id: Optional[str] = None
+ team_object: Optional[LiteLLM_TeamTable] = None
+ object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
+
+ if rbac_role and object_id:
+
+ if rbac_role == LitellmUserRoles.TEAM:
+ team_id = object_id
+ elif rbac_role == LitellmUserRoles.INTERNAL_USER:
+ user_id = object_id
+
+ # Check admin access
+ admin_result = await JWTAuthManager.check_admin_access(
+ jwt_handler, scopes, route, user_id, org_id, api_key
+ )
+ if admin_result:
+ return admin_result
+
+ # Get team with model access
+ ## SPECIFIC TEAM ID
+
+ if not team_id:
+ team_id, team_object = (
+ await JWTAuthManager.find_and_validate_specific_team_id(
+ jwt_handler,
+ jwt_valid_token,
+ prisma_client,
+ user_api_key_cache,
+ parent_otel_span,
+ proxy_logging_obj,
+ )
+ )
+
+ if not team_object and not team_id:
+ ## CHECK USER GROUP ACCESS
+ all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token)
+ team_id, team_object = await JWTAuthManager.find_team_with_model_access(
+ team_ids=all_team_ids,
+ requested_model=request_data.get("model"),
+ route=route,
+ jwt_handler=jwt_handler,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+ # Get other objects
+ user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
+ user_id=user_id,
+ user_email=user_email,
+ org_id=org_id,
+ end_user_id=end_user_id,
+ valid_user_email=valid_user_email,
+ jwt_handler=jwt_handler,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+ # Validate that a valid rbac id is returned for spend tracking
+ JWTAuthManager.validate_object_id(
+ user_id=user_id,
+ team_id=team_id,
+ enforce_rbac=general_settings.get("enforce_rbac", False),
+ is_proxy_admin=False,
+ )
+
+ return JWTAuthBuilderResult(
+ is_proxy_admin=False,
+ team_id=team_id,
+ team_object=team_object,
+ user_id=user_id,
+ user_object=user_object,
+ org_id=org_id,
+ org_object=org_object,
+ end_user_id=end_user_id,
+ end_user_object=end_user_object,
+ token=api_key,
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/litellm_license.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/litellm_license.py
new file mode 100644
index 00000000..67ec91f5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/litellm_license.py
@@ -0,0 +1,169 @@
+# What is this?
+## If litellm license in env, checks if it's valid
+import base64
+import json
+import os
+from datetime import datetime
+from typing import Optional
+
+import httpx
+
+from litellm._logging import verbose_proxy_logger
+from litellm.llms.custom_httpx.http_handler import HTTPHandler
+
+
+class LicenseCheck:
+ """
+ - Check if license in env
+ - Returns if license is valid
+ """
+
+ base_url = "https://license.litellm.ai"
+
+ def __init__(self) -> None:
+ self.license_str = os.getenv("LITELLM_LICENSE", None)
+ verbose_proxy_logger.debug("License Str value - {}".format(self.license_str))
+ self.http_handler = HTTPHandler(timeout=15)
+ self.public_key = None
+ self.read_public_key()
+
+ def read_public_key(self):
+ try:
+ from cryptography.hazmat.primitives import serialization
+
+ # current dir
+ current_dir = os.path.dirname(os.path.realpath(__file__))
+
+ # check if public_key.pem exists
+ _path_to_public_key = os.path.join(current_dir, "public_key.pem")
+ if os.path.exists(_path_to_public_key):
+ with open(_path_to_public_key, "rb") as key_file:
+ self.public_key = serialization.load_pem_public_key(key_file.read())
+ else:
+ self.public_key = None
+ except Exception as e:
+ verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
+
+ def _verify(self, license_str: str) -> bool:
+
+ verbose_proxy_logger.debug(
+ "litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format(
+ self.base_url, license_str
+ )
+ )
+ url = "{}/verify_license/{}".format(self.base_url, license_str)
+
+ response: Optional[httpx.Response] = None
+ try: # don't impact user, if call fails
+ num_retries = 3
+ for i in range(num_retries):
+ try:
+ response = self.http_handler.get(url=url)
+ if response is None:
+ raise Exception("No response from license server")
+ response.raise_for_status()
+ except httpx.HTTPStatusError:
+ if i == num_retries - 1:
+ raise
+
+ if response is None:
+ raise Exception("No response from license server")
+
+ response_json = response.json()
+
+ premium = response_json["verify"]
+
+ assert isinstance(premium, bool)
+
+ verbose_proxy_logger.debug(
+ "litellm.proxy.auth.litellm_license.py::_verify - License={} is premium={}".format(
+ license_str, premium
+ )
+ )
+ return premium
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License={} via api. - {}".format(
+ license_str, str(e)
+ )
+ )
+ return False
+
+ def is_premium(self) -> bool:
+ """
+ 1. verify_license_without_api_request: checks if license was generate using private / public key pair
+ 2. _verify: checks if license is valid calling litellm API. This is the old way we were generating/validating license
+ """
+ try:
+ verbose_proxy_logger.debug(
+ "litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - LiteLLM License={}".format(
+ self.license_str
+ )
+ )
+
+ if self.license_str is None:
+ self.license_str = os.getenv("LITELLM_LICENSE", None)
+
+ verbose_proxy_logger.debug(
+ "litellm.proxy.auth.litellm_license.py::is_premium() - Updated 'self.license_str' - {}".format(
+ self.license_str
+ )
+ )
+
+ if self.license_str is None:
+ return False
+ elif (
+ self.verify_license_without_api_request(
+ public_key=self.public_key, license_key=self.license_str
+ )
+ is True
+ ):
+ return True
+ elif self._verify(license_str=self.license_str) is True:
+ return True
+ return False
+ except Exception:
+ return False
+
+ def verify_license_without_api_request(self, public_key, license_key):
+ try:
+ from cryptography.hazmat.primitives import hashes
+ from cryptography.hazmat.primitives.asymmetric import padding
+
+ # Decode the license key
+ decoded = base64.b64decode(license_key)
+ message, signature = decoded.split(b".", 1)
+
+ # Verify the signature
+ public_key.verify(
+ signature,
+ message,
+ padding.PSS(
+ mgf=padding.MGF1(hashes.SHA256()),
+ salt_length=padding.PSS.MAX_LENGTH,
+ ),
+ hashes.SHA256(),
+ )
+
+ # Decode and parse the data
+ license_data = json.loads(message.decode())
+
+ # debug information provided in license data
+ verbose_proxy_logger.debug("License data: %s", license_data)
+
+ # Check expiration date
+ expiration_date = datetime.strptime(
+ license_data["expiration_date"], "%Y-%m-%d"
+ )
+ if expiration_date < datetime.now():
+ return False, "License has expired"
+
+ return True
+
+ except Exception as e:
+ verbose_proxy_logger.debug(
+ "litellm.proxy.auth.litellm_license.py::verify_license_without_api_request - Unable to verify License locally. - {}".format(
+ str(e)
+ )
+ )
+ return False
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py
new file mode 100644
index 00000000..a48ef6ae
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py
@@ -0,0 +1,197 @@
+# What is this?
+## Common checks for /v1/models and `/model/info`
+import copy
+from typing import Dict, List, Optional, Set
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
+from litellm.utils import get_valid_models
+
+
+def _check_wildcard_routing(model: str) -> bool:
+ """
+ Returns True if a model is a provider wildcard.
+
+ eg:
+ - anthropic/*
+ - openai/*
+ - *
+ """
+ if "*" in model:
+ return True
+ return False
+
+
+def get_provider_models(provider: str) -> Optional[List[str]]:
+ """
+ Returns the list of known models by provider
+ """
+ if provider == "*":
+ return get_valid_models()
+
+ if provider in litellm.models_by_provider:
+ provider_models = copy.deepcopy(litellm.models_by_provider[provider])
+ for idx, _model in enumerate(provider_models):
+ if provider not in _model:
+ provider_models[idx] = f"{provider}/{_model}"
+ return provider_models
+ return None
+
+
+def _get_models_from_access_groups(
+ model_access_groups: Dict[str, List[str]],
+ all_models: List[str],
+) -> List[str]:
+ idx_to_remove = []
+ new_models = []
+ for idx, model in enumerate(all_models):
+ if model in model_access_groups:
+ idx_to_remove.append(idx)
+ new_models.extend(model_access_groups[model])
+
+ for idx in sorted(idx_to_remove, reverse=True):
+ all_models.pop(idx)
+
+ all_models.extend(new_models)
+ return all_models
+
+
+def get_key_models(
+ user_api_key_dict: UserAPIKeyAuth,
+ proxy_model_list: List[str],
+ model_access_groups: Dict[str, List[str]],
+) -> List[str]:
+ """
+ Returns:
+ - List of model name strings
+ - Empty list if no models set
+ - If model_access_groups is provided, only return models that are in the access groups
+ """
+ all_models: List[str] = []
+ if len(user_api_key_dict.models) > 0:
+ all_models = user_api_key_dict.models
+ if SpecialModelNames.all_team_models.value in all_models:
+ all_models = user_api_key_dict.team_models
+ if SpecialModelNames.all_proxy_models.value in all_models:
+ all_models = proxy_model_list
+
+ all_models = _get_models_from_access_groups(
+ model_access_groups=model_access_groups, all_models=all_models
+ )
+
+ verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
+ return all_models
+
+
+def get_team_models(
+ team_models: List[str],
+ proxy_model_list: List[str],
+ model_access_groups: Dict[str, List[str]],
+) -> List[str]:
+ """
+ Returns:
+ - List of model name strings
+ - Empty list if no models set
+ - If model_access_groups is provided, only return models that are in the access groups
+ """
+ all_models = []
+ if len(team_models) > 0:
+ all_models = team_models
+ if SpecialModelNames.all_team_models.value in all_models:
+ all_models = team_models
+ if SpecialModelNames.all_proxy_models.value in all_models:
+ all_models = proxy_model_list
+
+ all_models = _get_models_from_access_groups(
+ model_access_groups=model_access_groups, all_models=all_models
+ )
+
+ verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
+ return all_models
+
+
+def get_complete_model_list(
+ key_models: List[str],
+ team_models: List[str],
+ proxy_model_list: List[str],
+ user_model: Optional[str],
+ infer_model_from_keys: Optional[bool],
+ return_wildcard_routes: Optional[bool] = False,
+) -> List[str]:
+ """Logic for returning complete model list for a given key + team pair"""
+
+ """
+ - If key list is empty -> defer to team list
+ - If team list is empty -> defer to proxy model list
+
+ If list contains wildcard -> return known provider models
+ """
+ unique_models: Set[str] = set()
+ if key_models:
+ unique_models.update(key_models)
+ elif team_models:
+ unique_models.update(team_models)
+ else:
+ unique_models.update(proxy_model_list)
+
+ if user_model:
+ unique_models.add(user_model)
+
+ if infer_model_from_keys:
+ valid_models = get_valid_models()
+ unique_models.update(valid_models)
+
+ all_wildcard_models = _get_wildcard_models(
+ unique_models=unique_models, return_wildcard_routes=return_wildcard_routes
+ )
+
+ return list(unique_models) + all_wildcard_models
+
+
+def get_known_models_from_wildcard(wildcard_model: str) -> List[str]:
+ try:
+ provider, model = wildcard_model.split("/", 1)
+ except ValueError: # safely fail
+ return []
+ # get all known provider models
+ wildcard_models = get_provider_models(provider=provider)
+ if wildcard_models is None:
+ return []
+ if model == "*":
+ return wildcard_models or []
+ else:
+ model_prefix = model.replace("*", "")
+ filtered_wildcard_models = [
+ wc_model
+ for wc_model in wildcard_models
+ if wc_model.split("/")[1].startswith(model_prefix)
+ ]
+
+ return filtered_wildcard_models
+
+
+def _get_wildcard_models(
+ unique_models: Set[str], return_wildcard_routes: Optional[bool] = False
+) -> List[str]:
+ models_to_remove = set()
+ all_wildcard_models = []
+ for model in unique_models:
+ if _check_wildcard_routing(model=model):
+
+ if (
+ return_wildcard_routes
+ ): # will add the wildcard route to the list eg: anthropic/*.
+ all_wildcard_models.append(model)
+
+ # get all known provider models
+ wildcard_models = get_known_models_from_wildcard(wildcard_model=model)
+
+ if wildcard_models is not None:
+ models_to_remove.add(model)
+ all_wildcard_models.extend(wildcard_models)
+
+ for model in models_to_remove:
+ unique_models.remove(model)
+
+ return all_wildcard_models
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py
new file mode 100644
index 00000000..4851c270
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py
@@ -0,0 +1,80 @@
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
+ """
+ Makes a request to the token info endpoint to validate the OAuth2 token.
+
+ Args:
+ token (str): The OAuth2 token to validate.
+
+ Returns:
+ Literal[True]: If the token is valid.
+
+ Raises:
+ ValueError: If the token is invalid, the request fails, or the token info endpoint is not set.
+ """
+ import os
+
+ import httpx
+
+ from litellm._logging import verbose_proxy_logger
+ from litellm.llms.custom_httpx.http_handler import (
+ get_async_httpx_client,
+ httpxSpecialProvider,
+ )
+ from litellm.proxy._types import CommonProxyErrors
+ from litellm.proxy.proxy_server import premium_user
+
+ if premium_user is not True:
+ raise ValueError(
+ "Oauth2 token validation is only available for premium users"
+ + CommonProxyErrors.not_premium_user.value
+ )
+
+ verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token)
+ # Get the token info endpoint from environment variable
+ token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT")
+ user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub")
+ user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role")
+ user_team_id_field_name = os.environ.get("OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id")
+
+ if not token_info_endpoint:
+ raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set")
+
+ client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
+ headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
+
+ try:
+ response = await client.get(token_info_endpoint, headers=headers)
+
+ # if it's a bad token we expect it to raise an HTTPStatusError
+ response.raise_for_status()
+
+ # If we get here, the request was successful
+ data = response.json()
+
+ verbose_proxy_logger.debug(
+ "Oauth2 token validation for token=%s, response from /token/info=%s",
+ token,
+ data,
+ )
+
+ # You might want to add additional checks here based on the response
+ # For example, checking if the token is expired or has the correct scope
+ user_id = data.get(user_id_field_name)
+ user_team_id = data.get(user_team_id_field_name)
+ user_role = data.get(user_role_field_name)
+
+ return UserAPIKeyAuth(
+ api_key=token,
+ team_id=user_team_id,
+ user_id=user_id,
+ user_role=user_role,
+ )
+ except httpx.HTTPStatusError as e:
+ # This will catch any 4xx or 5xx errors
+ raise ValueError(f"Oauth 2.0 Token validation failed: {e}")
+ except Exception as e:
+ # This will catch any other errors (like network issues)
+ raise ValueError(f"An error occurred during token validation: {e}")
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_proxy_hook.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_proxy_hook.py
new file mode 100644
index 00000000..a1db5d84
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_proxy_hook.py
@@ -0,0 +1,45 @@
+from typing import Any, Dict
+
+from fastapi import Request
+
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+async def handle_oauth2_proxy_request(request: Request) -> UserAPIKeyAuth:
+ """
+ Handle request from oauth2 proxy.
+ """
+ from litellm.proxy.proxy_server import general_settings
+
+ verbose_proxy_logger.debug("Handling oauth2 proxy request")
+ # Define the OAuth2 config mappings
+ oauth2_config_mappings: Dict[str, str] = general_settings.get(
+ "oauth2_config_mappings", None
+ )
+ verbose_proxy_logger.debug(f"Oauth2 config mappings: {oauth2_config_mappings}")
+
+ if not oauth2_config_mappings:
+ raise ValueError("Oauth2 config mappings not found in general_settings")
+ # Initialize a dictionary to store the mapped values
+ auth_data: Dict[str, Any] = {}
+
+ # Extract values from headers based on the mappings
+ for key, header in oauth2_config_mappings.items():
+ value = request.headers.get(header)
+ if value:
+ # Convert max_budget to float if present
+ if key == "max_budget":
+ auth_data[key] = float(value)
+ # Convert models to list if present
+ elif key == "models":
+ auth_data[key] = [model.strip() for model in value.split(",")]
+ else:
+ auth_data[key] = value
+ verbose_proxy_logger.debug(
+ f"Auth data before creating UserAPIKeyAuth object: {auth_data}"
+ )
+ user_api_key_auth = UserAPIKeyAuth(**auth_data)
+ verbose_proxy_logger.debug(f"UserAPIKeyAuth object created: {user_api_key_auth}")
+ # Create and return UserAPIKeyAuth object
+ return user_api_key_auth
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/public_key.pem b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/public_key.pem
new file mode 100644
index 00000000..0962794a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/public_key.pem
@@ -0,0 +1,9 @@
+ -----BEGIN PUBLIC KEY-----
+MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwcNBabWBZzrDhFAuA4Fh
+FhIcA3rF7vrLb8+1yhF2U62AghQp9nStyuJRjxMUuldWgJ1yRJ2s7UffVw5r8DeA
+dqXPD+w+3LCNwqJGaIKN08QGJXNArM3QtMaN0RTzAyQ4iibN1r6609W5muK9wGp0
+b1j5+iDUmf0ynItnhvaX6B8Xoaflc3WD/UBdrygLmsU5uR3XC86+/8ILoSZH3HtN
+6FJmWhlhjS2TR1cKZv8K5D0WuADTFf5MF8jYFR+uORPj5Pe/EJlLGN26Lfn2QnGu
+XgbPF6nCGwZ0hwH1Xkn3xzGaJ4xBEC761wqp5cHxWSDktHyFKnLbP3jVeegjVIHh
+pQIDAQAB
+-----END PUBLIC KEY----- \ No newline at end of file
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py
new file mode 100644
index 00000000..053cdb91
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/rds_iam_token.py
@@ -0,0 +1,187 @@
+import os
+from typing import Any, Optional, Union
+
+import httpx
+
+
+def init_rds_client(
+ aws_access_key_id: Optional[str] = None,
+ aws_secret_access_key: Optional[str] = None,
+ aws_region_name: Optional[str] = None,
+ aws_session_name: Optional[str] = None,
+ aws_profile_name: Optional[str] = None,
+ aws_role_name: Optional[str] = None,
+ aws_web_identity_token: Optional[str] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+):
+ from litellm.secret_managers.main import get_secret
+
+ # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
+ litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
+ standard_aws_region_name = get_secret("AWS_REGION", None)
+ ## CHECK IS 'os.environ/' passed in
+ # Define the list of parameters to check
+ params_to_check = [
+ aws_access_key_id,
+ aws_secret_access_key,
+ aws_region_name,
+ aws_session_name,
+ aws_profile_name,
+ aws_role_name,
+ aws_web_identity_token,
+ ]
+
+ # Iterate over parameters and update if needed
+ for i, param in enumerate(params_to_check):
+ if param and param.startswith("os.environ/"):
+ params_to_check[i] = get_secret(param) # type: ignore
+ # Assign updated values back to parameters
+ (
+ aws_access_key_id,
+ aws_secret_access_key,
+ aws_region_name,
+ aws_session_name,
+ aws_profile_name,
+ aws_role_name,
+ aws_web_identity_token,
+ ) = params_to_check
+
+ ### SET REGION NAME
+ region_name = aws_region_name
+ if aws_region_name:
+ region_name = aws_region_name
+ elif litellm_aws_region_name:
+ region_name = litellm_aws_region_name
+ elif standard_aws_region_name:
+ region_name = standard_aws_region_name
+ else:
+ raise Exception(
+ "AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
+ )
+
+ import boto3
+
+ if isinstance(timeout, float):
+ config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore
+ elif isinstance(timeout, httpx.Timeout):
+ config = boto3.session.Config( # type: ignore
+ connect_timeout=timeout.connect, read_timeout=timeout.read
+ )
+ else:
+ config = boto3.session.Config() # type: ignore
+
+ ### CHECK STS ###
+ if (
+ aws_web_identity_token is not None
+ and aws_role_name is not None
+ and aws_session_name is not None
+ ):
+ try:
+ oidc_token = open(aws_web_identity_token).read() # check if filepath
+ except Exception:
+ oidc_token = get_secret(aws_web_identity_token)
+
+ if oidc_token is None:
+ raise Exception(
+ "OIDC token could not be retrieved from secret manager.",
+ )
+
+ sts_client = boto3.client("sts")
+
+ # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
+ sts_response = sts_client.assume_role_with_web_identity(
+ RoleArn=aws_role_name,
+ RoleSessionName=aws_session_name,
+ WebIdentityToken=oidc_token,
+ DurationSeconds=3600,
+ )
+
+ client = boto3.client(
+ service_name="rds",
+ aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
+ aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
+ aws_session_token=sts_response["Credentials"]["SessionToken"],
+ region_name=region_name,
+ config=config,
+ )
+
+ elif aws_role_name is not None and aws_session_name is not None:
+ # use sts if role name passed in
+ sts_client = boto3.client(
+ "sts",
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ )
+
+ sts_response = sts_client.assume_role(
+ RoleArn=aws_role_name, RoleSessionName=aws_session_name
+ )
+
+ client = boto3.client(
+ service_name="rds",
+ aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
+ aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
+ aws_session_token=sts_response["Credentials"]["SessionToken"],
+ region_name=region_name,
+ config=config,
+ )
+ elif aws_access_key_id is not None:
+ # uses auth params passed to completion
+ # aws_access_key_id is not None, assume user is trying to auth using litellm.completion
+
+ client = boto3.client(
+ service_name="rds",
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=region_name,
+ config=config,
+ )
+ elif aws_profile_name is not None:
+ # uses auth values from AWS profile usually stored in ~/.aws/credentials
+
+ client = boto3.Session(profile_name=aws_profile_name).client(
+ service_name="rds",
+ region_name=region_name,
+ config=config,
+ )
+
+ else:
+ # aws_access_key_id is None, assume user is trying to auth using env variables
+ # boto3 automatically reads env variables
+
+ client = boto3.client(
+ service_name="rds",
+ region_name=region_name,
+ config=config,
+ )
+
+ return client
+
+
+def generate_iam_auth_token(
+ db_host, db_port, db_user, client: Optional[Any] = None
+) -> str:
+ from urllib.parse import quote
+
+ if client is None:
+ boto_client = init_rds_client(
+ aws_region_name=os.getenv("AWS_REGION_NAME"),
+ aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
+ aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
+ aws_session_name=os.getenv("AWS_SESSION_NAME"),
+ aws_profile_name=os.getenv("AWS_PROFILE_NAME"),
+ aws_role_name=os.getenv("AWS_ROLE_NAME", os.getenv("AWS_ROLE_ARN")),
+ aws_web_identity_token=os.getenv(
+ "AWS_WEB_IDENTITY_TOKEN", os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
+ ),
+ )
+ else:
+ boto_client = client
+
+ token = boto_client.generate_db_auth_token(
+ DBHostname=db_host, Port=db_port, DBUsername=db_user
+ )
+ cleaned_token = quote(token, safe="")
+
+ return cleaned_token
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/route_checks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/route_checks.py
new file mode 100644
index 00000000..a18a7ab5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/route_checks.py
@@ -0,0 +1,257 @@
+import re
+from typing import List, Optional
+
+from fastapi import HTTPException, Request, status
+
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import (
+ CommonProxyErrors,
+ LiteLLM_UserTable,
+ LiteLLMRoutes,
+ LitellmUserRoles,
+ UserAPIKeyAuth,
+)
+
+from .auth_checks_organization import _user_is_org_admin
+
+
+class RouteChecks:
+
+ @staticmethod
+ def non_proxy_admin_allowed_routes_check(
+ user_obj: Optional[LiteLLM_UserTable],
+ _user_role: Optional[LitellmUserRoles],
+ route: str,
+ request: Request,
+ valid_token: UserAPIKeyAuth,
+ api_key: str,
+ request_data: dict,
+ ):
+ """
+ Checks if Non Proxy Admin User is allowed to access the route
+ """
+
+ # Check user has defined custom admin routes
+ RouteChecks.custom_admin_only_route_check(
+ route=route,
+ )
+
+ if RouteChecks.is_llm_api_route(route=route):
+ pass
+ elif (
+ route in LiteLLMRoutes.info_routes.value
+ ): # check if user allowed to call an info route
+ if route == "/key/info":
+ # handled by function itself
+ pass
+ elif route == "/user/info":
+ # check if user can access this route
+ query_params = request.query_params
+ user_id = query_params.get("user_id")
+ verbose_proxy_logger.debug(
+ f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}"
+ )
+ if user_id and user_id != valid_token.user_id:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="key not allowed to access this user's info. user_id={}, key's user_id={}".format(
+ user_id, valid_token.user_id
+ ),
+ )
+ elif route == "/model/info":
+ # /model/info just shows models user has access to
+ pass
+ elif route == "/team/info":
+ pass # handled by function itself
+ elif (
+ route in LiteLLMRoutes.global_spend_tracking_routes.value
+ and getattr(valid_token, "permissions", None) is not None
+ and "get_spend_routes" in getattr(valid_token, "permissions", [])
+ ):
+
+ pass
+ elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value:
+ if RouteChecks.is_llm_api_route(route=route):
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"user not allowed to access this OpenAI routes, role= {_user_role}",
+ )
+ if RouteChecks.check_route_access(
+ route=route, allowed_routes=LiteLLMRoutes.management_routes.value
+ ):
+ # the Admin Viewer is only allowed to call /user/update for their own user_id and can only update
+ if route == "/user/update":
+
+ # Check the Request params are valid for PROXY_ADMIN_VIEW_ONLY
+ if request_data is not None and isinstance(request_data, dict):
+ _params_updated = request_data.keys()
+ for param in _params_updated:
+ if param not in ["user_email", "password"]:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route} and updating invalid param: {param}. only user_email and password can be updated",
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
+ )
+
+ elif (
+ _user_role == LitellmUserRoles.INTERNAL_USER.value
+ and RouteChecks.check_route_access(
+ route=route, allowed_routes=LiteLLMRoutes.internal_user_routes.value
+ )
+ ):
+ pass
+ elif _user_is_org_admin(
+ request_data=request_data, user_object=user_obj
+ ) and RouteChecks.check_route_access(
+ route=route, allowed_routes=LiteLLMRoutes.org_admin_allowed_routes.value
+ ):
+ pass
+ elif (
+ _user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
+ and RouteChecks.check_route_access(
+ route=route,
+ allowed_routes=LiteLLMRoutes.internal_user_view_only_routes.value,
+ )
+ ):
+ pass
+ elif RouteChecks.check_route_access(
+ route=route, allowed_routes=LiteLLMRoutes.self_managed_routes.value
+ ): # routes that manage their own allowed/disallowed logic
+ pass
+ else:
+ user_role = "unknown"
+ user_id = "unknown"
+ if user_obj is not None:
+ user_role = user_obj.user_role or "unknown"
+ user_id = user_obj.user_id or "unknown"
+ raise Exception(
+ f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}"
+ )
+
+ @staticmethod
+ def custom_admin_only_route_check(route: str):
+ from litellm.proxy.proxy_server import general_settings, premium_user
+
+ if "admin_only_routes" in general_settings:
+ if premium_user is not True:
+ verbose_proxy_logger.error(
+ f"Trying to use 'admin_only_routes' this is an Enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
+ )
+ return
+ if route in general_settings["admin_only_routes"]:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"user not allowed to access this route. Route={route} is an admin only route",
+ )
+ pass
+
+ @staticmethod
+ def is_llm_api_route(route: str) -> bool:
+ """
+ Helper to checks if provided route is an OpenAI route
+
+
+ Returns:
+ - True: if route is an OpenAI route
+ - False: if route is not an OpenAI route
+ """
+
+ if route in LiteLLMRoutes.openai_routes.value:
+ return True
+
+ if route in LiteLLMRoutes.anthropic_routes.value:
+ return True
+
+ # fuzzy match routes like "/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
+ # Check for routes with placeholders
+ for openai_route in LiteLLMRoutes.openai_routes.value:
+ # Replace placeholders with regex pattern
+ # placeholders are written as "/threads/{thread_id}"
+ if "{" in openai_route:
+ if RouteChecks._route_matches_pattern(
+ route=route, pattern=openai_route
+ ):
+ return True
+
+ if RouteChecks._is_azure_openai_route(route=route):
+ return True
+
+ for _llm_passthrough_route in LiteLLMRoutes.mapped_pass_through_routes.value:
+ if _llm_passthrough_route in route:
+ return True
+
+ return False
+
+ @staticmethod
+ def _is_azure_openai_route(route: str) -> bool:
+ """
+ Check if route is a route from AzureOpenAI SDK client
+
+ eg.
+ route='/openai/deployments/vertex_ai/gemini-1.5-flash/chat/completions'
+ """
+ # Add support for deployment and engine model paths
+ deployment_pattern = r"^/openai/deployments/[^/]+/[^/]+/chat/completions$"
+ engine_pattern = r"^/engines/[^/]+/chat/completions$"
+
+ if re.match(deployment_pattern, route) or re.match(engine_pattern, route):
+ return True
+ return False
+
+ @staticmethod
+ def _route_matches_pattern(route: str, pattern: str) -> bool:
+ """
+ Check if route matches the pattern placed in proxy/_types.py
+
+ Example:
+ - pattern: "/threads/{thread_id}"
+ - route: "/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
+ - returns: True
+
+
+ - pattern: "/key/{token_id}/regenerate"
+ - route: "/key/regenerate/82akk800000000jjsk"
+ - returns: False, pattern is "/key/{token_id}/regenerate"
+ """
+ pattern = re.sub(r"\{[^}]+\}", r"[^/]+", pattern)
+ # Anchor the pattern to match the entire string
+ pattern = f"^{pattern}$"
+ if re.match(pattern, route):
+ return True
+ return False
+
+ @staticmethod
+ def check_route_access(route: str, allowed_routes: List[str]) -> bool:
+ """
+ Check if a route has access by checking both exact matches and patterns
+
+ Args:
+ route (str): The route to check
+ allowed_routes (list): List of allowed routes/patterns
+
+ Returns:
+ bool: True if route is allowed, False otherwise
+ """
+ return route in allowed_routes or any( # Check exact match
+ RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
+ for allowed_route in allowed_routes
+ ) # Check pattern match
+
+ @staticmethod
+ def _is_assistants_api_request(request: Request) -> bool:
+ """
+ Returns True if `thread` or `assistant` is in the request path
+
+ Args:
+ request (Request): The request object
+
+ Returns:
+ bool: True if `thread` or `assistant` is in the request path, False otherwise
+ """
+ if "thread" in request.url.path or "assistant" in request.url.path:
+ return True
+ return False
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/service_account_checks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/service_account_checks.py
new file mode 100644
index 00000000..87d7d668
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/service_account_checks.py
@@ -0,0 +1,53 @@
+"""
+Checks for LiteLLM service account keys
+
+"""
+
+from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth
+
+
+def check_if_token_is_service_account(valid_token: UserAPIKeyAuth) -> bool:
+ """
+ Checks if the token is a service account
+
+ Returns:
+ bool: True if token is a service account
+
+ """
+ if valid_token.metadata:
+ if "service_account_id" in valid_token.metadata:
+ return True
+ return False
+
+
+async def service_account_checks(
+ valid_token: UserAPIKeyAuth, request_data: dict
+) -> bool:
+ """
+ If a virtual key is a service account, checks it's a valid service account
+
+ A token is a service account if it has a service_account_id in its metadata
+
+ Service Account Specific Checks:
+ - Check if required_params is set
+ """
+
+ if check_if_token_is_service_account(valid_token) is not True:
+ return True
+
+ from litellm.proxy.proxy_server import general_settings
+
+ if "service_account_settings" in general_settings:
+ service_account_settings = general_settings["service_account_settings"]
+ if "enforced_params" in service_account_settings:
+ _enforced_params = service_account_settings["enforced_params"]
+ for param in _enforced_params:
+ if param not in request_data:
+ raise ProxyException(
+ type=ProxyErrorTypes.bad_request_error.value,
+ code=400,
+ param=param,
+ message=f"BadRequest please pass param={param} in request body. This is a required param for service account",
+ )
+
+ return True
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/user_api_key_auth.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/user_api_key_auth.py
new file mode 100644
index 00000000..ace0bf49
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/user_api_key_auth.py
@@ -0,0 +1,1337 @@
+"""
+This file handles authentication for the LiteLLM Proxy.
+
+it checks if the user passed a valid API Key to the LiteLLM Proxy
+
+Returns a UserAPIKeyAuth object if the API key is valid
+
+"""
+
+import asyncio
+import re
+import secrets
+from datetime import datetime, timezone
+from typing import Optional, cast
+
+import fastapi
+from fastapi import HTTPException, Request, WebSocket, status
+from fastapi.security.api_key import APIKeyHeader
+
+import litellm
+from litellm._logging import verbose_logger, verbose_proxy_logger
+from litellm._service_logger import ServiceLogging
+from litellm.caching import DualCache
+from litellm.litellm_core_utils.dd_tracing import tracer
+from litellm.proxy._types import *
+from litellm.proxy.auth.auth_checks import (
+ _cache_key_object,
+ _handle_failed_db_connection_for_get_key_object,
+ _virtual_key_max_budget_check,
+ _virtual_key_soft_budget_check,
+ can_key_call_model,
+ common_checks,
+ get_end_user_object,
+ get_key_object,
+ get_team_object,
+ get_user_object,
+ is_valid_fallback_model,
+)
+from litellm.proxy.auth.auth_utils import (
+ _get_request_ip_address,
+ get_end_user_id_from_request_body,
+ get_request_route,
+ is_pass_through_provider_route,
+ pre_db_read_auth_checks,
+ route_in_additonal_public_routes,
+ should_run_auth_on_pass_through_provider_route,
+)
+from litellm.proxy.auth.handle_jwt import JWTAuthManager, JWTHandler
+from litellm.proxy.auth.oauth2_check import check_oauth2_token
+from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
+from litellm.proxy.auth.route_checks import RouteChecks
+from litellm.proxy.auth.service_account_checks import service_account_checks
+from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
+from litellm.proxy.utils import PrismaClient, ProxyLogging
+from litellm.types.services import ServiceTypes
+
+user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
+
+
+api_key_header = APIKeyHeader(
+ name=SpecialHeaders.openai_authorization.value,
+ auto_error=False,
+ description="Bearer token",
+)
+azure_api_key_header = APIKeyHeader(
+ name=SpecialHeaders.azure_authorization.value,
+ auto_error=False,
+ description="Some older versions of the openai Python package will send an API-Key header with just the API key ",
+)
+anthropic_api_key_header = APIKeyHeader(
+ name=SpecialHeaders.anthropic_authorization.value,
+ auto_error=False,
+ description="If anthropic client used.",
+)
+google_ai_studio_api_key_header = APIKeyHeader(
+ name=SpecialHeaders.google_ai_studio_authorization.value,
+ auto_error=False,
+ description="If google ai studio client used.",
+)
+azure_apim_header = APIKeyHeader(
+ name=SpecialHeaders.azure_apim_authorization.value,
+ auto_error=False,
+ description="The default name of the subscription key header of Azure",
+)
+
+
+def _get_bearer_token(
+ api_key: str,
+):
+ if api_key.startswith("Bearer "): # ensure Bearer token passed in
+ api_key = api_key.replace("Bearer ", "") # extract the token
+ elif api_key.startswith("Basic "):
+ api_key = api_key.replace("Basic ", "") # handle langfuse input
+ elif api_key.startswith("bearer "):
+ api_key = api_key.replace("bearer ", "")
+ else:
+ api_key = ""
+ return api_key
+
+
+def _is_ui_route(
+ route: str,
+ user_obj: Optional[LiteLLM_UserTable] = None,
+) -> bool:
+ """
+ - Check if the route is a UI used route
+ """
+ # this token is only used for managing the ui
+ allowed_routes = LiteLLMRoutes.ui_routes.value
+ # check if the current route startswith any of the allowed routes
+ if (
+ route is not None
+ and isinstance(route, str)
+ and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
+ ):
+ # Do something if the current route starts with any of the allowed routes
+ return True
+ elif any(
+ RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
+ for allowed_route in allowed_routes
+ ):
+ return True
+ return False
+
+
+def _is_api_route_allowed(
+ route: str,
+ request: Request,
+ request_data: dict,
+ api_key: str,
+ valid_token: Optional[UserAPIKeyAuth],
+ user_obj: Optional[LiteLLM_UserTable] = None,
+) -> bool:
+ """
+ - Route b/w api token check and normal token check
+ """
+ _user_role = _get_user_role(user_obj=user_obj)
+
+ if valid_token is None:
+ raise Exception("Invalid proxy server token passed. valid_token=None.")
+
+ if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
+ RouteChecks.non_proxy_admin_allowed_routes_check(
+ user_obj=user_obj,
+ _user_role=_user_role,
+ route=route,
+ request=request,
+ request_data=request_data,
+ api_key=api_key,
+ valid_token=valid_token,
+ )
+ return True
+
+
+def _is_allowed_route(
+ route: str,
+ token_type: Literal["ui", "api"],
+ request: Request,
+ request_data: dict,
+ api_key: str,
+ valid_token: Optional[UserAPIKeyAuth],
+ user_obj: Optional[LiteLLM_UserTable] = None,
+) -> bool:
+ """
+ - Route b/w ui token check and normal token check
+ """
+
+ if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
+ return True
+ else:
+ return _is_api_route_allowed(
+ route=route,
+ request=request,
+ request_data=request_data,
+ api_key=api_key,
+ valid_token=valid_token,
+ user_obj=user_obj,
+ )
+
+
+async def user_api_key_auth_websocket(websocket: WebSocket):
+ # Accept the WebSocket connection
+
+ request = Request(scope={"type": "http"})
+ request._url = websocket.url
+
+ query_params = websocket.query_params
+
+ model = query_params.get("model")
+
+ async def return_body():
+ return_string = f'{{"model": "{model}"}}'
+ # return string as bytes
+ return return_string.encode()
+
+ request.body = return_body # type: ignore
+
+ # Extract the Authorization header
+ authorization = websocket.headers.get("authorization")
+
+ # If no Authorization header, try the api-key header
+ if not authorization:
+ api_key = websocket.headers.get("api-key")
+ if not api_key:
+ await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
+ raise HTTPException(status_code=403, detail="No API key provided")
+ else:
+ # Extract the API key from the Bearer token
+ if not authorization.startswith("Bearer "):
+ await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
+ raise HTTPException(
+ status_code=403, detail="Invalid Authorization header format"
+ )
+
+ api_key = authorization[len("Bearer ") :].strip()
+
+ # Call user_api_key_auth with the extracted API key
+ # Note: You'll need to modify this to work with WebSocket context if needed
+ try:
+ return await user_api_key_auth(request=request, api_key=f"Bearer {api_key}")
+ except Exception as e:
+ verbose_proxy_logger.exception(e)
+ await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
+ raise HTTPException(status_code=403, detail=str(e))
+
+
+def update_valid_token_with_end_user_params(
+ valid_token: UserAPIKeyAuth, end_user_params: dict
+) -> UserAPIKeyAuth:
+ valid_token.end_user_id = end_user_params.get("end_user_id")
+ valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
+ valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
+ valid_token.allowed_model_region = end_user_params.get("allowed_model_region")
+ return valid_token
+
+
+async def get_global_proxy_spend(
+ litellm_proxy_admin_name: str,
+ user_api_key_cache: DualCache,
+ prisma_client: Optional[PrismaClient],
+ token: str,
+ proxy_logging_obj: ProxyLogging,
+) -> Optional[float]:
+ global_proxy_spend = None
+ if litellm.max_budget > 0: # user set proxy max budget
+ # check cache
+ global_proxy_spend = await user_api_key_cache.async_get_cache(
+ key="{}:spend".format(litellm_proxy_admin_name)
+ )
+ if global_proxy_spend is None and prisma_client is not None:
+ # get from db
+ sql_query = (
+ """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
+ )
+
+ response = await prisma_client.db.query_raw(query=sql_query)
+
+ global_proxy_spend = response[0]["total_spend"]
+
+ await user_api_key_cache.async_set_cache(
+ key="{}:spend".format(litellm_proxy_admin_name),
+ value=global_proxy_spend,
+ )
+ if global_proxy_spend is not None:
+ user_info = CallInfo(
+ user_id=litellm_proxy_admin_name,
+ max_budget=litellm.max_budget,
+ spend=global_proxy_spend,
+ token=token,
+ )
+ asyncio.create_task(
+ proxy_logging_obj.budget_alerts(
+ type="proxy_budget",
+ user_info=user_info,
+ )
+ )
+ return global_proxy_spend
+
+
+def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
+ is_admin = jwt_handler.is_admin(scopes=scopes)
+ if is_admin:
+ return LitellmUserRoles.PROXY_ADMIN
+ else:
+ return LitellmUserRoles.TEAM
+
+
+def get_model_from_request(request_data: dict, route: str) -> Optional[str]:
+
+ # First try to get model from request_data
+ model = request_data.get("model")
+
+ # If model not in request_data, try to extract from route
+ if model is None:
+ # Parse model from route that follows the pattern /openai/deployments/{model}/*
+ match = re.match(r"/openai/deployments/([^/]+)", route)
+ if match:
+ model = match.group(1)
+
+ return model
+
+
+async def _user_api_key_auth_builder( # noqa: PLR0915
+ request: Request,
+ api_key: str,
+ azure_api_key_header: str,
+ anthropic_api_key_header: Optional[str],
+ google_ai_studio_api_key_header: Optional[str],
+ azure_apim_header: Optional[str],
+ request_data: dict,
+) -> UserAPIKeyAuth:
+
+ from litellm.proxy.proxy_server import (
+ general_settings,
+ jwt_handler,
+ litellm_proxy_admin_name,
+ llm_model_list,
+ llm_router,
+ master_key,
+ model_max_budget_limiter,
+ open_telemetry_logger,
+ prisma_client,
+ proxy_logging_obj,
+ user_api_key_cache,
+ user_custom_auth,
+ )
+
+ parent_otel_span: Optional[Span] = None
+ start_time = datetime.now()
+ route: str = get_request_route(request=request)
+ try:
+
+ # get the request body
+
+ await pre_db_read_auth_checks(
+ request_data=request_data,
+ request=request,
+ route=route,
+ )
+ pass_through_endpoints: Optional[List[dict]] = general_settings.get(
+ "pass_through_endpoints", None
+ )
+ passed_in_key: Optional[str] = None
+ if isinstance(api_key, str):
+ passed_in_key = api_key
+ api_key = _get_bearer_token(api_key=api_key)
+ elif isinstance(azure_api_key_header, str):
+ api_key = azure_api_key_header
+ elif isinstance(anthropic_api_key_header, str):
+ api_key = anthropic_api_key_header
+ elif isinstance(google_ai_studio_api_key_header, str):
+ api_key = google_ai_studio_api_key_header
+ elif isinstance(azure_apim_header, str):
+ api_key = azure_apim_header
+ elif pass_through_endpoints is not None:
+ for endpoint in pass_through_endpoints:
+ if endpoint.get("path", "") == route:
+ headers: Optional[dict] = endpoint.get("headers", None)
+ if headers is not None:
+ header_key: str = headers.get("litellm_user_api_key", "")
+ if request.headers.get(key=header_key) is not None:
+ api_key = request.headers.get(key=header_key)
+
+ # if user wants to pass LiteLLM_Master_Key as a custom header, example pass litellm keys as X-LiteLLM-Key: Bearer sk-1234
+ custom_litellm_key_header_name = general_settings.get("litellm_key_header_name")
+ if custom_litellm_key_header_name is not None:
+ api_key = get_api_key_from_custom_header(
+ request=request,
+ custom_litellm_key_header_name=custom_litellm_key_header_name,
+ )
+
+ if open_telemetry_logger is not None:
+ parent_otel_span = (
+ open_telemetry_logger.create_litellm_proxy_request_started_span(
+ start_time=start_time,
+ headers=dict(request.headers),
+ )
+ )
+
+ ### USER-DEFINED AUTH FUNCTION ###
+ if user_custom_auth is not None:
+ response = await user_custom_auth(request=request, api_key=api_key) # type: ignore
+ return UserAPIKeyAuth.model_validate(response)
+
+ ### LITELLM-DEFINED AUTH FUNCTION ###
+ #### IF JWT ####
+ """
+ LiteLLM supports using JWTs.
+
+ Enable this in proxy config, by setting
+ ```
+ general_settings:
+ enable_jwt_auth: true
+ ```
+ """
+
+ ######## Route Checks Before Reading DB / Cache for "token" ################
+ if (
+ route in LiteLLMRoutes.public_routes.value # type: ignore
+ or route_in_additonal_public_routes(current_route=route)
+ ):
+ # check if public endpoint
+ return UserAPIKeyAuth(user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY)
+ elif is_pass_through_provider_route(route=route):
+ if should_run_auth_on_pass_through_provider_route(route=route) is False:
+ return UserAPIKeyAuth(
+ user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
+ )
+
+ ########## End of Route Checks Before Reading DB / Cache for "token" ########
+
+ if general_settings.get("enable_oauth2_auth", False) is True:
+ # return UserAPIKeyAuth object
+ # helper to check if the api_key is a valid oauth2 token
+ from litellm.proxy.proxy_server import premium_user
+
+ if premium_user is not True:
+ raise ValueError(
+ "Oauth2 token validation is only available for premium users"
+ + CommonProxyErrors.not_premium_user.value
+ )
+
+ return await check_oauth2_token(token=api_key)
+
+ if general_settings.get("enable_oauth2_proxy_auth", False) is True:
+ return await handle_oauth2_proxy_request(request=request)
+
+ if general_settings.get("enable_jwt_auth", False) is True:
+ from litellm.proxy.proxy_server import premium_user
+
+ if premium_user is not True:
+ raise ValueError(
+ f"JWT Auth is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
+ )
+ is_jwt = jwt_handler.is_jwt(token=api_key)
+ verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
+ if is_jwt:
+ result = await JWTAuthManager.auth_builder(
+ request_data=request_data,
+ general_settings=general_settings,
+ api_key=api_key,
+ jwt_handler=jwt_handler,
+ route=route,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ parent_otel_span=parent_otel_span,
+ )
+
+ is_proxy_admin = result["is_proxy_admin"]
+ team_id = result["team_id"]
+ team_object = result["team_object"]
+ user_id = result["user_id"]
+ user_object = result["user_object"]
+ end_user_id = result["end_user_id"]
+ end_user_object = result["end_user_object"]
+ org_id = result["org_id"]
+ token = result["token"]
+
+ global_proxy_spend = await get_global_proxy_spend(
+ litellm_proxy_admin_name=litellm_proxy_admin_name,
+ user_api_key_cache=user_api_key_cache,
+ prisma_client=prisma_client,
+ token=token,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+ if is_proxy_admin:
+ return UserAPIKeyAuth(
+ user_role=LitellmUserRoles.PROXY_ADMIN,
+ parent_otel_span=parent_otel_span,
+ )
+ # run through common checks
+ _ = await common_checks(
+ request_body=request_data,
+ team_object=team_object,
+ user_object=user_object,
+ end_user_object=end_user_object,
+ general_settings=general_settings,
+ global_proxy_spend=global_proxy_spend,
+ route=route,
+ llm_router=llm_router,
+ proxy_logging_obj=proxy_logging_obj,
+ valid_token=None,
+ )
+
+ # return UserAPIKeyAuth object
+ return UserAPIKeyAuth(
+ api_key=None,
+ team_id=team_id,
+ team_tpm_limit=(
+ team_object.tpm_limit if team_object is not None else None
+ ),
+ team_rpm_limit=(
+ team_object.rpm_limit if team_object is not None else None
+ ),
+ team_models=team_object.models if team_object is not None else [],
+ user_role=LitellmUserRoles.INTERNAL_USER,
+ user_id=user_id,
+ org_id=org_id,
+ parent_otel_span=parent_otel_span,
+ end_user_id=end_user_id,
+ )
+
+ #### ELSE ####
+ ## CHECK PASS-THROUGH ENDPOINTS ##
+ is_mapped_pass_through_route: bool = False
+ for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: # type: ignore
+ if route.startswith(mapped_route):
+ is_mapped_pass_through_route = True
+ if is_mapped_pass_through_route:
+ if request.headers.get("litellm_user_api_key") is not None:
+ api_key = request.headers.get("litellm_user_api_key") or ""
+ if pass_through_endpoints is not None:
+ for endpoint in pass_through_endpoints:
+ if isinstance(endpoint, dict) and endpoint.get("path", "") == route:
+ ## IF AUTH DISABLED
+ if endpoint.get("auth") is not True:
+ return UserAPIKeyAuth()
+ ## IF AUTH ENABLED
+ ### IF CUSTOM PARSER REQUIRED
+ if (
+ endpoint.get("custom_auth_parser") is not None
+ and endpoint.get("custom_auth_parser") == "langfuse"
+ ):
+ """
+ - langfuse returns {'Authorization': 'Basic YW55dGhpbmc6YW55dGhpbmc'}
+ - check the langfuse public key if it contains the litellm api key
+ """
+ import base64
+
+ api_key = api_key.replace("Basic ", "").strip()
+ decoded_bytes = base64.b64decode(api_key)
+ decoded_str = decoded_bytes.decode("utf-8")
+ api_key = decoded_str.split(":")[0]
+ else:
+ headers = endpoint.get("headers", None)
+ if headers is not None:
+ header_key = headers.get("litellm_user_api_key", "")
+ if (
+ isinstance(request.headers, dict)
+ and request.headers.get(key=header_key) is not None # type: ignore
+ ):
+ api_key = request.headers.get(key=header_key) # type: ignore
+ if master_key is None:
+ if isinstance(api_key, str):
+ return UserAPIKeyAuth(
+ api_key=api_key,
+ user_role=LitellmUserRoles.PROXY_ADMIN,
+ parent_otel_span=parent_otel_span,
+ )
+ else:
+ return UserAPIKeyAuth(
+ user_role=LitellmUserRoles.PROXY_ADMIN,
+ parent_otel_span=parent_otel_span,
+ )
+ elif api_key is None: # only require api key if master key is set
+ raise Exception("No api key passed in.")
+ elif api_key == "":
+ # missing 'Bearer ' prefix
+ raise Exception(
+ f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
+ )
+
+ if route == "/user/auth":
+ if general_settings.get("allow_user_auth", False) is True:
+ return UserAPIKeyAuth()
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="'allow_user_auth' not set or set to False",
+ )
+
+ ## Check END-USER OBJECT
+ _end_user_object = None
+ end_user_params = {}
+
+ end_user_id = get_end_user_id_from_request_body(request_data)
+ if end_user_id:
+ try:
+ end_user_params["end_user_id"] = end_user_id
+
+ # get end-user object
+ _end_user_object = await get_end_user_object(
+ end_user_id=end_user_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+ if _end_user_object is not None:
+ end_user_params["allowed_model_region"] = (
+ _end_user_object.allowed_model_region
+ )
+ if _end_user_object.litellm_budget_table is not None:
+ budget_info = _end_user_object.litellm_budget_table
+ if budget_info.tpm_limit is not None:
+ end_user_params["end_user_tpm_limit"] = (
+ budget_info.tpm_limit
+ )
+ if budget_info.rpm_limit is not None:
+ end_user_params["end_user_rpm_limit"] = (
+ budget_info.rpm_limit
+ )
+ if budget_info.max_budget is not None:
+ end_user_params["end_user_max_budget"] = (
+ budget_info.max_budget
+ )
+ except Exception as e:
+ if isinstance(e, litellm.BudgetExceededError):
+ raise e
+ verbose_proxy_logger.debug(
+ "Unable to find user in db. Error - {}".format(str(e))
+ )
+ pass
+
+ ### CHECK IF ADMIN ###
+ # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
+ ### CHECK IF ADMIN ###
+ # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
+ ## Check CACHE
+ try:
+ valid_token = await get_key_object(
+ hashed_token=hash_token(api_key),
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ check_cache_only=True,
+ )
+ except Exception:
+ verbose_logger.debug("api key not found in cache.")
+ valid_token = None
+
+ if (
+ valid_token is not None
+ and isinstance(valid_token, UserAPIKeyAuth)
+ and valid_token.user_role == LitellmUserRoles.PROXY_ADMIN
+ ):
+ # update end-user params on valid token
+ valid_token = update_valid_token_with_end_user_params(
+ valid_token=valid_token, end_user_params=end_user_params
+ )
+ valid_token.parent_otel_span = parent_otel_span
+
+ return valid_token
+
+ if (
+ valid_token is not None
+ and isinstance(valid_token, UserAPIKeyAuth)
+ and valid_token.team_id is not None
+ ):
+ ## UPDATE TEAM VALUES BASED ON CACHED TEAM OBJECT - allows `/team/update` values to work for cached token
+ try:
+ team_obj: LiteLLM_TeamTableCachedObj = await get_team_object(
+ team_id=valid_token.team_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ check_cache_only=True,
+ )
+
+ if (
+ team_obj.last_refreshed_at is not None
+ and valid_token.last_refreshed_at is not None
+ and team_obj.last_refreshed_at > valid_token.last_refreshed_at
+ ):
+ team_obj_dict = team_obj.__dict__
+
+ for k, v in team_obj_dict.items():
+ field_name = f"team_{k}"
+ if field_name in valid_token.__fields__:
+ setattr(valid_token, field_name, v)
+ except Exception as e:
+ verbose_logger.debug(
+ e
+ ) # moving from .warning to .debug as it spams logs when team missing from cache.
+
+ try:
+ is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
+ except Exception:
+ is_master_key_valid = False
+
+ ## VALIDATE MASTER KEY ##
+ try:
+ assert isinstance(master_key, str)
+ except Exception:
+ raise HTTPException(
+ status_code=500,
+ detail={
+ "Master key must be a valid string. Current type={}".format(
+ type(master_key)
+ )
+ },
+ )
+
+ if is_master_key_valid:
+ _user_api_key_obj = await _return_user_api_key_auth_obj(
+ user_obj=None,
+ user_role=LitellmUserRoles.PROXY_ADMIN,
+ api_key=master_key,
+ parent_otel_span=parent_otel_span,
+ valid_token_dict={
+ **end_user_params,
+ "user_id": litellm_proxy_admin_name,
+ },
+ route=route,
+ start_time=start_time,
+ )
+ asyncio.create_task(
+ _cache_key_object(
+ hashed_token=hash_token(master_key),
+ user_api_key_obj=_user_api_key_obj,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+ )
+
+ _user_api_key_obj = update_valid_token_with_end_user_params(
+ valid_token=_user_api_key_obj, end_user_params=end_user_params
+ )
+
+ return _user_api_key_obj
+
+ ## IF it's not a master key
+ ## Route should not be in master_key_only_routes
+ if route in LiteLLMRoutes.master_key_only_routes.value: # type: ignore
+ raise Exception(
+ f"Tried to access route={route}, which is only for MASTER KEY"
+ )
+
+ ## Check DB
+ if isinstance(
+ api_key, str
+ ): # if generated token, make sure it starts with sk-.
+ assert api_key.startswith(
+ "sk-"
+ ), "LiteLLM Virtual Key expected. Received={}, expected to start with 'sk-'.".format(
+ api_key
+ ) # prevent token hashes from being used
+ else:
+ verbose_logger.warning(
+ "litellm.proxy.proxy_server.user_api_key_auth(): Warning - Key={} is not a string.".format(
+ api_key
+ )
+ )
+
+ if (
+ prisma_client is None
+ ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
+ return await _handle_failed_db_connection_for_get_key_object(
+ e=Exception("No connected db.")
+ )
+
+ ## check for cache hit (In-Memory Cache)
+ _user_role = None
+ if api_key.startswith("sk-"):
+ api_key = hash_token(token=api_key)
+
+ if valid_token is None:
+ try:
+ valid_token = await get_key_object(
+ hashed_token=api_key,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+ # update end-user params on valid token
+ # These can change per request - it's important to update them here
+ valid_token.end_user_id = end_user_params.get("end_user_id")
+ valid_token.end_user_tpm_limit = end_user_params.get(
+ "end_user_tpm_limit"
+ )
+ valid_token.end_user_rpm_limit = end_user_params.get(
+ "end_user_rpm_limit"
+ )
+ valid_token.allowed_model_region = end_user_params.get(
+ "allowed_model_region"
+ )
+ # update key budget with temp budget increase
+ valid_token = _update_key_budget_with_temp_budget_increase(
+ valid_token
+ ) # updating it here, allows all downstream reporting / checks to use the updated budget
+ except Exception:
+ verbose_logger.info(
+ "litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format(
+ api_key
+ )
+ )
+ valid_token = None
+
+ if valid_token is None:
+ raise Exception(
+ "Invalid proxy server token passed. Received API Key (hashed)={}. Unable to find token in cache or `LiteLLM_VerificationTokenTable`".format(
+ api_key
+ )
+ )
+
+ user_obj: Optional[LiteLLM_UserTable] = None
+ valid_token_dict: dict = {}
+ if valid_token is not None:
+ # Got Valid Token from Cache, DB
+ # Run checks for
+ # 1. If token can call model
+ ## 1a. If token can call fallback models (if client-side fallbacks given)
+ # 2. If user_id for this token is in budget
+ # 3. If the user spend within their own team is within budget
+ # 4. If 'user' passed to /chat/completions, /embeddings endpoint is in budget
+ # 5. If token is expired
+ # 6. If token spend is under Budget for the token
+ # 7. If token spend per model is under budget per model
+ # 8. If token spend is under team budget
+ # 9. If team spend is under team budget
+
+ ## base case ## key is disabled
+ if valid_token.blocked is True:
+ raise Exception(
+ "Key is blocked. Update via `/key/unblock` if you're admin."
+ )
+ config = valid_token.config
+
+ if config != {}:
+ model_list = config.get("model_list", [])
+ new_model_list = model_list
+ verbose_proxy_logger.debug(
+ f"\n new llm router model list {new_model_list}"
+ )
+ elif (
+ isinstance(valid_token.models, list)
+ and "all-team-models" in valid_token.models
+ ):
+ # Do not do any validation at this step
+ # the validation will occur when checking the team has access to this model
+ pass
+ else:
+ model = get_model_from_request(request_data, route)
+ fallback_models = cast(
+ Optional[List[ALL_FALLBACK_MODEL_VALUES]],
+ request_data.get("fallbacks", None),
+ )
+
+ if model is not None:
+ await can_key_call_model(
+ model=model,
+ llm_model_list=llm_model_list,
+ valid_token=valid_token,
+ llm_router=llm_router,
+ )
+
+ if fallback_models is not None:
+ for m in fallback_models:
+ await can_key_call_model(
+ model=m["model"] if isinstance(m, dict) else m,
+ llm_model_list=llm_model_list,
+ valid_token=valid_token,
+ llm_router=llm_router,
+ )
+ await is_valid_fallback_model(
+ model=m["model"] if isinstance(m, dict) else m,
+ llm_router=llm_router,
+ user_model=None,
+ )
+
+ # Check 2. If user_id for this token is in budget - done in common_checks()
+ if valid_token.user_id is not None:
+ try:
+ user_obj = await get_user_object(
+ user_id=valid_token.user_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ user_id_upsert=False,
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+ except Exception as e:
+ verbose_logger.debug(
+ "litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to get user from db/cache. Setting user_obj to None. Exception received - {}".format(
+ str(e)
+ )
+ )
+ user_obj = None
+
+ # Check 3. Check if user is in their team budget
+ if valid_token.team_member_spend is not None:
+ if prisma_client is not None:
+
+ _cache_key = f"{valid_token.team_id}_{valid_token.user_id}"
+
+ team_member_info = await user_api_key_cache.async_get_cache(
+ key=_cache_key
+ )
+ if team_member_info is None:
+ # read from DB
+ _user_id = valid_token.user_id
+ _team_id = valid_token.team_id
+
+ if _user_id is not None and _team_id is not None:
+ team_member_info = await prisma_client.db.litellm_teammembership.find_first(
+ where={
+ "user_id": _user_id,
+ "team_id": _team_id,
+ }, # type: ignore
+ include={"litellm_budget_table": True},
+ )
+ await user_api_key_cache.async_set_cache(
+ key=_cache_key,
+ value=team_member_info,
+ )
+
+ if (
+ team_member_info is not None
+ and team_member_info.litellm_budget_table is not None
+ ):
+ team_member_budget = (
+ team_member_info.litellm_budget_table.max_budget
+ )
+ if team_member_budget is not None and team_member_budget > 0:
+ if valid_token.team_member_spend > team_member_budget:
+ raise litellm.BudgetExceededError(
+ current_cost=valid_token.team_member_spend,
+ max_budget=team_member_budget,
+ )
+
+ # Check 3. If token is expired
+ if valid_token.expires is not None:
+ current_time = datetime.now(timezone.utc)
+ if isinstance(valid_token.expires, datetime):
+ expiry_time = valid_token.expires
+ else:
+ expiry_time = datetime.fromisoformat(valid_token.expires)
+ if (
+ expiry_time.tzinfo is None
+ or expiry_time.tzinfo.utcoffset(expiry_time) is None
+ ):
+ expiry_time = expiry_time.replace(tzinfo=timezone.utc)
+ verbose_proxy_logger.debug(
+ f"Checking if token expired, expiry time {expiry_time} and current time {current_time}"
+ )
+ if expiry_time < current_time:
+ # Token exists but is expired.
+ raise ProxyException(
+ message=f"Authentication Error - Expired Key. Key Expiry time {expiry_time} and current time {current_time}",
+ type=ProxyErrorTypes.expired_key,
+ code=400,
+ param=api_key,
+ )
+
+ # Check 4. Token Spend is under budget
+ await _virtual_key_max_budget_check(
+ valid_token=valid_token,
+ proxy_logging_obj=proxy_logging_obj,
+ user_obj=user_obj,
+ )
+
+ # Check 5. Soft Budget Check
+ await _virtual_key_soft_budget_check(
+ valid_token=valid_token,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+ # Check 5. Token Model Spend is under Model budget
+ max_budget_per_model = valid_token.model_max_budget
+ current_model = request_data.get("model", None)
+
+ if (
+ max_budget_per_model is not None
+ and isinstance(max_budget_per_model, dict)
+ and len(max_budget_per_model) > 0
+ and prisma_client is not None
+ and current_model is not None
+ and valid_token.token is not None
+ ):
+ ## GET THE SPEND FOR THIS MODEL
+ await model_max_budget_limiter.is_key_within_model_budget(
+ user_api_key_dict=valid_token,
+ model=current_model,
+ )
+
+ # Check 6: Additional Common Checks across jwt + key auth
+ if valid_token.team_id is not None:
+ _team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable(
+ team_id=valid_token.team_id,
+ max_budget=valid_token.team_max_budget,
+ spend=valid_token.team_spend,
+ tpm_limit=valid_token.team_tpm_limit,
+ rpm_limit=valid_token.team_rpm_limit,
+ blocked=valid_token.team_blocked,
+ models=valid_token.team_models,
+ metadata=valid_token.team_metadata,
+ )
+ else:
+ _team_obj = None
+
+ # Check 7: Check if key is a service account key
+ await service_account_checks(
+ valid_token=valid_token,
+ request_data=request_data,
+ )
+
+ user_api_key_cache.set_cache(
+ key=valid_token.team_id, value=_team_obj
+ ) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py
+
+ global_proxy_spend = None
+ if (
+ litellm.max_budget > 0 and prisma_client is not None
+ ): # user set proxy max budget
+ # check cache
+ global_proxy_spend = await user_api_key_cache.async_get_cache(
+ key="{}:spend".format(litellm_proxy_admin_name)
+ )
+ if global_proxy_spend is None:
+ # get from db
+ sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
+
+ response = await prisma_client.db.query_raw(query=sql_query)
+
+ global_proxy_spend = response[0]["total_spend"]
+ await user_api_key_cache.async_set_cache(
+ key="{}:spend".format(litellm_proxy_admin_name),
+ value=global_proxy_spend,
+ )
+
+ if global_proxy_spend is not None:
+ call_info = CallInfo(
+ token=valid_token.token,
+ spend=global_proxy_spend,
+ max_budget=litellm.max_budget,
+ user_id=litellm_proxy_admin_name,
+ team_id=valid_token.team_id,
+ )
+ asyncio.create_task(
+ proxy_logging_obj.budget_alerts(
+ type="proxy_budget",
+ user_info=call_info,
+ )
+ )
+ _ = await common_checks(
+ request_body=request_data,
+ team_object=_team_obj,
+ user_object=user_obj,
+ end_user_object=_end_user_object,
+ general_settings=general_settings,
+ global_proxy_spend=global_proxy_spend,
+ route=route,
+ llm_router=llm_router,
+ proxy_logging_obj=proxy_logging_obj,
+ valid_token=valid_token,
+ )
+ # Token passed all checks
+ if valid_token is None:
+ raise HTTPException(401, detail="Invalid API key")
+ if valid_token.token is None:
+ raise HTTPException(401, detail="Invalid API key, no token associated")
+ api_key = valid_token.token
+
+ # Add hashed token to cache
+ asyncio.create_task(
+ _cache_key_object(
+ hashed_token=api_key,
+ user_api_key_obj=valid_token,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+ )
+
+ valid_token_dict = valid_token.model_dump(exclude_none=True)
+ valid_token_dict.pop("token", None)
+
+ if _end_user_object is not None:
+ valid_token_dict.update(end_user_params)
+
+ # check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
+ # sso/login, ui/login, /key functions and /user functions
+ # this will never be allowed to call /chat/completions
+ token_team = getattr(valid_token, "team_id", None)
+ token_type: Literal["ui", "api"] = (
+ "ui"
+ if token_team is not None and token_team == "litellm-dashboard"
+ else "api"
+ )
+ _is_route_allowed = _is_allowed_route(
+ route=route,
+ token_type=token_type,
+ user_obj=user_obj,
+ request=request,
+ request_data=request_data,
+ api_key=api_key,
+ valid_token=valid_token,
+ )
+ if not _is_route_allowed:
+ raise HTTPException(401, detail="Invalid route for UI token")
+
+ if valid_token is None:
+ # No token was found when looking up in the DB
+ raise Exception("Invalid proxy server token passed")
+ if valid_token_dict is not None:
+ return await _return_user_api_key_auth_obj(
+ user_obj=user_obj,
+ api_key=api_key,
+ parent_otel_span=parent_otel_span,
+ valid_token_dict=valid_token_dict,
+ route=route,
+ start_time=start_time,
+ )
+ else:
+ raise Exception()
+ except Exception as e:
+ requester_ip = _get_request_ip_address(
+ request=request,
+ use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
+ )
+ verbose_proxy_logger.exception(
+ "litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format(
+ str(e),
+ requester_ip,
+ ),
+ extra={"requester_ip": requester_ip},
+ )
+
+ # Log this exception to OTEL, Datadog etc
+ user_api_key_dict = UserAPIKeyAuth(
+ parent_otel_span=parent_otel_span,
+ api_key=api_key,
+ )
+ asyncio.create_task(
+ proxy_logging_obj.post_call_failure_hook(
+ request_data=request_data,
+ original_exception=e,
+ user_api_key_dict=user_api_key_dict,
+ error_type=ProxyErrorTypes.auth_error,
+ route=route,
+ )
+ )
+
+ if isinstance(e, litellm.BudgetExceededError):
+ raise ProxyException(
+ message=e.message,
+ type=ProxyErrorTypes.budget_exceeded,
+ param=None,
+ code=400,
+ )
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "detail", f"Authentication Error({str(e)})"),
+ type=ProxyErrorTypes.auth_error,
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
+ )
+ elif isinstance(e, ProxyException):
+ raise e
+ raise ProxyException(
+ message="Authentication Error, " + str(e),
+ type=ProxyErrorTypes.auth_error,
+ param=getattr(e, "param", "None"),
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+
+@tracer.wrap()
+async def user_api_key_auth(
+ request: Request,
+ api_key: str = fastapi.Security(api_key_header),
+ azure_api_key_header: str = fastapi.Security(azure_api_key_header),
+ anthropic_api_key_header: Optional[str] = fastapi.Security(
+ anthropic_api_key_header
+ ),
+ google_ai_studio_api_key_header: Optional[str] = fastapi.Security(
+ google_ai_studio_api_key_header
+ ),
+ azure_apim_header: Optional[str] = fastapi.Security(azure_apim_header),
+) -> UserAPIKeyAuth:
+ """
+ Parent function to authenticate user api key / jwt token.
+ """
+
+ request_data = await _read_request_body(request=request)
+
+ user_api_key_auth_obj = await _user_api_key_auth_builder(
+ request=request,
+ api_key=api_key,
+ azure_api_key_header=azure_api_key_header,
+ anthropic_api_key_header=anthropic_api_key_header,
+ google_ai_studio_api_key_header=google_ai_studio_api_key_header,
+ azure_apim_header=azure_apim_header,
+ request_data=request_data,
+ )
+
+ end_user_id = get_end_user_id_from_request_body(request_data)
+ if end_user_id is not None:
+ user_api_key_auth_obj.end_user_id = end_user_id
+
+ return user_api_key_auth_obj
+
+
+async def _return_user_api_key_auth_obj(
+ user_obj: Optional[LiteLLM_UserTable],
+ api_key: str,
+ parent_otel_span: Optional[Span],
+ valid_token_dict: dict,
+ route: str,
+ start_time: datetime,
+ user_role: Optional[LitellmUserRoles] = None,
+) -> UserAPIKeyAuth:
+ end_time = datetime.now()
+
+ asyncio.create_task(
+ user_api_key_service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.AUTH,
+ call_type=route,
+ start_time=start_time,
+ end_time=end_time,
+ duration=end_time.timestamp() - start_time.timestamp(),
+ parent_otel_span=parent_otel_span,
+ )
+ )
+
+ retrieved_user_role = (
+ user_role or _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
+ )
+
+ user_api_key_kwargs = {
+ "api_key": api_key,
+ "parent_otel_span": parent_otel_span,
+ "user_role": retrieved_user_role,
+ **valid_token_dict,
+ }
+ if user_obj is not None:
+ user_api_key_kwargs.update(
+ user_tpm_limit=user_obj.tpm_limit,
+ user_rpm_limit=user_obj.rpm_limit,
+ user_email=user_obj.user_email,
+ )
+ if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
+ user_api_key_kwargs.update(
+ user_role=LitellmUserRoles.PROXY_ADMIN,
+ )
+ return UserAPIKeyAuth(**user_api_key_kwargs)
+ else:
+ return UserAPIKeyAuth(**user_api_key_kwargs)
+
+
+def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
+ if user_obj is None:
+ return False
+
+ if (
+ user_obj.user_role is not None
+ and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
+ ):
+ return True
+
+ if (
+ user_obj.user_role is not None
+ and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
+ ):
+ return True
+
+ return False
+
+
+def _get_user_role(
+ user_obj: Optional[LiteLLM_UserTable],
+) -> Optional[LitellmUserRoles]:
+ if user_obj is None:
+ return None
+
+ _user = user_obj
+
+ _user_role = _user.user_role
+ try:
+ role = LitellmUserRoles(_user_role)
+ except ValueError:
+ return LitellmUserRoles.INTERNAL_USER
+
+ return role
+
+
+def get_api_key_from_custom_header(
+ request: Request, custom_litellm_key_header_name: str
+) -> str:
+ """
+ Get API key from custom header
+
+ Args:
+ request (Request): Request object
+ custom_litellm_key_header_name (str): Custom header name
+
+ Returns:
+ Optional[str]: API key
+ """
+ api_key: str = ""
+ # use this as the virtual key passed to litellm proxy
+ custom_litellm_key_header_name = custom_litellm_key_header_name.lower()
+ _headers = {k.lower(): v for k, v in request.headers.items()}
+ verbose_proxy_logger.debug(
+ "searching for custom_litellm_key_header_name= %s, in headers=%s",
+ custom_litellm_key_header_name,
+ _headers,
+ )
+ custom_api_key = _headers.get(custom_litellm_key_header_name)
+ if custom_api_key:
+ api_key = _get_bearer_token(api_key=custom_api_key)
+ verbose_proxy_logger.debug(
+ "Found custom API key using header: {}, setting api_key={}".format(
+ custom_litellm_key_header_name, api_key
+ )
+ )
+ else:
+ verbose_proxy_logger.exception(
+ f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer <api_key>"
+ )
+ return api_key
+
+
+def _get_temp_budget_increase(valid_token: UserAPIKeyAuth):
+ valid_token_metadata = valid_token.metadata
+ if (
+ "temp_budget_increase" in valid_token_metadata
+ and "temp_budget_expiry" in valid_token_metadata
+ ):
+ expiry = datetime.fromisoformat(valid_token_metadata["temp_budget_expiry"])
+ if expiry > datetime.now():
+ return valid_token_metadata["temp_budget_increase"]
+ return None
+
+
+def _update_key_budget_with_temp_budget_increase(
+ valid_token: UserAPIKeyAuth,
+) -> UserAPIKeyAuth:
+ if valid_token.max_budget is None:
+ return valid_token
+ temp_budget_increase = _get_temp_budget_increase(valid_token) or 0.0
+ valid_token.max_budget = valid_token.max_budget + temp_budget_increase
+ return valid_token