diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/auth/handle_jwt.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/auth/handle_jwt.py | 1001 |
1 files changed, 1001 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/handle_jwt.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/handle_jwt.py new file mode 100644 index 00000000..cc410501 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/handle_jwt.py @@ -0,0 +1,1001 @@ +""" +Supports using JWT's for authenticating into the proxy. + +Currently only supports admin. + +JWT token must have 'litellm_proxy_admin' in scope. +""" + +import json +import os +from typing import Any, List, Literal, Optional, Set, Tuple, cast + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from fastapi import HTTPException + +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value +from litellm.llms.custom_httpx.httpx_handler import HTTPHandler +from litellm.proxy._types import ( + RBAC_ROLES, + JWKKeyValue, + JWTAuthBuilderResult, + JWTKeyItem, + LiteLLM_EndUserTable, + LiteLLM_JWTAuth, + LiteLLM_OrganizationTable, + LiteLLM_TeamTable, + LiteLLM_UserTable, + LitellmUserRoles, + ScopeMapping, + Span, +) +from litellm.proxy.auth.auth_checks import can_team_access_model +from litellm.proxy.utils import PrismaClient, ProxyLogging + +from .auth_checks import ( + _allowed_routes_check, + allowed_routes_check, + get_actual_routes, + get_end_user_object, + get_org_object, + get_role_based_models, + get_role_based_routes, + get_team_object, + get_user_object, +) + + +class JWTHandler: + """ + - treat the sub id passed in as the user id + - return an error if id making request doesn't exist in proxy user table + - track spend against the user id + - if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets + """ + + prisma_client: Optional[PrismaClient] + user_api_key_cache: DualCache + + def __init__( + self, + ) -> None: + self.http_handler = HTTPHandler() + self.leeway = 0 + + def update_environment( + self, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + litellm_jwtauth: LiteLLM_JWTAuth, + leeway: int = 0, + ) -> None: + self.prisma_client = prisma_client + self.user_api_key_cache = user_api_key_cache + self.litellm_jwtauth = litellm_jwtauth + self.leeway = leeway + + def is_jwt(self, token: str): + parts = token.split(".") + return len(parts) == 3 + + def _rbac_role_from_role_mapping(self, token: dict) -> Optional[RBAC_ROLES]: + """ + Returns the RBAC role the token 'belongs' to based on role mappings. + + Args: + token (dict): The JWT token containing role information + + Returns: + Optional[RBAC_ROLES]: The mapped internal RBAC role if a mapping exists, + None otherwise + + Note: + The function handles both single string roles and lists of roles from the JWT. + If multiple mappings match the JWT roles, the first matching mapping is returned. + """ + if self.litellm_jwtauth.role_mappings is None: + return None + + jwt_role = self.get_jwt_role(token=token, default_value=None) + if not jwt_role: + return None + + jwt_role_set = set(jwt_role) + + for role_mapping in self.litellm_jwtauth.role_mappings: + # Check if the mapping role matches any of the JWT roles + if role_mapping.role in jwt_role_set: + return role_mapping.internal_role + + return None + + def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]: + """ + Returns the RBAC role the token 'belongs' to. + + RBAC roles allowed to make requests: + - PROXY_ADMIN: can make requests to all routes + - TEAM: can make requests to routes associated with a team + - INTERNAL_USER: can make requests to routes associated with a user + + Resolves: https://github.com/BerriAI/litellm/issues/6793 + + Returns: + - PROXY_ADMIN: if token is admin + - TEAM: if token is associated with a team + - INTERNAL_USER: if token is associated with a user + - None: if token is not associated with a team or user + """ + scopes = self.get_scopes(token=token) + is_admin = self.is_admin(scopes=scopes) + user_roles = self.get_user_roles(token=token, default_value=None) + + if is_admin: + return LitellmUserRoles.PROXY_ADMIN + elif self.get_team_id(token=token, default_value=None) is not None: + return LitellmUserRoles.TEAM + elif self.get_user_id(token=token, default_value=None) is not None: + return LitellmUserRoles.INTERNAL_USER + elif user_roles is not None and self.is_allowed_user_role( + user_roles=user_roles + ): + return LitellmUserRoles.INTERNAL_USER + elif rbac_role := self._rbac_role_from_role_mapping(token=token): + return rbac_role + + return None + + def is_admin(self, scopes: list) -> bool: + if self.litellm_jwtauth.admin_jwt_scope in scopes: + return True + return False + + def get_team_ids_from_jwt(self, token: dict) -> List[str]: + if ( + self.litellm_jwtauth.team_ids_jwt_field is not None + and token.get(self.litellm_jwtauth.team_ids_jwt_field) is not None + ): + return token[self.litellm_jwtauth.team_ids_jwt_field] + return [] + + def get_end_user_id( + self, token: dict, default_value: Optional[str] + ) -> Optional[str]: + try: + + if self.litellm_jwtauth.end_user_id_jwt_field is not None: + user_id = token[self.litellm_jwtauth.end_user_id_jwt_field] + else: + user_id = None + except KeyError: + user_id = default_value + + return user_id + + def is_required_team_id(self) -> bool: + """ + Returns: + - True: if 'team_id_jwt_field' is set + - False: if not + """ + if self.litellm_jwtauth.team_id_jwt_field is None: + return False + return True + + def is_enforced_email_domain(self) -> bool: + """ + Returns: + - True: if 'user_allowed_email_domain' is set + - False: if 'user_allowed_email_domain' is None + """ + + if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance( + self.litellm_jwtauth.user_allowed_email_domain, str + ): + return True + return False + + def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + if self.litellm_jwtauth.team_id_jwt_field is not None: + team_id = token[self.litellm_jwtauth.team_id_jwt_field] + elif self.litellm_jwtauth.team_id_default is not None: + team_id = self.litellm_jwtauth.team_id_default + else: + team_id = None + except KeyError: + team_id = default_value + return team_id + + def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool: + """ + Returns: + - True: if 'user_id_upsert' is set AND valid_user_email is not False + - False: if not + """ + if valid_user_email is False: + return False + return self.litellm_jwtauth.user_id_upsert + + def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + if self.litellm_jwtauth.user_id_jwt_field is not None: + user_id = token[self.litellm_jwtauth.user_id_jwt_field] + else: + user_id = default_value + except KeyError: + user_id = default_value + return user_id + + def get_user_roles( + self, token: dict, default_value: Optional[List[str]] + ) -> Optional[List[str]]: + """ + Returns the user role from the token. + + Set via 'user_roles_jwt_field' in the config. + """ + try: + if self.litellm_jwtauth.user_roles_jwt_field is not None: + user_roles = get_nested_value( + data=token, + key_path=self.litellm_jwtauth.user_roles_jwt_field, + default=default_value, + ) + else: + user_roles = default_value + except KeyError: + user_roles = default_value + return user_roles + + def get_jwt_role( + self, token: dict, default_value: Optional[List[str]] + ) -> Optional[List[str]]: + """ + Generic implementation of `get_user_roles` that can be used for both user and team roles. + + Returns the jwt role from the token. + + Set via 'roles_jwt_field' in the config. + """ + try: + if self.litellm_jwtauth.roles_jwt_field is not None: + user_roles = get_nested_value( + data=token, + key_path=self.litellm_jwtauth.roles_jwt_field, + default=default_value, + ) + else: + user_roles = default_value + except KeyError: + user_roles = default_value + return user_roles + + def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool: + """ + Returns the user role from the token. + + Set via 'user_allowed_roles' in the config. + """ + if ( + user_roles is not None + and self.litellm_jwtauth.user_allowed_roles is not None + and any( + role in self.litellm_jwtauth.user_allowed_roles for role in user_roles + ) + ): + return True + return False + + def get_user_email( + self, token: dict, default_value: Optional[str] + ) -> Optional[str]: + try: + if self.litellm_jwtauth.user_email_jwt_field is not None: + user_email = token[self.litellm_jwtauth.user_email_jwt_field] + else: + user_email = None + except KeyError: + user_email = default_value + return user_email + + def get_object_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + if self.litellm_jwtauth.object_id_jwt_field is not None: + object_id = token[self.litellm_jwtauth.object_id_jwt_field] + else: + object_id = default_value + except KeyError: + object_id = default_value + return object_id + + def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + if self.litellm_jwtauth.org_id_jwt_field is not None: + org_id = token[self.litellm_jwtauth.org_id_jwt_field] + else: + org_id = None + except KeyError: + org_id = default_value + return org_id + + def get_scopes(self, token: dict) -> List[str]: + try: + if isinstance(token["scope"], str): + # Assuming the scopes are stored in 'scope' claim and are space-separated + scopes = token["scope"].split() + elif isinstance(token["scope"], list): + scopes = token["scope"] + else: + raise Exception( + f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str." + ) + except KeyError: + scopes = [] + return scopes + + async def get_public_key(self, kid: Optional[str]) -> dict: + + keys_url = os.getenv("JWT_PUBLIC_KEY_URL") + + if keys_url is None: + raise Exception("Missing JWT Public Key URL from environment.") + + keys_url_list = [url.strip() for url in keys_url.split(",")] + + for key_url in keys_url_list: + + cache_key = f"litellm_jwt_auth_keys_{key_url}" + + cached_keys = await self.user_api_key_cache.async_get_cache(cache_key) + + if cached_keys is None: + response = await self.http_handler.get(key_url) + + response_json = response.json() + if "keys" in response_json: + keys: JWKKeyValue = response.json()["keys"] + else: + keys = response_json + + await self.user_api_key_cache.async_set_cache( + key=cache_key, + value=keys, + ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins + ) + else: + keys = cached_keys + + public_key = self.parse_keys(keys=keys, kid=kid) + if public_key is not None: + return cast(dict, public_key) + + raise Exception( + f"No matching public key found. keys={keys_url_list}, kid={kid}" + ) + + def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]: + public_key: Optional[JWTKeyItem] = None + if len(keys) == 1: + if isinstance(keys, dict) and (keys.get("kid", None) == kid or kid is None): + public_key = keys + elif isinstance(keys, list) and ( + keys[0].get("kid", None) == kid or kid is None + ): + public_key = keys[0] + elif len(keys) > 1: + for key in keys: + if isinstance(key, dict): + key_kid = key.get("kid", None) + else: + key_kid = None + if ( + kid is not None + and isinstance(key, dict) + and key_kid is not None + and key_kid == kid + ): + public_key = key + + return public_key + + def is_allowed_domain(self, user_email: str) -> bool: + if self.litellm_jwtauth.user_allowed_email_domain is None: + return True + + email_domain = user_email.split("@")[-1] # Extract domain from email + if email_domain == self.litellm_jwtauth.user_allowed_email_domain: + return True + else: + return False + + async def auth_jwt(self, token: str) -> dict: + # Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html + # "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret + # the key in different ways (e.g. HS* and RS*)." + algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"] + + audience = os.getenv("JWT_AUDIENCE") + decode_options = None + if audience is None: + decode_options = {"verify_aud": False} + + import jwt + from jwt.algorithms import RSAAlgorithm + + header = jwt.get_unverified_header(token) + + verbose_proxy_logger.debug("header: %s", header) + + kid = header.get("kid", None) + + public_key = await self.get_public_key(kid=kid) + + if public_key is not None and isinstance(public_key, dict): + jwk = {} + if "kty" in public_key: + jwk["kty"] = public_key["kty"] + if "kid" in public_key: + jwk["kid"] = public_key["kid"] + if "n" in public_key: + jwk["n"] = public_key["n"] + if "e" in public_key: + jwk["e"] = public_key["e"] + + public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk)) + + try: + # decode the token using the public key + payload = jwt.decode( + token, + public_key_rsa, # type: ignore + algorithms=algorithms, + options=decode_options, + audience=audience, + leeway=self.leeway, # allow testing of expired tokens + ) + return payload + + except jwt.ExpiredSignatureError: + # the token is expired, do something to refresh it + raise Exception("Token Expired") + except Exception as e: + raise Exception(f"Validation fails: {str(e)}") + elif public_key is not None and isinstance(public_key, str): + try: + cert = x509.load_pem_x509_certificate( + public_key.encode(), default_backend() + ) + + # Extract public key + key = cert.public_key().public_bytes( + serialization.Encoding.PEM, + serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + # decode the token using the public key + payload = jwt.decode( + token, + key, + algorithms=algorithms, + audience=audience, + options=decode_options, + ) + return payload + + except jwt.ExpiredSignatureError: + # the token is expired, do something to refresh it + raise Exception("Token Expired") + except Exception as e: + raise Exception(f"Validation fails: {str(e)}") + + raise Exception("Invalid JWT Submitted") + + async def close(self): + await self.http_handler.close() + + +class JWTAuthManager: + """Manages JWT authentication and authorization operations""" + + @staticmethod + def can_rbac_role_call_route( + rbac_role: RBAC_ROLES, + general_settings: dict, + route: str, + ) -> Literal[True]: + """ + Checks if user is allowed to access the route, based on their role. + """ + role_based_routes = get_role_based_routes( + rbac_role=rbac_role, general_settings=general_settings + ) + + if role_based_routes is None or route is None: + return True + + is_allowed = _allowed_routes_check( + user_route=route, + allowed_routes=role_based_routes, + ) + + if not is_allowed: + raise HTTPException( + status_code=403, + detail=f"Role={rbac_role} not allowed to call route={route}. Allowed routes={role_based_routes}", + ) + + return True + + @staticmethod + def can_rbac_role_call_model( + rbac_role: RBAC_ROLES, + general_settings: dict, + model: Optional[str], + ) -> Literal[True]: + """ + Checks if user is allowed to access the model, based on their role. + """ + role_based_models = get_role_based_models( + rbac_role=rbac_role, general_settings=general_settings + ) + if role_based_models is None or model is None: + return True + + if model not in role_based_models: + raise HTTPException( + status_code=403, + detail=f"Role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}", + ) + + return True + + @staticmethod + def check_scope_based_access( + scope_mappings: List[ScopeMapping], + scopes: List[str], + request_data: dict, + general_settings: dict, + ) -> None: + """ + Check if scope allows access to the requested model + """ + if not scope_mappings: + return None + + allowed_models = [] + for sm in scope_mappings: + if sm.scope in scopes and sm.models: + allowed_models.extend(sm.models) + + requested_model = request_data.get("model") + + if not requested_model: + return None + + if requested_model not in allowed_models: + raise HTTPException( + status_code=403, + detail={ + "error": "model={} not allowed. Allowed_models={}".format( + requested_model, allowed_models + ) + }, + ) + return None + + @staticmethod + async def check_rbac_role( + jwt_handler: JWTHandler, + jwt_valid_token: dict, + general_settings: dict, + request_data: dict, + route: str, + rbac_role: Optional[RBAC_ROLES], + ) -> None: + """Validate RBAC role and model access permissions""" + if jwt_handler.litellm_jwtauth.enforce_rbac is True: + if rbac_role is None: + raise HTTPException( + status_code=403, + detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.", + ) + JWTAuthManager.can_rbac_role_call_model( + rbac_role=rbac_role, + general_settings=general_settings, + model=request_data.get("model"), + ) + JWTAuthManager.can_rbac_role_call_route( + rbac_role=rbac_role, + general_settings=general_settings, + route=route, + ) + + @staticmethod + async def check_admin_access( + jwt_handler: JWTHandler, + scopes: list, + route: str, + user_id: Optional[str], + org_id: Optional[str], + api_key: str, + ) -> Optional[JWTAuthBuilderResult]: + """Check admin status and route access permissions""" + if not jwt_handler.is_admin(scopes=scopes): + return None + + is_allowed = allowed_routes_check( + user_role=LitellmUserRoles.PROXY_ADMIN, + user_route=route, + litellm_proxy_roles=jwt_handler.litellm_jwtauth, + ) + if not is_allowed: + allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes + actual_routes = get_actual_routes(allowed_routes=allowed_routes) + raise Exception( + f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" + ) + + return JWTAuthBuilderResult( + is_proxy_admin=True, + team_object=None, + user_object=None, + end_user_object=None, + org_object=None, + token=api_key, + team_id=None, + user_id=user_id, + end_user_id=None, + org_id=org_id, + ) + + @staticmethod + async def find_and_validate_specific_team_id( + jwt_handler: JWTHandler, + jwt_valid_token: dict, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span], + proxy_logging_obj: ProxyLogging, + ) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]: + """Find and validate specific team ID""" + individual_team_id = jwt_handler.get_team_id( + token=jwt_valid_token, default_value=None + ) + + if not individual_team_id and jwt_handler.is_required_team_id() is True: + raise Exception( + f"No team id found in token. Checked team_id field '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'" + ) + + ## VALIDATE TEAM OBJECT ### + team_object: Optional[LiteLLM_TeamTable] = None + if individual_team_id: + team_object = await get_team_object( + team_id=individual_team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert, + ) + + return individual_team_id, team_object + + @staticmethod + def get_all_team_ids(jwt_handler: JWTHandler, jwt_valid_token: dict) -> Set[str]: + """Get combined team IDs from groups and individual team_id""" + team_ids_from_groups = jwt_handler.get_team_ids_from_jwt(token=jwt_valid_token) + + all_team_ids = set(team_ids_from_groups) + + return all_team_ids + + @staticmethod + async def find_team_with_model_access( + team_ids: Set[str], + requested_model: Optional[str], + route: str, + jwt_handler: JWTHandler, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span], + proxy_logging_obj: ProxyLogging, + ) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]: + """Find first team with access to the requested model""" + + if not team_ids: + if jwt_handler.litellm_jwtauth.enforce_team_based_model_access: + raise HTTPException( + status_code=403, + detail="No teams found in token. `enforce_team_based_model_access` is set to True. Token must belong to a team.", + ) + return None, None + + for team_id in team_ids: + try: + team_object = await get_team_object( + team_id=team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + + if team_object and team_object.models is not None: + team_models = team_object.models + if isinstance(team_models, list) and ( + not requested_model + or can_team_access_model( + model=requested_model, + team_object=team_object, + llm_router=None, + team_model_aliases=None, + ) + ): + is_allowed = allowed_routes_check( + user_role=LitellmUserRoles.TEAM, + user_route=route, + litellm_proxy_roles=jwt_handler.litellm_jwtauth, + ) + if is_allowed: + return team_id, team_object + except Exception: + continue + + if requested_model: + raise HTTPException( + status_code=403, + detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}. Check `/models` to see all available models.", + ) + + return None, None + + @staticmethod + async def get_user_info( + jwt_handler: JWTHandler, + jwt_valid_token: dict, + ) -> Tuple[Optional[str], Optional[str], Optional[bool]]: + """Get user email and validation status""" + user_email = jwt_handler.get_user_email( + token=jwt_valid_token, default_value=None + ) + valid_user_email = None + if jwt_handler.is_enforced_email_domain(): + valid_user_email = ( + False + if user_email is None + else jwt_handler.is_allowed_domain(user_email=user_email) + ) + user_id = jwt_handler.get_user_id( + token=jwt_valid_token, default_value=user_email + ) + return user_id, user_email, valid_user_email + + @staticmethod + async def get_objects( + user_id: Optional[str], + user_email: Optional[str], + org_id: Optional[str], + end_user_id: Optional[str], + valid_user_email: Optional[bool], + jwt_handler: JWTHandler, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span], + proxy_logging_obj: ProxyLogging, + ) -> Tuple[ + Optional[LiteLLM_UserTable], + Optional[LiteLLM_OrganizationTable], + Optional[LiteLLM_EndUserTable], + ]: + """Get user, org, and end user objects""" + org_object: Optional[LiteLLM_OrganizationTable] = None + if org_id: + org_object = ( + await get_org_object( + org_id=org_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + if org_id + else None + ) + + user_object: Optional[LiteLLM_UserTable] = None + if user_id: + user_object = ( + await get_user_object( + user_id=user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + user_id_upsert=jwt_handler.is_upsert_user_id( + valid_user_email=valid_user_email + ), + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + user_email=user_email, + sso_user_id=user_id, + ) + if user_id + else None + ) + + end_user_object: Optional[LiteLLM_EndUserTable] = None + if end_user_id: + end_user_object = ( + await get_end_user_object( + end_user_id=end_user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + if end_user_id + else None + ) + + return user_object, org_object, end_user_object + + @staticmethod + def validate_object_id( + user_id: Optional[str], + team_id: Optional[str], + enforce_rbac: bool, + is_proxy_admin: bool, + ) -> Literal[True]: + """If enforce_rbac is true, validate that a valid rbac id is returned for spend tracking""" + if enforce_rbac and not is_proxy_admin and not user_id and not team_id: + raise HTTPException( + status_code=403, + detail="No user or team id found in token. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.", + ) + return True + + @staticmethod + async def auth_builder( + api_key: str, + jwt_handler: JWTHandler, + request_data: dict, + general_settings: dict, + route: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span], + proxy_logging_obj: ProxyLogging, + ) -> JWTAuthBuilderResult: + """Main authentication and authorization builder""" + jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key) + + # Check custom validate + if jwt_handler.litellm_jwtauth.custom_validate: + if not jwt_handler.litellm_jwtauth.custom_validate(jwt_valid_token): + raise HTTPException( + status_code=403, + detail="Invalid JWT token", + ) + + # Check RBAC + rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token) + await JWTAuthManager.check_rbac_role( + jwt_handler, + jwt_valid_token, + general_settings, + request_data, + route, + rbac_role, + ) + + # Check Scope Based Access + scopes = jwt_handler.get_scopes(token=jwt_valid_token) + if ( + jwt_handler.litellm_jwtauth.enforce_scope_based_access + and jwt_handler.litellm_jwtauth.scope_mappings + ): + JWTAuthManager.check_scope_based_access( + scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings, + scopes=scopes, + request_data=request_data, + general_settings=general_settings, + ) + + object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None) + + # Get basic user info + scopes = jwt_handler.get_scopes(token=jwt_valid_token) + user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info( + jwt_handler, jwt_valid_token + ) + + # Get IDs + org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None) + end_user_id = jwt_handler.get_end_user_id( + token=jwt_valid_token, default_value=None + ) + team_id: Optional[str] = None + team_object: Optional[LiteLLM_TeamTable] = None + object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None) + + if rbac_role and object_id: + + if rbac_role == LitellmUserRoles.TEAM: + team_id = object_id + elif rbac_role == LitellmUserRoles.INTERNAL_USER: + user_id = object_id + + # Check admin access + admin_result = await JWTAuthManager.check_admin_access( + jwt_handler, scopes, route, user_id, org_id, api_key + ) + if admin_result: + return admin_result + + # Get team with model access + ## SPECIFIC TEAM ID + + if not team_id: + team_id, team_object = ( + await JWTAuthManager.find_and_validate_specific_team_id( + jwt_handler, + jwt_valid_token, + prisma_client, + user_api_key_cache, + parent_otel_span, + proxy_logging_obj, + ) + ) + + if not team_object and not team_id: + ## CHECK USER GROUP ACCESS + all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token) + team_id, team_object = await JWTAuthManager.find_team_with_model_access( + team_ids=all_team_ids, + requested_model=request_data.get("model"), + route=route, + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + + # Get other objects + user_object, org_object, end_user_object = await JWTAuthManager.get_objects( + user_id=user_id, + user_email=user_email, + org_id=org_id, + end_user_id=end_user_id, + valid_user_email=valid_user_email, + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + + # Validate that a valid rbac id is returned for spend tracking + JWTAuthManager.validate_object_id( + user_id=user_id, + team_id=team_id, + enforce_rbac=general_settings.get("enforce_rbac", False), + is_proxy_admin=False, + ) + + return JWTAuthBuilderResult( + is_proxy_admin=False, + team_id=team_id, + team_object=team_object, + user_id=user_id, + user_object=user_object, + org_id=org_id, + org_object=org_object, + end_user_id=end_user_id, + end_user_object=end_user_object, + token=api_key, + ) |