aboutsummaryrefslogtreecommitdiff
"""
Supports using JWT's for authenticating into the proxy. 

Currently only supports admin. 

JWT token must have 'litellm_proxy_admin' in scope. 
"""

import json
import os
from typing import Any, List, Literal, Optional, Set, Tuple, cast

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from fastapi import HTTPException

from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import (
    RBAC_ROLES,
    JWKKeyValue,
    JWTAuthBuilderResult,
    JWTKeyItem,
    LiteLLM_EndUserTable,
    LiteLLM_JWTAuth,
    LiteLLM_OrganizationTable,
    LiteLLM_TeamTable,
    LiteLLM_UserTable,
    LitellmUserRoles,
    ScopeMapping,
    Span,
)
from litellm.proxy.auth.auth_checks import can_team_access_model
from litellm.proxy.utils import PrismaClient, ProxyLogging

from .auth_checks import (
    _allowed_routes_check,
    allowed_routes_check,
    get_actual_routes,
    get_end_user_object,
    get_org_object,
    get_role_based_models,
    get_role_based_routes,
    get_team_object,
    get_user_object,
)


class JWTHandler:
    """
    - treat the sub id passed in as the user id
    - return an error if id making request doesn't exist in proxy user table
    - track spend against the user id
    - if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets
    """

    prisma_client: Optional[PrismaClient]
    user_api_key_cache: DualCache

    def __init__(
        self,
    ) -> None:
        self.http_handler = HTTPHandler()
        self.leeway = 0

    def update_environment(
        self,
        prisma_client: Optional[PrismaClient],
        user_api_key_cache: DualCache,
        litellm_jwtauth: LiteLLM_JWTAuth,
        leeway: int = 0,
    ) -> None:
        self.prisma_client = prisma_client
        self.user_api_key_cache = user_api_key_cache
        self.litellm_jwtauth = litellm_jwtauth
        self.leeway = leeway

    def is_jwt(self, token: str):
        parts = token.split(".")
        return len(parts) == 3

    def _rbac_role_from_role_mapping(self, token: dict) -> Optional[RBAC_ROLES]:
        """
        Returns the RBAC role the token 'belongs' to based on role mappings.

        Args:
            token (dict): The JWT token containing role information

        Returns:
            Optional[RBAC_ROLES]: The mapped internal RBAC role if a mapping exists,
                                None otherwise

        Note:
            The function handles both single string roles and lists of roles from the JWT.
            If multiple mappings match the JWT roles, the first matching mapping is returned.
        """
        if self.litellm_jwtauth.role_mappings is None:
            return None

        jwt_role = self.get_jwt_role(token=token, default_value=None)
        if not jwt_role:
            return None

        jwt_role_set = set(jwt_role)

        for role_mapping in self.litellm_jwtauth.role_mappings:
            # Check if the mapping role matches any of the JWT roles
            if role_mapping.role in jwt_role_set:
                return role_mapping.internal_role

        return None

    def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]:
        """
        Returns the RBAC role the token 'belongs' to.

        RBAC roles allowed to make requests:
        - PROXY_ADMIN: can make requests to all routes
        - TEAM: can make requests to routes associated with a team
        - INTERNAL_USER: can make requests to routes associated with a user

        Resolves: https://github.com/BerriAI/litellm/issues/6793

        Returns:
        - PROXY_ADMIN: if token is admin
        - TEAM: if token is associated with a team
        - INTERNAL_USER: if token is associated with a user
        - None: if token is not associated with a team or user
        """
        scopes = self.get_scopes(token=token)
        is_admin = self.is_admin(scopes=scopes)
        user_roles = self.get_user_roles(token=token, default_value=None)

        if is_admin:
            return LitellmUserRoles.PROXY_ADMIN
        elif self.get_team_id(token=token, default_value=None) is not None:
            return LitellmUserRoles.TEAM
        elif self.get_user_id(token=token, default_value=None) is not None:
            return LitellmUserRoles.INTERNAL_USER
        elif user_roles is not None and self.is_allowed_user_role(
            user_roles=user_roles
        ):
            return LitellmUserRoles.INTERNAL_USER
        elif rbac_role := self._rbac_role_from_role_mapping(token=token):
            return rbac_role

        return None

    def is_admin(self, scopes: list) -> bool:
        if self.litellm_jwtauth.admin_jwt_scope in scopes:
            return True
        return False

    def get_team_ids_from_jwt(self, token: dict) -> List[str]:
        if (
            self.litellm_jwtauth.team_ids_jwt_field is not None
            and token.get(self.litellm_jwtauth.team_ids_jwt_field) is not None
        ):
            return token[self.litellm_jwtauth.team_ids_jwt_field]
        return []

    def get_end_user_id(
        self, token: dict, default_value: Optional[str]
    ) -> Optional[str]:
        try:

            if self.litellm_jwtauth.end_user_id_jwt_field is not None:
                user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
            else:
                user_id = None
        except KeyError:
            user_id = default_value

        return user_id

    def is_required_team_id(self) -> bool:
        """
        Returns:
        - True: if 'team_id_jwt_field' is set
        - False: if not
        """
        if self.litellm_jwtauth.team_id_jwt_field is None:
            return False
        return True

    def is_enforced_email_domain(self) -> bool:
        """
        Returns:
        - True: if 'user_allowed_email_domain' is set
        - False: if 'user_allowed_email_domain' is None
        """

        if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance(
            self.litellm_jwtauth.user_allowed_email_domain, str
        ):
            return True
        return False

    def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
        try:
            if self.litellm_jwtauth.team_id_jwt_field is not None:
                team_id = token[self.litellm_jwtauth.team_id_jwt_field]
            elif self.litellm_jwtauth.team_id_default is not None:
                team_id = self.litellm_jwtauth.team_id_default
            else:
                team_id = None
        except KeyError:
            team_id = default_value
        return team_id

    def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool:
        """
        Returns:
        - True: if 'user_id_upsert' is set AND valid_user_email is not False
        - False: if not
        """
        if valid_user_email is False:
            return False
        return self.litellm_jwtauth.user_id_upsert

    def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
        try:
            if self.litellm_jwtauth.user_id_jwt_field is not None:
                user_id = token[self.litellm_jwtauth.user_id_jwt_field]
            else:
                user_id = default_value
        except KeyError:
            user_id = default_value
        return user_id

    def get_user_roles(
        self, token: dict, default_value: Optional[List[str]]
    ) -> Optional[List[str]]:
        """
        Returns the user role from the token.

        Set via 'user_roles_jwt_field' in the config.
        """
        try:
            if self.litellm_jwtauth.user_roles_jwt_field is not None:
                user_roles = get_nested_value(
                    data=token,
                    key_path=self.litellm_jwtauth.user_roles_jwt_field,
                    default=default_value,
                )
            else:
                user_roles = default_value
        except KeyError:
            user_roles = default_value
        return user_roles

    def get_jwt_role(
        self, token: dict, default_value: Optional[List[str]]
    ) -> Optional[List[str]]:
        """
        Generic implementation of `get_user_roles` that can be used for both user and team roles.

        Returns the jwt role from the token.

        Set via 'roles_jwt_field' in the config.
        """
        try:
            if self.litellm_jwtauth.roles_jwt_field is not None:
                user_roles = get_nested_value(
                    data=token,
                    key_path=self.litellm_jwtauth.roles_jwt_field,
                    default=default_value,
                )
            else:
                user_roles = default_value
        except KeyError:
            user_roles = default_value
        return user_roles

    def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool:
        """
        Returns the user role from the token.

        Set via 'user_allowed_roles' in the config.
        """
        if (
            user_roles is not None
            and self.litellm_jwtauth.user_allowed_roles is not None
            and any(
                role in self.litellm_jwtauth.user_allowed_roles for role in user_roles
            )
        ):
            return True
        return False

    def get_user_email(
        self, token: dict, default_value: Optional[str]
    ) -> Optional[str]:
        try:
            if self.litellm_jwtauth.user_email_jwt_field is not None:
                user_email = token[self.litellm_jwtauth.user_email_jwt_field]
            else:
                user_email = None
        except KeyError:
            user_email = default_value
        return user_email

    def get_object_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
        try:
            if self.litellm_jwtauth.object_id_jwt_field is not None:
                object_id = token[self.litellm_jwtauth.object_id_jwt_field]
            else:
                object_id = default_value
        except KeyError:
            object_id = default_value
        return object_id

    def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
        try:
            if self.litellm_jwtauth.org_id_jwt_field is not None:
                org_id = token[self.litellm_jwtauth.org_id_jwt_field]
            else:
                org_id = None
        except KeyError:
            org_id = default_value
        return org_id

    def get_scopes(self, token: dict) -> List[str]:
        try:
            if isinstance(token["scope"], str):
                # Assuming the scopes are stored in 'scope' claim and are space-separated
                scopes = token["scope"].split()
            elif isinstance(token["scope"], list):
                scopes = token["scope"]
            else:
                raise Exception(
                    f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str."
                )
        except KeyError:
            scopes = []
        return scopes

    async def get_public_key(self, kid: Optional[str]) -> dict:

        keys_url = os.getenv("JWT_PUBLIC_KEY_URL")

        if keys_url is None:
            raise Exception("Missing JWT Public Key URL from environment.")

        keys_url_list = [url.strip() for url in keys_url.split(",")]

        for key_url in keys_url_list:

            cache_key = f"litellm_jwt_auth_keys_{key_url}"

            cached_keys = await self.user_api_key_cache.async_get_cache(cache_key)

            if cached_keys is None:
                response = await self.http_handler.get(key_url)

                response_json = response.json()
                if "keys" in response_json:
                    keys: JWKKeyValue = response.json()["keys"]
                else:
                    keys = response_json

                await self.user_api_key_cache.async_set_cache(
                    key=cache_key,
                    value=keys,
                    ttl=self.litellm_jwtauth.public_key_ttl,  # cache for 10 mins
                )
            else:
                keys = cached_keys

            public_key = self.parse_keys(keys=keys, kid=kid)
            if public_key is not None:
                return cast(dict, public_key)

        raise Exception(
            f"No matching public key found. keys={keys_url_list}, kid={kid}"
        )

    def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]:
        public_key: Optional[JWTKeyItem] = None
        if len(keys) == 1:
            if isinstance(keys, dict) and (keys.get("kid", None) == kid or kid is None):
                public_key = keys
            elif isinstance(keys, list) and (
                keys[0].get("kid", None) == kid or kid is None
            ):
                public_key = keys[0]
        elif len(keys) > 1:
            for key in keys:
                if isinstance(key, dict):
                    key_kid = key.get("kid", None)
                else:
                    key_kid = None
                if (
                    kid is not None
                    and isinstance(key, dict)
                    and key_kid is not None
                    and key_kid == kid
                ):
                    public_key = key

        return public_key

    def is_allowed_domain(self, user_email: str) -> bool:
        if self.litellm_jwtauth.user_allowed_email_domain is None:
            return True

        email_domain = user_email.split("@")[-1]  # Extract domain from email
        if email_domain == self.litellm_jwtauth.user_allowed_email_domain:
            return True
        else:
            return False

    async def auth_jwt(self, token: str) -> dict:
        # Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
        # "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret
        #   the key in different ways (e.g. HS* and RS*)."
        algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]

        audience = os.getenv("JWT_AUDIENCE")
        decode_options = None
        if audience is None:
            decode_options = {"verify_aud": False}

        import jwt
        from jwt.algorithms import RSAAlgorithm

        header = jwt.get_unverified_header(token)

        verbose_proxy_logger.debug("header: %s", header)

        kid = header.get("kid", None)

        public_key = await self.get_public_key(kid=kid)

        if public_key is not None and isinstance(public_key, dict):
            jwk = {}
            if "kty" in public_key:
                jwk["kty"] = public_key["kty"]
            if "kid" in public_key:
                jwk["kid"] = public_key["kid"]
            if "n" in public_key:
                jwk["n"] = public_key["n"]
            if "e" in public_key:
                jwk["e"] = public_key["e"]

            public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk))

            try:
                # decode the token using the public key
                payload = jwt.decode(
                    token,
                    public_key_rsa,  # type: ignore
                    algorithms=algorithms,
                    options=decode_options,
                    audience=audience,
                    leeway=self.leeway,  # allow testing of expired tokens
                )
                return payload

            except jwt.ExpiredSignatureError:
                # the token is expired, do something to refresh it
                raise Exception("Token Expired")
            except Exception as e:
                raise Exception(f"Validation fails: {str(e)}")
        elif public_key is not None and isinstance(public_key, str):
            try:
                cert = x509.load_pem_x509_certificate(
                    public_key.encode(), default_backend()
                )

                # Extract public key
                key = cert.public_key().public_bytes(
                    serialization.Encoding.PEM,
                    serialization.PublicFormat.SubjectPublicKeyInfo,
                )

                # decode the token using the public key
                payload = jwt.decode(
                    token,
                    key,
                    algorithms=algorithms,
                    audience=audience,
                    options=decode_options,
                )
                return payload

            except jwt.ExpiredSignatureError:
                # the token is expired, do something to refresh it
                raise Exception("Token Expired")
            except Exception as e:
                raise Exception(f"Validation fails: {str(e)}")

        raise Exception("Invalid JWT Submitted")

    async def close(self):
        await self.http_handler.close()


