about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py514
1 files changed, 514 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py
new file mode 100644
index 00000000..91fcaf7e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py
@@ -0,0 +1,514 @@
+import os
+import re
+import sys
+from typing import Any, List, Optional, Tuple
+
+from fastapi import HTTPException, Request, status
+
+from litellm import Router, provider_list
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import *
+from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS
+
+
+def _get_request_ip_address(
+    request: Request, use_x_forwarded_for: Optional[bool] = False
+) -> Optional[str]:
+
+    client_ip = None
+    if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
+        client_ip = request.headers["x-forwarded-for"]
+    elif request.client is not None:
+        client_ip = request.client.host
+    else:
+        client_ip = ""
+
+    return client_ip
+
+
+def _check_valid_ip(
+    allowed_ips: Optional[List[str]],
+    request: Request,
+    use_x_forwarded_for: Optional[bool] = False,
+) -> Tuple[bool, Optional[str]]:
+    """
+    Returns if ip is allowed or not
+    """
+    if allowed_ips is None:  # if not set, assume true
+        return True, None
+
+    # if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
+    client_ip = _get_request_ip_address(
+        request=request, use_x_forwarded_for=use_x_forwarded_for
+    )
+
+    # Check if IP address is allowed
+    if client_ip not in allowed_ips:
+        return False, client_ip
+
+    return True, client_ip
+
+
+def check_complete_credentials(request_body: dict) -> bool:
+    """
+    if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks.
+    """
+    given_model: Optional[str] = None
+
+    given_model = request_body.get("model")
+    if given_model is None:
+        return False
+
+    if (
+        "sagemaker" in given_model
+        or "bedrock" in given_model
+        or "vertex_ai" in given_model
+        or "vertex_ai_beta" in given_model
+    ):
+        # complex credentials - easier to make a malicious request
+        return False
+
+    if "api_key" in request_body:
+        return True
+
+    return False
+
+
+def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
+    """
+    Check if request_body_value matches the regex_str or is equal to param
+    """
+    if re.match(regex_str, request_body_value) or regex_str == request_body_value:
+        return True
+    return False
+
+
+def _is_param_allowed(
+    param: str,
+    request_body_value: Any,
+    configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
+) -> bool:
+    """
+    Check if param is a str or dict and if request_body_value is in the list of allowed values
+    """
+    if configurable_clientside_auth_params is None:
+        return False
+
+    for item in configurable_clientside_auth_params:
+        if isinstance(item, str) and param == item:
+            return True
+        elif isinstance(item, Dict):
+            if param == "api_base" and check_regex_or_str_match(
+                request_body_value=request_body_value,
+                regex_str=item["api_base"],
+            ):  # assume param is a regex
+                return True
+
+    return False
+
+
+def _allow_model_level_clientside_configurable_parameters(
+    model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
+) -> bool:
+    """
+    Check if model is allowed to use configurable client-side params
+    - get matching model
+    - check if 'clientside_configurable_parameters' is set for model
+    -
+    """
+    if llm_router is None:
+        return False
+    # check if model is set
+    model_info = llm_router.get_model_group_info(model_group=model)
+    if model_info is None:
+        # check if wildcard model is set
+        if model.split("/", 1)[0] in provider_list:
+            model_info = llm_router.get_model_group_info(
+                model_group=model.split("/", 1)[0]
+            )
+
+    if model_info is None:
+        return False
+
+    if model_info is None or model_info.configurable_clientside_auth_params is None:
+        return False
+
+    return _is_param_allowed(
+        param=param,
+        request_body_value=request_body_value,
+        configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
+    )
+
+
+def is_request_body_safe(
+    request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
+) -> bool:
+    """
+    Check if the request body is safe.
+
+    A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key.
+    Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997
+    """
+    banned_params = ["api_base", "base_url"]
+
+    for param in banned_params:
+        if (
+            param in request_body
+            and not check_complete_credentials(  # allow client-credentials to be passed to proxy
+                request_body=request_body
+            )
+        ):
+            if general_settings.get("allow_client_side_credentials") is True:
+                return True
+            elif (
+                _allow_model_level_clientside_configurable_parameters(
+                    model=model,
+                    param=param,
+                    request_body_value=request_body[param],
+                    llm_router=llm_router,
+                )
+                is True
+            ):
+                return True
+            raise ValueError(
+                f"Rejected Request: {param} is not allowed in request body. "
+                "Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
+                "Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
+            )
+
+    return True
+
+
+async def pre_db_read_auth_checks(
+    request: Request,
+    request_data: dict,
+    route: str,
+):
+    """
+    1. Checks if request size is under max_request_size_mb (if set)
+    2. Check if request body is safe (example user has not set api_base in request body)
+    3. Check if IP address is allowed (if set)
+    4. Check if request route is an allowed route on the proxy (if set)
+
+    Returns:
+    - True
+
+    Raises:
+    - HTTPException if request fails initial auth checks
+    """
+    from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
+
+    # Check 1. request size
+    await check_if_request_size_is_safe(request=request)
+
+    # Check 2. Request body is safe
+    is_request_body_safe(
+        request_body=request_data,
+        general_settings=general_settings,
+        llm_router=llm_router,
+        model=request_data.get(
+            "model", ""
+        ),  # [TODO] use model passed in url as well (azure openai routes)
+    )
+
+    # Check 3. Check if IP address is allowed
+    is_valid_ip, passed_in_ip = _check_valid_ip(
+        allowed_ips=general_settings.get("allowed_ips", None),
+        use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
+        request=request,
+    )
+
+    if not is_valid_ip:
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
+        )
+
+    # Check 4. Check if request route is an allowed route on the proxy
+    if "allowed_routes" in general_settings:
+        _allowed_routes = general_settings["allowed_routes"]
+        if premium_user is not True:
+            verbose_proxy_logger.error(
+                f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}"
+            )
+        if route not in _allowed_routes:
+            verbose_proxy_logger.error(
+                f"Route {route} not in allowed_routes={_allowed_routes}"
+            )
+            raise HTTPException(
+                status_code=status.HTTP_403_FORBIDDEN,
+                detail=f"Access forbidden: Route {route} not allowed",
+            )
+
+
+def route_in_additonal_public_routes(current_route: str):
+    """
+    Helper to check if the user defined public_routes on config.yaml
+
+    Parameters:
+    - current_route: str - the route the user is trying to call
+
+    Returns:
+    - bool - True if the route is defined in public_routes
+    - bool - False if the route is not defined in public_routes
+
+
+    In order to use this the litellm config.yaml should have the following in general_settings:
+
+    ```yaml
+    general_settings:
+        master_key: sk-1234
+        public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"]
+    ```
+    """
+
+    # check if user is premium_user - if not do nothing
+    from litellm.proxy.proxy_server import general_settings, premium_user
+
+    try:
+        if premium_user is not True:
+            return False
+        # check if this is defined on the config
+        if general_settings is None:
+            return False
+
+        routes_defined = general_settings.get("public_routes", [])
+        if current_route in routes_defined:
+            return True
+
+        return False
+    except Exception as e:
+        verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}")
+        return False
+
+
+def get_request_route(request: Request) -> str:
+    """
+    Helper to get the route from the request
+
+    remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions
+    """
+    try:
+        if hasattr(request, "base_url") and request.url.path.startswith(
+            request.base_url.path
+        ):
+            # remove base_url from path
+            return request.url.path[len(request.base_url.path) - 1 :]
+        else:
+            return request.url.path
+    except Exception as e:
+        verbose_proxy_logger.debug(
+            f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}"
+        )
+        return request.url.path
+
+
+async def check_if_request_size_is_safe(request: Request) -> bool:
+    """
+    Enterprise Only:
+        - Checks if the request size is within the limit
+
+    Args:
+        request (Request): The incoming request.
+
+    Returns:
+        bool: True if the request size is within the limit
+
+    Raises:
+        ProxyException: If the request size is too large
+
+    """
+    from litellm.proxy.proxy_server import general_settings, premium_user
+
+    max_request_size_mb = general_settings.get("max_request_size_mb", None)
+    if max_request_size_mb is not None:
+        # Check if premium user
+        if premium_user is not True:
+            verbose_proxy_logger.warning(
+                f"using max_request_size_mb - not checking -  this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
+            )
+            return True
+
+        # Get the request body
+        content_length = request.headers.get("content-length")
+
+        if content_length:
+            header_size = int(content_length)
+            header_size_mb = bytes_to_mb(bytes_value=header_size)
+            verbose_proxy_logger.debug(
+                f"content_length request size in MB={header_size_mb}"
+            )
+
+            if header_size_mb > max_request_size_mb:
+                raise ProxyException(
+                    message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB",
+                    type=ProxyErrorTypes.bad_request_error.value,
+                    code=400,
+                    param="content-length",
+                )
+        else:
+            # If Content-Length is not available, read the body
+            body = await request.body()
+            body_size = len(body)
+            request_size_mb = bytes_to_mb(bytes_value=body_size)
+
+            verbose_proxy_logger.debug(
+                f"request body request size in MB={request_size_mb}"
+            )
+            if request_size_mb > max_request_size_mb:
+                raise ProxyException(
+                    message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB",
+                    type=ProxyErrorTypes.bad_request_error.value,
+                    code=400,
+                    param="content-length",
+                )
+
+    return True
+
+
+async def check_response_size_is_safe(response: Any) -> bool:
+    """
+    Enterprise Only:
+        - Checks if the response size is within the limit
+
+    Args:
+        response (Any): The response to check.
+
+    Returns:
+        bool: True if the response size is within the limit
+
+    Raises:
+        ProxyException: If the response size is too large
+
+    """
+
+    from litellm.proxy.proxy_server import general_settings, premium_user
+
+    max_response_size_mb = general_settings.get("max_response_size_mb", None)
+    if max_response_size_mb is not None:
+        # Check if premium user
+        if premium_user is not True:
+            verbose_proxy_logger.warning(
+                f"using max_response_size_mb - not checking -  this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
+            )
+            return True
+
+        response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response))
+        verbose_proxy_logger.debug(f"response size in MB={response_size_mb}")
+        if response_size_mb > max_response_size_mb:
+            raise ProxyException(
+                message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB",
+                type=ProxyErrorTypes.bad_request_error.value,
+                code=400,
+                param="content-length",
+            )
+
+    return True
+
+
+def bytes_to_mb(bytes_value: int):
+    """
+    Helper to convert bytes to MB
+    """
+    return bytes_value / (1024 * 1024)
+
+
+# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
+def get_key_model_rpm_limit(
+    user_api_key_dict: UserAPIKeyAuth,
+) -> Optional[Dict[str, int]]:
+    if user_api_key_dict.metadata:
+        if "model_rpm_limit" in user_api_key_dict.metadata:
+            return user_api_key_dict.metadata["model_rpm_limit"]
+    elif user_api_key_dict.model_max_budget:
+        model_rpm_limit: Dict[str, Any] = {}
+        for model, budget in user_api_key_dict.model_max_budget.items():
+            if "rpm_limit" in budget and budget["rpm_limit"] is not None:
+                model_rpm_limit[model] = budget["rpm_limit"]
+        return model_rpm_limit
+
+    return None
+
+
+def get_key_model_tpm_limit(
+    user_api_key_dict: UserAPIKeyAuth,
+) -> Optional[Dict[str, int]]:
+    if user_api_key_dict.metadata:
+        if "model_tpm_limit" in user_api_key_dict.metadata:
+            return user_api_key_dict.metadata["model_tpm_limit"]
+    elif user_api_key_dict.model_max_budget:
+        if "tpm_limit" in user_api_key_dict.model_max_budget:
+            return user_api_key_dict.model_max_budget["tpm_limit"]
+
+    return None
+
+
+def is_pass_through_provider_route(route: str) -> bool:
+    PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
+        "vertex-ai",
+    ]
+
+    # check if any of the prefixes are in the route
+    for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES:
+        if prefix in route:
+            return True
+
+    return False
+
+
+def should_run_auth_on_pass_through_provider_route(route: str) -> bool:
+    """
+    Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on /vertex-ai/{endpoint} routes
+    Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on provider pass through routes
+    ex /vertex-ai/{endpoint} routes
+    Run virtual key auth if the following is try:
+    - User is premium_user
+    - User has enabled litellm_setting.use_client_credentials_pass_through_routes
+    """
+    from litellm.proxy.proxy_server import general_settings, premium_user
+
+    if premium_user is not True:
+
+        return False
+
+    # premium use has opted into using client credentials
+    if (
+        general_settings.get("use_client_credentials_pass_through_routes", False)
+        is True
+    ):
+        return False
+
+    # only enabled for LiteLLM Enterprise
+    return True
+
+
+def _has_user_setup_sso():
+    """
+    Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables.
+    Returns a boolean indicating whether SSO has been set up.
+    """
+    microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
+    google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
+    generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
+
+    sso_setup = (
+        (microsoft_client_id is not None)
+        or (google_client_id is not None)
+        or (generic_client_id is not None)
+    )
+
+    return sso_setup
+
+
+def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]:
+    # openai - check 'user'
+    if "user" in request_body and request_body["user"] is not None:
+        return str(request_body["user"])
+    # anthropic - check 'litellm_metadata'
+    end_user_id = request_body.get("litellm_metadata", {}).get("user", None)
+    if end_user_id:
+        return str(end_user_id)
+    metadata = request_body.get("metadata")
+    if metadata and "user_id" in metadata and metadata["user_id"] is not None:
+        return str(metadata["user_id"])
+    return None