aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py1373
1 files changed, 1373 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py
new file mode 100644
index 00000000..f029511d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_checks.py
@@ -0,0 +1,1373 @@
+# What is this?
+## Common auth checks between jwt + key based auth
+"""
+Got Valid Token from Cache, DB
+Run checks for:
+
+1. If user can call model
+2. If user is in budget
+3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
+"""
+import asyncio
+import re
+import time
+import traceback
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
+
+from fastapi import status
+from pydantic import BaseModel
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.caching.dual_cache import LimitedSizeOrderedDict
+from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
+from litellm.proxy._types import (
+ DB_CONNECTION_ERROR_TYPES,
+ RBAC_ROLES,
+ CallInfo,
+ LiteLLM_EndUserTable,
+ LiteLLM_JWTAuth,
+ LiteLLM_OrganizationMembershipTable,
+ LiteLLM_OrganizationTable,
+ LiteLLM_TeamTable,
+ LiteLLM_TeamTableCachedObj,
+ LiteLLM_UserTable,
+ LiteLLMRoutes,
+ LitellmUserRoles,
+ ProxyErrorTypes,
+ ProxyException,
+ RoleBasedPermissions,
+ SpecialModelNames,
+ UserAPIKeyAuth,
+)
+from litellm.proxy.auth.route_checks import RouteChecks
+from litellm.proxy.route_llm_request import route_request
+from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
+from litellm.router import Router
+from litellm.types.services import ServiceTypes
+
+from .auth_checks_organization import organization_role_based_access_check
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ Span = _Span
+else:
+ Span = Any
+
+
+last_db_access_time = LimitedSizeOrderedDict(max_size=100)
+db_cache_expiry = 5 # refresh every 5s
+
+all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
+
+
+async def common_checks(
+ request_body: dict,
+ team_object: Optional[LiteLLM_TeamTable],
+ user_object: Optional[LiteLLM_UserTable],
+ end_user_object: Optional[LiteLLM_EndUserTable],
+ global_proxy_spend: Optional[float],
+ general_settings: dict,
+ route: str,
+ llm_router: Optional[Router],
+ proxy_logging_obj: ProxyLogging,
+ valid_token: Optional[UserAPIKeyAuth],
+) -> bool:
+ """
+ Common checks across jwt + key-based auth.
+
+ 1. If team is blocked
+ 2. If team can call model
+ 3. If team is in budget
+ 4. If user passed in (JWT or key.user_id) - is in budget
+ 5. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
+ 6. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
+ 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
+ 8. [OPTIONAL] If guardrails modified - is request allowed to change this
+ 9. Check if request body is safe
+ 10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
+ """
+ _model = request_body.get("model", None)
+
+ # 1. If team is blocked
+ if team_object is not None and team_object.blocked is True:
+ raise Exception(
+ f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
+ )
+
+ # 2. If team can call model
+ if _model and team_object:
+ if not await can_team_access_model(
+ model=_model,
+ team_object=team_object,
+ llm_router=llm_router,
+ team_model_aliases=valid_token.team_model_aliases if valid_token else None,
+ ):
+ raise ProxyException(
+ message=f"Team not allowed to access model. Team={team_object.team_id}, Model={_model}. Allowed team models = {team_object.models}",
+ type=ProxyErrorTypes.team_model_access_denied,
+ param="model",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+ ## 2.1 If user can call model (if personal key)
+ if team_object is None and user_object is not None:
+ await can_user_call_model(
+ model=_model,
+ llm_router=llm_router,
+ user_object=user_object,
+ )
+
+ # 3. If team is in budget
+ await _team_max_budget_check(
+ team_object=team_object,
+ proxy_logging_obj=proxy_logging_obj,
+ valid_token=valid_token,
+ )
+
+ # 4. If user is in budget
+ ## 4.1 check personal budget, if personal key
+ if (
+ (team_object is None or team_object.team_id is None)
+ and user_object is not None
+ and user_object.max_budget is not None
+ ):
+ user_budget = user_object.max_budget
+ if user_budget < user_object.spend:
+ raise litellm.BudgetExceededError(
+ current_cost=user_object.spend,
+ max_budget=user_budget,
+ message=f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}",
+ )
+
+ ## 4.2 check team member budget, if team key
+ # 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
+ if end_user_object is not None and end_user_object.litellm_budget_table is not None:
+ end_user_budget = end_user_object.litellm_budget_table.max_budget
+ if end_user_budget is not None and end_user_object.spend > end_user_budget:
+ raise litellm.BudgetExceededError(
+ current_cost=end_user_object.spend,
+ max_budget=end_user_budget,
+ message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}",
+ )
+
+ # 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
+ if (
+ general_settings.get("enforce_user_param", None) is not None
+ and general_settings["enforce_user_param"] is True
+ ):
+ if RouteChecks.is_llm_api_route(route=route) and "user" not in request_body:
+ raise Exception(
+ f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
+ )
+ # 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
+ if (
+ litellm.max_budget > 0
+ and global_proxy_spend is not None
+ # only run global budget checks for OpenAI routes
+ # Reason - the Admin UI should continue working if the proxy crosses it's global budget
+ and RouteChecks.is_llm_api_route(route=route)
+ and route != "/v1/models"
+ and route != "/models"
+ ):
+ if global_proxy_spend > litellm.max_budget:
+ raise litellm.BudgetExceededError(
+ current_cost=global_proxy_spend, max_budget=litellm.max_budget
+ )
+
+ _request_metadata: dict = request_body.get("metadata", {}) or {}
+ if _request_metadata.get("guardrails"):
+ # check if team allowed to modify guardrails
+ from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails
+
+ can_modify: bool = can_modify_guardrails(team_object)
+ if can_modify is False:
+ from fastapi import HTTPException
+
+ raise HTTPException(
+ status_code=403,
+ detail={
+ "error": "Your team does not have permission to modify guardrails."
+ },
+ )
+
+ # 10 [OPTIONAL] Organization RBAC checks
+ organization_role_based_access_check(
+ user_object=user_object, route=route, request_body=request_body
+ )
+
+ return True
+
+
+def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
+ """
+ Return if a user is allowed to access route. Helper function for `allowed_routes_check`.
+
+ Parameters:
+ - user_route: str - the route the user is trying to call
+ - allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user.
+ """
+
+ for allowed_route in allowed_routes:
+ if (
+ allowed_route in LiteLLMRoutes.__members__
+ and user_route in LiteLLMRoutes[allowed_route].value
+ ):
+ return True
+ elif allowed_route == user_route:
+ return True
+ return False
+
+
+def allowed_routes_check(
+ user_role: Literal[
+ LitellmUserRoles.PROXY_ADMIN,
+ LitellmUserRoles.TEAM,
+ LitellmUserRoles.INTERNAL_USER,
+ ],
+ user_route: str,
+ litellm_proxy_roles: LiteLLM_JWTAuth,
+) -> bool:
+ """
+ Check if user -> not admin - allowed to access these routes
+ """
+
+ if user_role == LitellmUserRoles.PROXY_ADMIN:
+ is_allowed = _allowed_routes_check(
+ user_route=user_route,
+ allowed_routes=litellm_proxy_roles.admin_allowed_routes,
+ )
+ return is_allowed
+
+ elif user_role == LitellmUserRoles.TEAM:
+ if litellm_proxy_roles.team_allowed_routes is None:
+ """
+ By default allow a team to call openai + info routes
+ """
+ is_allowed = _allowed_routes_check(
+ user_route=user_route, allowed_routes=["openai_routes", "info_routes"]
+ )
+ return is_allowed
+ elif litellm_proxy_roles.team_allowed_routes is not None:
+ is_allowed = _allowed_routes_check(
+ user_route=user_route,
+ allowed_routes=litellm_proxy_roles.team_allowed_routes,
+ )
+ return is_allowed
+ return False
+
+
+def allowed_route_check_inside_route(
+ user_api_key_dict: UserAPIKeyAuth,
+ requested_user_id: Optional[str],
+) -> bool:
+ ret_val = True
+ if (
+ user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
+ and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
+ ):
+ ret_val = False
+ if requested_user_id is not None and user_api_key_dict.user_id is not None:
+ if user_api_key_dict.user_id == requested_user_id:
+ ret_val = True
+ return ret_val
+
+
+def get_actual_routes(allowed_routes: list) -> list:
+ actual_routes: list = []
+ for route_name in allowed_routes:
+ try:
+ route_value = LiteLLMRoutes[route_name].value
+ if isinstance(route_value, set):
+ actual_routes.extend(list(route_value))
+ else:
+ actual_routes.extend(route_value)
+
+ except KeyError:
+ actual_routes.append(route_name)
+ return actual_routes
+
+
+@log_db_metrics
+async def get_end_user_object(
+ end_user_id: Optional[str],
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span] = None,
+ proxy_logging_obj: Optional[ProxyLogging] = None,
+) -> Optional[LiteLLM_EndUserTable]:
+ """
+ Returns end user object, if in db.
+
+ Do a isolated check for end user in table vs. doing a combined key + team + user + end-user check, as key might come in frequently for different end-users. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (end-user).
+ """
+ if prisma_client is None:
+ raise Exception("No db connected")
+
+ if end_user_id is None:
+ return None
+ _key = "end_user_id:{}".format(end_user_id)
+
+ def check_in_budget(end_user_obj: LiteLLM_EndUserTable):
+ if end_user_obj.litellm_budget_table is None:
+ return
+ end_user_budget = end_user_obj.litellm_budget_table.max_budget
+ if end_user_budget is not None and end_user_obj.spend > end_user_budget:
+ raise litellm.BudgetExceededError(
+ current_cost=end_user_obj.spend, max_budget=end_user_budget
+ )
+
+ # check if in cache
+ cached_user_obj = await user_api_key_cache.async_get_cache(key=_key)
+ if cached_user_obj is not None:
+ if isinstance(cached_user_obj, dict):
+ return_obj = LiteLLM_EndUserTable(**cached_user_obj)
+ check_in_budget(end_user_obj=return_obj)
+ return return_obj
+ elif isinstance(cached_user_obj, LiteLLM_EndUserTable):
+ return_obj = cached_user_obj
+ check_in_budget(end_user_obj=return_obj)
+ return return_obj
+ # else, check db
+ try:
+ response = await prisma_client.db.litellm_endusertable.find_unique(
+ where={"user_id": end_user_id},
+ include={"litellm_budget_table": True},
+ )
+
+ if response is None:
+ raise Exception
+
+ # save the end-user object to cache
+ await user_api_key_cache.async_set_cache(
+ key="end_user_id:{}".format(end_user_id), value=response
+ )
+
+ _response = LiteLLM_EndUserTable(**response.dict())
+
+ check_in_budget(end_user_obj=_response)
+
+ return _response
+ except Exception as e: # if end-user not in db
+ if isinstance(e, litellm.BudgetExceededError):
+ raise e
+ return None
+
+
+def model_in_access_group(
+ model: str, team_models: Optional[List[str]], llm_router: Optional[Router]
+) -> bool:
+ from collections import defaultdict
+
+ if team_models is None:
+ return True
+ if model in team_models:
+ return True
+
+ access_groups: dict[str, list[str]] = defaultdict(list)
+ if llm_router:
+ access_groups = llm_router.get_model_access_groups(model_name=model)
+
+ if len(access_groups) > 0: # check if token contains any model access groups
+ for idx, m in enumerate(
+ team_models
+ ): # loop token models, if any of them are an access group add the access group
+ if m in access_groups:
+ return True
+
+ # Filter out models that are access_groups
+ filtered_models = [m for m in team_models if m not in access_groups]
+
+ if model in filtered_models:
+ return True
+
+ return False
+
+
+def _should_check_db(
+ key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int
+) -> bool:
+ """
+ Prevent calling db repeatedly for items that don't exist in the db.
+ """
+ current_time = time.time()
+ # if key doesn't exist in last_db_access_time -> check db
+ if key not in last_db_access_time:
+ return True
+ elif (
+ last_db_access_time[key][0] is not None
+ ): # check db for non-null values (for refresh operations)
+ return True
+ elif last_db_access_time[key][0] is None:
+ if current_time - last_db_access_time[key] >= db_cache_expiry:
+ return True
+ return False
+
+
+def _update_last_db_access_time(
+ key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict
+):
+ last_db_access_time[key] = (value, time.time())
+
+
+def _get_role_based_permissions(
+ rbac_role: RBAC_ROLES,
+ general_settings: dict,
+ key: Literal["models", "routes"],
+) -> Optional[List[str]]:
+ """
+ Get the role based permissions from the general settings.
+ """
+ role_based_permissions = cast(
+ Optional[List[RoleBasedPermissions]],
+ general_settings.get("role_permissions", []),
+ )
+ if role_based_permissions is None:
+ return None
+
+ for role_based_permission in role_based_permissions:
+
+ if role_based_permission.role == rbac_role:
+ return getattr(role_based_permission, key)
+
+ return None
+
+
+def get_role_based_models(
+ rbac_role: RBAC_ROLES,
+ general_settings: dict,
+) -> Optional[List[str]]:
+ """
+ Get the models allowed for a user role.
+
+ Used by JWT Auth.
+ """
+
+ return _get_role_based_permissions(
+ rbac_role=rbac_role,
+ general_settings=general_settings,
+ key="models",
+ )
+
+
+def get_role_based_routes(
+ rbac_role: RBAC_ROLES,
+ general_settings: dict,
+) -> Optional[List[str]]:
+ """
+ Get the routes allowed for a user role.
+ """
+
+ return _get_role_based_permissions(
+ rbac_role=rbac_role,
+ general_settings=general_settings,
+ key="routes",
+ )
+
+
+async def _get_fuzzy_user_object(
+ prisma_client: PrismaClient,
+ sso_user_id: Optional[str] = None,
+ user_email: Optional[str] = None,
+) -> Optional[LiteLLM_UserTable]:
+ """
+ Checks if sso user is in db.
+
+ Called when user id match is not found in db.
+
+ - Check if sso_user_id is user_id in db
+ - Check if sso_user_id is sso_user_id in db
+ - Check if user_email is user_email in db
+ - If not, create new user with user_email and sso_user_id and user_id = sso_user_id
+ """
+ response = None
+ if sso_user_id is not None:
+ response = await prisma_client.db.litellm_usertable.find_unique(
+ where={"sso_user_id": sso_user_id},
+ include={"organization_memberships": True},
+ )
+
+ if response is None and user_email is not None:
+ response = await prisma_client.db.litellm_usertable.find_first(
+ where={"user_email": user_email},
+ include={"organization_memberships": True},
+ )
+
+ if response is not None and sso_user_id is not None: # update sso_user_id
+ asyncio.create_task( # background task to update user with sso id
+ prisma_client.db.litellm_usertable.update(
+ where={"user_id": response.user_id},
+ data={"sso_user_id": sso_user_id},
+ )
+ )
+
+ return response
+
+
+@log_db_metrics
+async def get_user_object(
+ user_id: Optional[str],
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ user_id_upsert: bool,
+ parent_otel_span: Optional[Span] = None,
+ proxy_logging_obj: Optional[ProxyLogging] = None,
+ sso_user_id: Optional[str] = None,
+ user_email: Optional[str] = None,
+) -> Optional[LiteLLM_UserTable]:
+ """
+ - Check if user id in proxy User Table
+ - if valid, return LiteLLM_UserTable object with defined limits
+ - if not, then raise an error
+ """
+
+ if user_id is None:
+ return None
+
+ # check if in cache
+ cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
+ if cached_user_obj is not None:
+ if isinstance(cached_user_obj, dict):
+ return LiteLLM_UserTable(**cached_user_obj)
+ elif isinstance(cached_user_obj, LiteLLM_UserTable):
+ return cached_user_obj
+ # else, check db
+ if prisma_client is None:
+ raise Exception("No db connected")
+ try:
+ db_access_time_key = "user_id:{}".format(user_id)
+ should_check_db = _should_check_db(
+ key=db_access_time_key,
+ last_db_access_time=last_db_access_time,
+ db_cache_expiry=db_cache_expiry,
+ )
+
+ if should_check_db:
+ response = await prisma_client.db.litellm_usertable.find_unique(
+ where={"user_id": user_id}, include={"organization_memberships": True}
+ )
+
+ if response is None:
+ response = await _get_fuzzy_user_object(
+ prisma_client=prisma_client,
+ sso_user_id=sso_user_id,
+ user_email=user_email,
+ )
+
+ else:
+ response = None
+
+ if response is None:
+ if user_id_upsert:
+ response = await prisma_client.db.litellm_usertable.create(
+ data={"user_id": user_id},
+ include={"organization_memberships": True},
+ )
+ else:
+ raise Exception
+
+ if (
+ response.organization_memberships is not None
+ and len(response.organization_memberships) > 0
+ ):
+ # dump each organization membership to type LiteLLM_OrganizationMembershipTable
+ _dumped_memberships = [
+ LiteLLM_OrganizationMembershipTable(**membership.model_dump())
+ for membership in response.organization_memberships
+ if membership is not None
+ ]
+ response.organization_memberships = _dumped_memberships
+
+ _response = LiteLLM_UserTable(**dict(response))
+ response_dict = _response.model_dump()
+
+ # save the user object to cache
+ await user_api_key_cache.async_set_cache(key=user_id, value=response_dict)
+
+ # save to db access time
+ _update_last_db_access_time(
+ key=db_access_time_key,
+ value=response_dict,
+ last_db_access_time=last_db_access_time,
+ )
+
+ return _response
+ except Exception as e: # if user not in db
+ raise ValueError(
+ f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call. Got error - {e}"
+ )
+
+
+async def _cache_management_object(
+ key: str,
+ value: BaseModel,
+ user_api_key_cache: DualCache,
+ proxy_logging_obj: Optional[ProxyLogging],
+):
+ await user_api_key_cache.async_set_cache(key=key, value=value)
+
+
+async def _cache_team_object(
+ team_id: str,
+ team_table: LiteLLM_TeamTableCachedObj,
+ user_api_key_cache: DualCache,
+ proxy_logging_obj: Optional[ProxyLogging],
+):
+ key = "team_id:{}".format(team_id)
+
+ ## CACHE REFRESH TIME!
+ team_table.last_refreshed_at = time.time()
+
+ await _cache_management_object(
+ key=key,
+ value=team_table,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+
+async def _cache_key_object(
+ hashed_token: str,
+ user_api_key_obj: UserAPIKeyAuth,
+ user_api_key_cache: DualCache,
+ proxy_logging_obj: Optional[ProxyLogging],
+):
+ key = hashed_token
+
+ ## CACHE REFRESH TIME
+ user_api_key_obj.last_refreshed_at = time.time()
+
+ await _cache_management_object(
+ key=key,
+ value=user_api_key_obj,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+
+async def _delete_cache_key_object(
+ hashed_token: str,
+ user_api_key_cache: DualCache,
+ proxy_logging_obj: Optional[ProxyLogging],
+):
+ key = hashed_token
+
+ user_api_key_cache.delete_cache(key=key)
+
+ ## UPDATE REDIS CACHE ##
+ if proxy_logging_obj is not None:
+ await proxy_logging_obj.internal_usage_cache.dual_cache.async_delete_cache(
+ key=key
+ )
+
+
+@log_db_metrics
+async def _get_team_db_check(
+ team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None
+):
+ response = await prisma_client.db.litellm_teamtable.find_unique(
+ where={"team_id": team_id}
+ )
+
+ if response is None and team_id_upsert:
+ response = await prisma_client.db.litellm_teamtable.create(
+ data={"team_id": team_id}
+ )
+
+ return response
+
+
+async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
+ return await prisma_client.db.litellm_teamtable.find_unique(
+ where={"team_id": team_id}
+ )
+
+
+async def _get_team_object_from_user_api_key_cache(
+ team_id: str,
+ prisma_client: PrismaClient,
+ user_api_key_cache: DualCache,
+ last_db_access_time: LimitedSizeOrderedDict,
+ db_cache_expiry: int,
+ proxy_logging_obj: Optional[ProxyLogging],
+ key: str,
+ team_id_upsert: Optional[bool] = None,
+) -> LiteLLM_TeamTableCachedObj:
+ db_access_time_key = key
+ should_check_db = _should_check_db(
+ key=db_access_time_key,
+ last_db_access_time=last_db_access_time,
+ db_cache_expiry=db_cache_expiry,
+ )
+ if should_check_db:
+ response = await _get_team_db_check(
+ team_id=team_id, prisma_client=prisma_client, team_id_upsert=team_id_upsert
+ )
+ else:
+ response = None
+
+ if response is None:
+ raise Exception
+
+ _response = LiteLLM_TeamTableCachedObj(**response.dict())
+ # save the team object to cache
+ await _cache_team_object(
+ team_id=team_id,
+ team_table=_response,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+ # save to db access time
+ # save to db access time
+ _update_last_db_access_time(
+ key=db_access_time_key,
+ value=_response,
+ last_db_access_time=last_db_access_time,
+ )
+
+ return _response
+
+
+async def _get_team_object_from_cache(
+ key: str,
+ proxy_logging_obj: Optional[ProxyLogging],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span],
+) -> Optional[LiteLLM_TeamTableCachedObj]:
+ cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None
+
+ ## CHECK REDIS CACHE ##
+ if (
+ proxy_logging_obj is not None
+ and proxy_logging_obj.internal_usage_cache.dual_cache
+ ):
+
+ cached_team_obj = (
+ await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
+ key=key, parent_otel_span=parent_otel_span
+ )
+ )
+
+ if cached_team_obj is None:
+ cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
+
+ if cached_team_obj is not None:
+ if isinstance(cached_team_obj, dict):
+ return LiteLLM_TeamTableCachedObj(**cached_team_obj)
+ elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj):
+ return cached_team_obj
+
+ return None
+
+
+async def get_team_object(
+ team_id: str,
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span] = None,
+ proxy_logging_obj: Optional[ProxyLogging] = None,
+ check_cache_only: Optional[bool] = None,
+ check_db_only: Optional[bool] = None,
+ team_id_upsert: Optional[bool] = None,
+) -> LiteLLM_TeamTableCachedObj:
+ """
+ - Check if team id in proxy Team Table
+ - if valid, return LiteLLM_TeamTable object with defined limits
+ - if not, then raise an error
+
+ Raises:
+ - Exception: If team doesn't exist in db or cache
+ """
+ if prisma_client is None:
+ raise Exception(
+ "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
+ )
+
+ # check if in cache
+ key = "team_id:{}".format(team_id)
+
+ if not check_db_only:
+ cached_team_obj = await _get_team_object_from_cache(
+ key=key,
+ proxy_logging_obj=proxy_logging_obj,
+ user_api_key_cache=user_api_key_cache,
+ parent_otel_span=parent_otel_span,
+ )
+
+ if cached_team_obj is not None:
+ return cached_team_obj
+
+ if check_cache_only:
+ raise Exception(
+ f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
+ )
+
+ # else, check db
+ try:
+ return await _get_team_object_from_user_api_key_cache(
+ team_id=team_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ last_db_access_time=last_db_access_time,
+ db_cache_expiry=db_cache_expiry,
+ key=key,
+ team_id_upsert=team_id_upsert,
+ )
+ except Exception:
+ raise Exception(
+ f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
+ )
+
+
+@log_db_metrics
+async def get_key_object(
+ hashed_token: str,
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span] = None,
+ proxy_logging_obj: Optional[ProxyLogging] = None,
+ check_cache_only: Optional[bool] = None,
+) -> UserAPIKeyAuth:
+ """
+ - Check if team id in proxy Team Table
+ - if valid, return LiteLLM_TeamTable object with defined limits
+ - if not, then raise an error
+ """
+ if prisma_client is None:
+ raise Exception(
+ "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
+ )
+
+ # check if in cache
+ key = hashed_token
+
+ cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache(
+ key=key
+ )
+
+ if cached_key_obj is not None:
+ if isinstance(cached_key_obj, dict):
+ return UserAPIKeyAuth(**cached_key_obj)
+ elif isinstance(cached_key_obj, UserAPIKeyAuth):
+ return cached_key_obj
+
+ if check_cache_only:
+ raise Exception(
+ f"Key doesn't exist in cache + check_cache_only=True. key={key}."
+ )
+
+ # else, check db
+ try:
+ _valid_token: Optional[BaseModel] = await prisma_client.get_data(
+ token=hashed_token,
+ table_name="combined_view",
+ parent_otel_span=parent_otel_span,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+ if _valid_token is None:
+ raise Exception
+
+ _response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
+
+ # save the key object to cache
+ await _cache_key_object(
+ hashed_token=hashed_token,
+ user_api_key_obj=_response,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+ return _response
+ except DB_CONNECTION_ERROR_TYPES as e:
+ return await _handle_failed_db_connection_for_get_key_object(e=e)
+ except Exception:
+ traceback.print_exc()
+ raise Exception(
+ f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
+ )
+
+
+async def _handle_failed_db_connection_for_get_key_object(
+ e: Exception,
+) -> UserAPIKeyAuth:
+ """
+ Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB
+
+ Use this if you don't want failed DB queries to block LLM API reqiests
+
+ Returns:
+ - UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
+
+ Raises:
+ - Orignal Exception in all other cases
+ """
+ from litellm.proxy.proxy_server import (
+ general_settings,
+ litellm_proxy_admin_name,
+ proxy_logging_obj,
+ )
+
+ # If this flag is on, requests failing to connect to the DB will be allowed
+ if general_settings.get("allow_requests_on_db_unavailable", False) is True:
+ # log this as a DB failure on prometheus
+ proxy_logging_obj.service_logging_obj.service_failure_hook(
+ service=ServiceTypes.DB,
+ call_type="get_key_object",
+ error=e,
+ duration=0.0,
+ )
+
+ return UserAPIKeyAuth(
+ key_name="failed-to-connect-to-db",
+ token="failed-to-connect-to-db",
+ user_id=litellm_proxy_admin_name,
+ )
+ else:
+ # raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus
+ raise e
+
+
+@log_db_metrics
+async def get_org_object(
+ org_id: str,
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ parent_otel_span: Optional[Span] = None,
+ proxy_logging_obj: Optional[ProxyLogging] = None,
+) -> Optional[LiteLLM_OrganizationTable]:
+ """
+ - Check if org id in proxy Org Table
+ - if valid, return LiteLLM_OrganizationTable object
+ - if not, then raise an error
+ """
+ if prisma_client is None:
+ raise Exception(
+ "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
+ )
+
+ # check if in cache
+ cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
+ if cached_org_obj is not None:
+ if isinstance(cached_org_obj, dict):
+ return LiteLLM_OrganizationTable(**cached_org_obj)
+ elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
+ return cached_org_obj
+ # else, check db
+ try:
+ response = await prisma_client.db.litellm_organizationtable.find_unique(
+ where={"organization_id": org_id}
+ )
+
+ if response is None:
+ raise Exception
+
+ return response
+ except Exception:
+ raise Exception(
+ f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
+ )
+
+
+async def _can_object_call_model(
+ model: str,
+ llm_router: Optional[Router],
+ models: List[str],
+ team_model_aliases: Optional[Dict[str, str]] = None,
+ object_type: Literal["user", "team", "key"] = "user",
+) -> Literal[True]:
+ """
+ Checks if token can call a given model
+
+ Args:
+ - model: str
+ - llm_router: Optional[Router]
+ - models: List[str]
+ - team_model_aliases: Optional[Dict[str, str]]
+ - object_type: Literal["user", "team", "key"]. We use the object type to raise the correct exception type
+
+ Returns:
+ - True: if token allowed to call model
+
+ Raises:
+ - Exception: If token not allowed to call model
+ """
+ if model in litellm.model_alias_map:
+ model = litellm.model_alias_map[model]
+
+ ## check if model in allowed model names
+ from collections import defaultdict
+
+ access_groups: Dict[str, List[str]] = defaultdict(list)
+
+ if llm_router:
+ access_groups = llm_router.get_model_access_groups(model_name=model)
+ if (
+ len(access_groups) > 0 and llm_router is not None
+ ): # check if token contains any model access groups
+ for idx, m in enumerate(
+ models
+ ): # loop token models, if any of them are an access group add the access group
+ if m in access_groups:
+ return True
+
+ # Filter out models that are access_groups
+ filtered_models = [m for m in models if m not in access_groups]
+
+ verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
+
+ if _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases):
+ return True
+
+ if _model_matches_any_wildcard_pattern_in_list(
+ model=model, allowed_model_list=filtered_models
+ ):
+ return True
+
+ all_model_access: bool = False
+
+ if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models:
+ all_model_access = True
+
+ if SpecialModelNames.all_proxy_models.value in filtered_models:
+ all_model_access = True
+
+ if model is not None and model not in filtered_models and all_model_access is False:
+ raise ProxyException(
+ message=f"{object_type} not allowed to access model. This {object_type} can only access models={models}. Tried to access {model}",
+ type=ProxyErrorTypes.get_model_access_error_type_for_object(
+ object_type=object_type
+ ),
+ param="model",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+ verbose_proxy_logger.debug(
+ f"filtered allowed_models: {filtered_models}; models: {models}"
+ )
+ return True
+
+
+def _model_in_team_aliases(
+ model: str, team_model_aliases: Optional[Dict[str, str]] = None
+) -> bool:
+ """
+ Returns True if `model` being accessed is an alias of a team model
+
+ - `model=gpt-4o`
+ - `team_model_aliases={"gpt-4o": "gpt-4o-team-1"}`
+ - returns True
+
+ - `model=gp-4o`
+ - `team_model_aliases={"o-3": "o3-preview"}`
+ - returns False
+ """
+ if team_model_aliases:
+ if model in team_model_aliases:
+ return True
+ return False
+
+
+async def can_key_call_model(
+ model: str,
+ llm_model_list: Optional[list],
+ valid_token: UserAPIKeyAuth,
+ llm_router: Optional[litellm.Router],
+) -> Literal[True]:
+ """
+ Checks if token can call a given model
+
+ Returns:
+ - True: if token allowed to call model
+
+ Raises:
+ - Exception: If token not allowed to call model
+ """
+ return await _can_object_call_model(
+ model=model,
+ llm_router=llm_router,
+ models=valid_token.models,
+ team_model_aliases=valid_token.team_model_aliases,
+ object_type="key",
+ )
+
+
+async def can_team_access_model(
+ model: str,
+ team_object: Optional[LiteLLM_TeamTable],
+ llm_router: Optional[Router],
+ team_model_aliases: Optional[Dict[str, str]] = None,
+) -> Literal[True]:
+ """
+ Returns True if the team can access a specific model.
+
+ """
+ return await _can_object_call_model(
+ model=model,
+ llm_router=llm_router,
+ models=team_object.models if team_object else [],
+ team_model_aliases=team_model_aliases,
+ object_type="team",
+ )
+
+
+async def can_user_call_model(
+ model: str,
+ llm_router: Optional[Router],
+ user_object: Optional[LiteLLM_UserTable],
+) -> Literal[True]:
+
+ if user_object is None:
+ return True
+
+ if SpecialModelNames.no_default_models.value in user_object.models:
+ raise ProxyException(
+ message=f"User not allowed to access model. No default model access, only team models allowed. Tried to access {model}",
+ type=ProxyErrorTypes.key_model_access_denied,
+ param="model",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+ return await _can_object_call_model(
+ model=model,
+ llm_router=llm_router,
+ models=user_object.models,
+ object_type="user",
+ )
+
+
+async def is_valid_fallback_model(
+ model: str,
+ llm_router: Optional[Router],
+ user_model: Optional[str],
+) -> Literal[True]:
+ """
+ Try to route the fallback model request.
+
+ Validate if it can't be routed.
+
+ Help catch invalid fallback models.
+ """
+ await route_request(
+ data={
+ "model": model,
+ "messages": [{"role": "user", "content": "Who was Alexander?"}],
+ },
+ llm_router=llm_router,
+ user_model=user_model,
+ route_type="acompletion", # route type shouldn't affect the fallback model check
+ )
+
+ return True
+
+
+async def _virtual_key_max_budget_check(
+ valid_token: UserAPIKeyAuth,
+ proxy_logging_obj: ProxyLogging,
+ user_obj: Optional[LiteLLM_UserTable] = None,
+):
+ """
+ Raises:
+ BudgetExceededError if the token is over it's max budget.
+ Triggers a budget alert if the token is over it's max budget.
+
+ """
+ if valid_token.spend is not None and valid_token.max_budget is not None:
+ ####################################
+ # collect information for alerting #
+ ####################################
+
+ user_email = None
+ # Check if the token has any user id information
+ if user_obj is not None:
+ user_email = user_obj.user_email
+
+ call_info = CallInfo(
+ token=valid_token.token,
+ spend=valid_token.spend,
+ max_budget=valid_token.max_budget,
+ user_id=valid_token.user_id,
+ team_id=valid_token.team_id,
+ user_email=user_email,
+ key_alias=valid_token.key_alias,
+ )
+ asyncio.create_task(
+ proxy_logging_obj.budget_alerts(
+ type="token_budget",
+ user_info=call_info,
+ )
+ )
+
+ ####################################
+ # collect information for alerting #
+ ####################################
+
+ if valid_token.spend >= valid_token.max_budget:
+ raise litellm.BudgetExceededError(
+ current_cost=valid_token.spend,
+ max_budget=valid_token.max_budget,
+ )
+
+
+async def _virtual_key_soft_budget_check(
+ valid_token: UserAPIKeyAuth,
+ proxy_logging_obj: ProxyLogging,
+):
+ """
+ Triggers a budget alert if the token is over it's soft budget.
+
+ """
+
+ if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
+ verbose_proxy_logger.debug(
+ "Crossed Soft Budget for token %s, spend %s, soft_budget %s",
+ valid_token.token,
+ valid_token.spend,
+ valid_token.soft_budget,
+ )
+ call_info = CallInfo(
+ token=valid_token.token,
+ spend=valid_token.spend,
+ max_budget=valid_token.max_budget,
+ soft_budget=valid_token.soft_budget,
+ user_id=valid_token.user_id,
+ team_id=valid_token.team_id,
+ team_alias=valid_token.team_alias,
+ user_email=None,
+ key_alias=valid_token.key_alias,
+ )
+ asyncio.create_task(
+ proxy_logging_obj.budget_alerts(
+ type="soft_budget",
+ user_info=call_info,
+ )
+ )
+
+
+async def _team_max_budget_check(
+ team_object: Optional[LiteLLM_TeamTable],
+ valid_token: Optional[UserAPIKeyAuth],
+ proxy_logging_obj: ProxyLogging,
+):
+ """
+ Check if the team is over it's max budget.
+
+ Raises:
+ BudgetExceededError if the team is over it's max budget.
+ Triggers a budget alert if the team is over it's max budget.
+ """
+ if (
+ team_object is not None
+ and team_object.max_budget is not None
+ and team_object.spend is not None
+ and team_object.spend > team_object.max_budget
+ ):
+ if valid_token:
+ call_info = CallInfo(
+ token=valid_token.token,
+ spend=team_object.spend,
+ max_budget=team_object.max_budget,
+ user_id=valid_token.user_id,
+ team_id=valid_token.team_id,
+ team_alias=valid_token.team_alias,
+ )
+ asyncio.create_task(
+ proxy_logging_obj.budget_alerts(
+ type="team_budget",
+ user_info=call_info,
+ )
+ )
+
+ raise litellm.BudgetExceededError(
+ current_cost=team_object.spend,
+ max_budget=team_object.max_budget,
+ message=f"Budget has been exceeded! Team={team_object.team_id} Current cost: {team_object.spend}, Max budget: {team_object.max_budget}",
+ )
+
+
+def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool:
+ """
+ Check if a model matches an allowed pattern.
+ Handles exact matches and wildcard patterns.
+
+ Args:
+ model (str): The model to check (e.g., "bedrock/anthropic.claude-3-5-sonnet-20240620")
+ allowed_model_pattern (str): The allowed pattern (e.g., "bedrock/*", "*", "openai/*")
+
+ Returns:
+ bool: True if model matches the pattern, False otherwise
+ """
+ if "*" in allowed_model_pattern:
+ pattern = f"^{allowed_model_pattern.replace('*', '.*')}$"
+ return bool(re.match(pattern, model))
+
+ return False
+
+
+def _model_matches_any_wildcard_pattern_in_list(
+ model: str, allowed_model_list: list
+) -> bool:
+ """
+ Returns True if a model matches any wildcard pattern in a list.
+
+ eg.
+ - model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns True
+ - model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True
+ - model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False
+ """
+
+ if any(
+ _is_wildcard_pattern(allowed_model_pattern)
+ and is_model_allowed_by_pattern(
+ model=model, allowed_model_pattern=allowed_model_pattern
+ )
+ for allowed_model_pattern in allowed_model_list
+ ):
+ return True
+
+ if any(
+ _is_wildcard_pattern(allowed_model_pattern)
+ and _model_custom_llm_provider_matches_wildcard_pattern(
+ model=model, allowed_model_pattern=allowed_model_pattern
+ )
+ for allowed_model_pattern in allowed_model_list
+ ):
+ return True
+
+ return False
+
+
+def _model_custom_llm_provider_matches_wildcard_pattern(
+ model: str, allowed_model_pattern: str
+) -> bool:
+ """
+ Returns True for this scenario:
+ - `model=gpt-4o`
+ - `allowed_model_pattern=openai/*`
+
+ or
+ - `model=claude-3-5-sonnet-20240620`
+ - `allowed_model_pattern=anthropic/*`
+ """
+ try:
+ model, custom_llm_provider, _, _ = get_llm_provider(model=model)
+ except Exception:
+ return False
+
+ return is_model_allowed_by_pattern(
+ model=f"{custom_llm_provider}/{model}",
+ allowed_model_pattern=allowed_model_pattern,
+ )
+
+
+def _is_wildcard_pattern(allowed_model_pattern: str) -> bool:
+ """
+ Returns True if the pattern is a wildcard pattern.
+
+ Checks if `*` is in the pattern.
+ """
+ return "*" in allowed_model_pattern