class JWTAuthManager:
    """Manages JWT authentication and authorization operations"""

    @staticmethod
    def can_rbac_role_call_route(
        rbac_role: RBAC_ROLES,
        general_settings: dict,
        route: str,
    ) -> Literal[True]:
        """
        Checks if user is allowed to access the route, based on their role.
        """
        role_based_routes = get_role_based_routes(
            rbac_role=rbac_role, general_settings=general_settings
        )

        if role_based_routes is None or route is None:
            return True

        is_allowed = _allowed_routes_check(
            user_route=route,
            allowed_routes=role_based_routes,
        )

        if not is_allowed:
            raise HTTPException(
                status_code=403,
                detail=f"Role={rbac_role} not allowed to call route={route}. Allowed routes={role_based_routes}",
            )

        return True

    @staticmethod
    def can_rbac_role_call_model(
        rbac_role: RBAC_ROLES,
        general_settings: dict,
        model: Optional[str],
    ) -> Literal[True]:
        """
        Checks if user is allowed to access the model, based on their role.
        """
        role_based_models = get_role_based_models(
            rbac_role=rbac_role, general_settings=general_settings
        )
        if role_based_models is None or model is None:
            return True

        if model not in role_based_models:
            raise HTTPException(
                status_code=403,
                detail=f"Role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
            )

        return True

    @staticmethod
    def check_scope_based_access(
        scope_mappings: List[ScopeMapping],
        scopes: List[str],
        request_data: dict,
        general_settings: dict,
    ) -> None:
        """
        Check if scope allows access to the requested model
        """
        if not scope_mappings:
            return None

        allowed_models = []
        for sm in scope_mappings:
            if sm.scope in scopes and sm.models:
                allowed_models.extend(sm.models)

        requested_model = request_data.get("model")

        if not requested_model:
            return None

        if requested_model not in allowed_models:
            raise HTTPException(
                status_code=403,
                detail={
                    "error": "model={} not allowed. Allowed_models={}".format(
                        requested_model, allowed_models
                    )
                },
            )
        return None

    @staticmethod
    async def check_rbac_role(
        jwt_handler: JWTHandler,
        jwt_valid_token: dict,
        general_settings: dict,
        request_data: dict,
        route: str,
        rbac_role: Optional[RBAC_ROLES],
    ) -> None:
        """Validate RBAC role and model access permissions"""
        if jwt_handler.litellm_jwtauth.enforce_rbac is True:
            if rbac_role is None:
                raise HTTPException(
                    status_code=403,
                    detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
                )
            JWTAuthManager.can_rbac_role_call_model(
                rbac_role=rbac_role,
                general_settings=general_settings,
                model=request_data.get("model"),
            )
            JWTAuthManager.can_rbac_role_call_route(
                rbac_role=rbac_role,
                general_settings=general_settings,
                route=route,
            )

    @staticmethod
    async def check_admin_access(
        jwt_handler: JWTHandler,
        scopes: list,
        route: str,
        user_id: Optional[str],
        org_id: Optional[str],
        api_key: str,
    ) -> Optional[JWTAuthBuilderResult]:
        """Check admin status and route access permissions"""
        if not jwt_handler.is_admin(scopes=scopes):
            return None

        is_allowed = allowed_routes_check(
            user_role=LitellmUserRoles.PROXY_ADMIN,
            user_route=route,
            litellm_proxy_roles=jwt_handler.litellm_jwtauth,
        )
        if not is_allowed:
            allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
            actual_routes = get_actual_routes(allowed_routes=allowed_routes)
            raise Exception(
                f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
            )

        return JWTAuthBuilderResult(
            is_proxy_admin=True,
            team_object=None,
            user_object=None,
            end_user_object=None,
            org_object=None,
            token=api_key,
            team_id=None,
            user_id=user_id,
            end_user_id=None,
            org_id=org_id,
        )

    @staticmethod
    async def find_and_validate_specific_team_id(
        jwt_handler: JWTHandler,
        jwt_valid_token: dict,
        prisma_client: Optional[PrismaClient],
        user_api_key_cache: DualCache,
        parent_otel_span: Optional[Span],
        proxy_logging_obj: ProxyLogging,
    ) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
        """Find and validate specific team ID"""
        individual_team_id = jwt_handler.get_team_id(
            token=jwt_valid_token, default_value=None
        )

        if not individual_team_id and jwt_handler.is_required_team_id() is True:
            raise Exception(
                f"No team id found in token. Checked team_id field '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
            )

        ## VALIDATE TEAM OBJECT ###
        team_object: Optional[LiteLLM_TeamTable] = None
        if individual_team_id:
            team_object = await get_team_object(
                team_id=individual_team_id,
                prisma_client=prisma_client,
                user_api_key_cache=user_api_key_cache,
                parent_otel_span=parent_otel_span,
                proxy_logging_obj=proxy_logging_obj,
                team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert,
            )

        return individual_team_id, team_object

    @staticmethod
    def get_all_team_ids(jwt_handler: JWTHandler, jwt_valid_token: dict) -> Set[str]:
        """Get combined team IDs from groups and individual team_id"""
        team_ids_from_groups = jwt_handler.get_team_ids_from_jwt(token=jwt_valid_token)

        all_team_ids = set(team_ids_from_groups)

        return all_team_ids

    @staticmethod
    async def find_team_with_model_access(
        team_ids: Set[str],
        requested_model: Optional[str],
        route: str,
        jwt_handler: JWTHandler,
        prisma_client: Optional[PrismaClient],
        user_api_key_cache: DualCache,
        parent_otel_span: Optional[Span],
        proxy_logging_obj: ProxyLogging,
    ) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
        """Find first team with access to the requested model"""

        if not team_ids:
            if jwt_handler.litellm_jwtauth.enforce_team_based_model_access:
                raise HTTPException(
                    status_code=403,
                    detail="No teams found in token. `enforce_team_based_model_access` is set to True. Token must belong to a team.",
                )
            return None, None

        for team_id in team_ids:
            try:
                team_object = await get_team_object(
                    team_id=team_id,
                    prisma_client=prisma_client,
                    user_api_key_cache=user_api_key_cache,
                    parent_otel_span=parent_otel_span,
                    proxy_logging_obj=proxy_logging_obj,
                )

                if team_object and team_object.models is not None:
                    team_models = team_object.models
                    if isinstance(team_models, list) and (
                        not requested_model
                        or can_team_access_model(
                            model=requested_model,
                            team_object=team_object,
                            llm_router=None,
                            team_model_aliases=None,
                        )
                    ):
                        is_allowed = allowed_routes_check(
                            user_role=LitellmUserRoles.TEAM,
                            user_route=route,
                            litellm_proxy_roles=jwt_handler.litellm_jwtauth,
                        )
                        if is_allowed:
                            return team_id, team_object
            except Exception:
                continue

        if requested_model:
            raise HTTPException(
                status_code=403,
                detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}. Check `/models` to see all available models.",
            )

        return None, None

    @staticmethod
    async def get_user_info(
        jwt_handler: JWTHandler,
        jwt_valid_token: dict,
    ) -> Tuple[Optional[str], Optional[str], Optional[bool]]:
        """Get user email and validation status"""
        user_email = jwt_handler.get_user_email(
            token=jwt_valid_token, default_value=None
        )
        valid_user_email = None
        if jwt_handler.is_enforced_email_domain():
            valid_user_email = (
                False
                if user_email is None
                else jwt_handler.is_allowed_domain(user_email=user_email)
            )
        user_id = jwt_handler.get_user_id(
            token=jwt_valid_token, default_value=user_email
        )
        return user_id, user_email, valid_user_email

    @staticmethod
    async def get_objects(
        user_id: Optional[str],
        user_email: Optional[str],
        org_id: Optional[str],
        end_user_id: Optional[str],
        valid_user_email: Optional[bool],
        jwt_handler: JWTHandler,
        prisma_client: Optional[PrismaClient],
        user_api_key_cache: DualCache,
        parent_otel_span: Optional[Span],
        proxy_logging_obj: ProxyLogging,
    ) -> Tuple[
        Optional[LiteLLM_UserTable],
        Optional[LiteLLM_OrganizationTable],
        Optional[LiteLLM_EndUserTable],
    ]:
        """Get user, org, and end user objects"""
        org_object: Optional[LiteLLM_OrganizationTable] = None
        if org_id:
            org_object = (
                await get_org_object(
                    org_id=org_id,
                    prisma_client=prisma_client,
                    user_api_key_cache=user_api_key_cache,
                    parent_otel_span=parent_otel_span,
                    proxy_logging_obj=proxy_logging_obj,
                )
                if org_id
                else None
            )

        user_object: Optional[LiteLLM_UserTable] = None
        if user_id:
            user_object = (
                await get_user_object(
                    user_id=user_id,
                    prisma_client=prisma_client,
                    user_api_key_cache=user_api_key_cache,
                    user_id_upsert=jwt_handler.is_upsert_user_id(
                        valid_user_email=valid_user_email
                    ),
                    parent_otel_span=parent_otel_span,
                    proxy_logging_obj=proxy_logging_obj,
                    user_email=user_email,
                    sso_user_id=user_id,
                )
                if user_id
                else None
            )

        end_user_object: Optional[LiteLLM_EndUserTable] = None
        if end_user_id:
            end_user_object = (
                await get_end_user_object(
                    end_user_id=end_user_id,
                    prisma_client=prisma_client,
                    user_api_key_cache=user_api_key_cache,
                    parent_otel_span=parent_otel_span,
                    proxy_logging_obj=proxy_logging_obj,
                )
                if end_user_id
                else None
            )

        return user_object, org_object, end_user_object

    @staticmethod
    def validate_object_id(
        user_id: Optional[str],
        team_id: Optional[str],
        enforce_rbac: bool,
        is_proxy_admin: bool,
    ) -> Literal[True]:
        """If enforce_rbac is true, validate that a valid rbac id is returned for spend tracking"""
        if enforce_rbac and not is_proxy_admin and not user_id and not team_id:
            raise HTTPException(
                status_code=403,
                detail="No user or team id found in token. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
            )
        return True

    @staticmethod
    async def auth_builder(
        api_key: str,
        jwt_handler: JWTHandler,
        request_data: dict,
        general_settings: dict,
        route: str,
        prisma_client: Optional[PrismaClient],
        user_api_key_cache: DualCache,
        parent_otel_span: Optional[Span],
        proxy_logging_obj: ProxyLogging,
    ) -> JWTAuthBuilderResult:
        """Main authentication and authorization builder"""
        jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)

        # Check custom validate
        if jwt_handler.litellm_jwtauth.custom_validate:
            if not jwt_handler.litellm_jwtauth.custom_validate(jwt_valid_token):
                raise HTTPException(
                    status_code=403,
                    detail="Invalid JWT token",
                )

        # Check RBAC
        rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
        await JWTAuthManager.check_rbac_role(
            jwt_handler,
            jwt_valid_token,
            general_settings,
            request_data,
            route,
            rbac_role,
        )

        # Check Scope Based Access
        scopes = jwt_handler.get_scopes(token=jwt_valid_token)
        if (
            jwt_handler.litellm_jwtauth.enforce_scope_based_access
            and jwt_handler.litellm_jwtauth.scope_mappings
        ):
            JWTAuthManager.check_scope_based_access(
                scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings,
                scopes=scopes,
                request_data=request_data,
                general_settings=general_settings,
            )

        object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)

        # Get basic user info
        scopes = jwt_handler.get_scopes(token=jwt_valid_token)
        user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info(
            jwt_handler, jwt_valid_token
        )

        # Get IDs
        org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
        end_user_id = jwt_handler.get_end_user_id(
            token=jwt_valid_token, default_value=None
        )
        team_id: Optional[str] = None
        team_object: Optional[LiteLLM_TeamTable] = None
        object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)

        if rbac_role and object_id:

            if rbac_role == LitellmUserRoles.TEAM:
                team_id = object_id
            elif rbac_role == LitellmUserRoles.INTERNAL_USER:
                user_id = object_id

        # Check admin access
        admin_result = await JWTAuthManager.check_admin_access(
            jwt_handler, scopes, route, user_id, org_id, api_key
        )
        if admin_result:
            return admin_result

        # Get team with model access
        ## SPECIFIC TEAM ID

        if not team_id:
            team_id, team_object = (
                await JWTAuthManager.find_and_validate_specific_team_id(
                    jwt_handler,
                    jwt_valid_token,
                    prisma_client,
                    user_api_key_cache,
                    parent_otel_span,
                    proxy_logging_obj,
                )
            )

        if not team_object and not team_id:
            ## CHECK USER GROUP ACCESS
            all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token)
            team_id, team_object = await JWTAuthManager.find_team_with_model_access(
                team_ids=all_team_ids,
                requested_model=request_data.get("model"),
                route=route,
                jwt_handler=jwt_handler,
                prisma_client=prisma_client,
                user_api_key_cache=user_api_key_cache,
                parent_otel_span=parent_otel_span,
                proxy_logging_obj=proxy_logging_obj,
            )

        # Get other objects
        user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
            user_id=user_id,
            user_email=user_email,
            org_id=org_id,
            end_user_id=end_user_id,
            valid_user_email=valid_user_email,
            jwt_handler=jwt_handler,
            prisma_client=prisma_client,
            user_api_key_cache=user_api_key_cache,
            parent_otel_span=parent_otel_span,
            proxy_logging_obj=proxy_logging_obj,
        )

        # Validate that a valid rbac id is returned for spend tracking
        JWTAuthManager.validate_object_id(
            user_id=user_id,
            team_id=team_id,
            enforce_rbac=general_settings.get("enforce_rbac", False),
            is_proxy_admin=False,
        )

        return JWTAuthBuilderResult(
            is_proxy_admin=False,
            team_id=team_id,
            team_object=team_object,
            user_id=user_id,
            user_object=user_object,
            org_id=org_id,
            org_object=org_object,
            end_user_id=end_user_id,
            end_user_object=end_user_object,
            token=api_key,
        )