diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py | 514 |
1 files changed, 514 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py new file mode 100644 index 00000000..91fcaf7e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py @@ -0,0 +1,514 @@ +import os +import re +import sys +from typing import Any, List, Optional, Tuple + +from fastapi import HTTPException, Request, status + +from litellm import Router, provider_list +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import * +from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS + + +def _get_request_ip_address( + request: Request, use_x_forwarded_for: Optional[bool] = False +) -> Optional[str]: + + client_ip = None + if use_x_forwarded_for is True and "x-forwarded-for" in request.headers: + client_ip = request.headers["x-forwarded-for"] + elif request.client is not None: + client_ip = request.client.host + else: + client_ip = "" + + return client_ip + + +def _check_valid_ip( + allowed_ips: Optional[List[str]], + request: Request, + use_x_forwarded_for: Optional[bool] = False, +) -> Tuple[bool, Optional[str]]: + """ + Returns if ip is allowed or not + """ + if allowed_ips is None: # if not set, assume true + return True, None + + # if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for + client_ip = _get_request_ip_address( + request=request, use_x_forwarded_for=use_x_forwarded_for + ) + + # Check if IP address is allowed + if client_ip not in allowed_ips: + return False, client_ip + + return True, client_ip + + +def check_complete_credentials(request_body: dict) -> bool: + """ + if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks. + """ + given_model: Optional[str] = None + + given_model = request_body.get("model") + if given_model is None: + return False + + if ( + "sagemaker" in given_model + or "bedrock" in given_model + or "vertex_ai" in given_model + or "vertex_ai_beta" in given_model + ): + # complex credentials - easier to make a malicious request + return False + + if "api_key" in request_body: + return True + + return False + + +def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool: + """ + Check if request_body_value matches the regex_str or is equal to param + """ + if re.match(regex_str, request_body_value) or regex_str == request_body_value: + return True + return False + + +def _is_param_allowed( + param: str, + request_body_value: Any, + configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS, +) -> bool: + """ + Check if param is a str or dict and if request_body_value is in the list of allowed values + """ + if configurable_clientside_auth_params is None: + return False + + for item in configurable_clientside_auth_params: + if isinstance(item, str) and param == item: + return True + elif isinstance(item, Dict): + if param == "api_base" and check_regex_or_str_match( + request_body_value=request_body_value, + regex_str=item["api_base"], + ): # assume param is a regex + return True + + return False + + +def _allow_model_level_clientside_configurable_parameters( + model: str, param: str, request_body_value: Any, llm_router: Optional[Router] +) -> bool: + """ + Check if model is allowed to use configurable client-side params + - get matching model + - check if 'clientside_configurable_parameters' is set for model + - + """ + if llm_router is None: + return False + # check if model is set + model_info = llm_router.get_model_group_info(model_group=model) + if model_info is None: + # check if wildcard model is set + if model.split("/", 1)[0] in provider_list: + model_info = llm_router.get_model_group_info( + model_group=model.split("/", 1)[0] + ) + + if model_info is None: + return False + + if model_info is None or model_info.configurable_clientside_auth_params is None: + return False + + return _is_param_allowed( + param=param, + request_body_value=request_body_value, + configurable_clientside_auth_params=model_info.configurable_clientside_auth_params, + ) + + +def is_request_body_safe( + request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str +) -> bool: + """ + Check if the request body is safe. + + A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key. + Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997 + """ + banned_params = ["api_base", "base_url"] + + for param in banned_params: + if ( + param in request_body + and not check_complete_credentials( # allow client-credentials to be passed to proxy + request_body=request_body + ) + ): + if general_settings.get("allow_client_side_credentials") is True: + return True + elif ( + _allow_model_level_clientside_configurable_parameters( + model=model, + param=param, + request_body_value=request_body[param], + llm_router=llm_router, + ) + is True + ): + return True + raise ValueError( + f"Rejected Request: {param} is not allowed in request body. " + "Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. " + "Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997", + ) + + return True + + +async def pre_db_read_auth_checks( + request: Request, + request_data: dict, + route: str, +): + """ + 1. Checks if request size is under max_request_size_mb (if set) + 2. Check if request body is safe (example user has not set api_base in request body) + 3. Check if IP address is allowed (if set) + 4. Check if request route is an allowed route on the proxy (if set) + + Returns: + - True + + Raises: + - HTTPException if request fails initial auth checks + """ + from litellm.proxy.proxy_server import general_settings, llm_router, premium_user + + # Check 1. request size + await check_if_request_size_is_safe(request=request) + + # Check 2. Request body is safe + is_request_body_safe( + request_body=request_data, + general_settings=general_settings, + llm_router=llm_router, + model=request_data.get( + "model", "" + ), # [TODO] use model passed in url as well (azure openai routes) + ) + + # Check 3. Check if IP address is allowed + is_valid_ip, passed_in_ip = _check_valid_ip( + allowed_ips=general_settings.get("allowed_ips", None), + use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), + request=request, + ) + + if not is_valid_ip: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Access forbidden: IP address {passed_in_ip} not allowed.", + ) + + # Check 4. Check if request route is an allowed route on the proxy + if "allowed_routes" in general_settings: + _allowed_routes = general_settings["allowed_routes"] + if premium_user is not True: + verbose_proxy_logger.error( + f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}" + ) + if route not in _allowed_routes: + verbose_proxy_logger.error( + f"Route {route} not in allowed_routes={_allowed_routes}" + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Access forbidden: Route {route} not allowed", + ) + + +def route_in_additonal_public_routes(current_route: str): + """ + Helper to check if the user defined public_routes on config.yaml + + Parameters: + - current_route: str - the route the user is trying to call + + Returns: + - bool - True if the route is defined in public_routes + - bool - False if the route is not defined in public_routes + + + In order to use this the litellm config.yaml should have the following in general_settings: + + ```yaml + general_settings: + master_key: sk-1234 + public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"] + ``` + """ + + # check if user is premium_user - if not do nothing + from litellm.proxy.proxy_server import general_settings, premium_user + + try: + if premium_user is not True: + return False + # check if this is defined on the config + if general_settings is None: + return False + + routes_defined = general_settings.get("public_routes", []) + if current_route in routes_defined: + return True + + return False + except Exception as e: + verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}") + return False + + +def get_request_route(request: Request) -> str: + """ + Helper to get the route from the request + + remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions + """ + try: + if hasattr(request, "base_url") and request.url.path.startswith( + request.base_url.path + ): + # remove base_url from path + return request.url.path[len(request.base_url.path) - 1 :] + else: + return request.url.path + except Exception as e: + verbose_proxy_logger.debug( + f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}" + ) + return request.url.path + + +async def check_if_request_size_is_safe(request: Request) -> bool: + """ + Enterprise Only: + - Checks if the request size is within the limit + + Args: + request (Request): The incoming request. + + Returns: + bool: True if the request size is within the limit + + Raises: + ProxyException: If the request size is too large + + """ + from litellm.proxy.proxy_server import general_settings, premium_user + + max_request_size_mb = general_settings.get("max_request_size_mb", None) + if max_request_size_mb is not None: + # Check if premium user + if premium_user is not True: + verbose_proxy_logger.warning( + f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}" + ) + return True + + # Get the request body + content_length = request.headers.get("content-length") + + if content_length: + header_size = int(content_length) + header_size_mb = bytes_to_mb(bytes_value=header_size) + verbose_proxy_logger.debug( + f"content_length request size in MB={header_size_mb}" + ) + + if header_size_mb > max_request_size_mb: + raise ProxyException( + message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB", + type=ProxyErrorTypes.bad_request_error.value, + code=400, + param="content-length", + ) + else: + # If Content-Length is not available, read the body + body = await request.body() + body_size = len(body) + request_size_mb = bytes_to_mb(bytes_value=body_size) + + verbose_proxy_logger.debug( + f"request body request size in MB={request_size_mb}" + ) + if request_size_mb > max_request_size_mb: + raise ProxyException( + message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB", + type=ProxyErrorTypes.bad_request_error.value, + code=400, + param="content-length", + ) + + return True + + +async def check_response_size_is_safe(response: Any) -> bool: + """ + Enterprise Only: + - Checks if the response size is within the limit + + Args: + response (Any): The response to check. + + Returns: + bool: True if the response size is within the limit + + Raises: + ProxyException: If the response size is too large + + """ + + from litellm.proxy.proxy_server import general_settings, premium_user + + max_response_size_mb = general_settings.get("max_response_size_mb", None) + if max_response_size_mb is not None: + # Check if premium user + if premium_user is not True: + verbose_proxy_logger.warning( + f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}" + ) + return True + + response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response)) + verbose_proxy_logger.debug(f"response size in MB={response_size_mb}") + if response_size_mb > max_response_size_mb: + raise ProxyException( + message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB", + type=ProxyErrorTypes.bad_request_error.value, + code=400, + param="content-length", + ) + + return True + + +def bytes_to_mb(bytes_value: int): + """ + Helper to convert bytes to MB + """ + return bytes_value / (1024 * 1024) + + +# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key +def get_key_model_rpm_limit( + user_api_key_dict: UserAPIKeyAuth, +) -> Optional[Dict[str, int]]: + if user_api_key_dict.metadata: + if "model_rpm_limit" in user_api_key_dict.metadata: + return user_api_key_dict.metadata["model_rpm_limit"] + elif user_api_key_dict.model_max_budget: + model_rpm_limit: Dict[str, Any] = {} + for model, budget in user_api_key_dict.model_max_budget.items(): + if "rpm_limit" in budget and budget["rpm_limit"] is not None: + model_rpm_limit[model] = budget["rpm_limit"] + return model_rpm_limit + + return None + + +def get_key_model_tpm_limit( + user_api_key_dict: UserAPIKeyAuth, +) -> Optional[Dict[str, int]]: + if user_api_key_dict.metadata: + if "model_tpm_limit" in user_api_key_dict.metadata: + return user_api_key_dict.metadata["model_tpm_limit"] + elif user_api_key_dict.model_max_budget: + if "tpm_limit" in user_api_key_dict.model_max_budget: + return user_api_key_dict.model_max_budget["tpm_limit"] + + return None + + +def is_pass_through_provider_route(route: str) -> bool: + PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [ + "vertex-ai", + ] + + # check if any of the prefixes are in the route + for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES: + if prefix in route: + return True + + return False + + +def should_run_auth_on_pass_through_provider_route(route: str) -> bool: + """ + Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on /vertex-ai/{endpoint} routes + Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on provider pass through routes + ex /vertex-ai/{endpoint} routes + Run virtual key auth if the following is try: + - User is premium_user + - User has enabled litellm_setting.use_client_credentials_pass_through_routes + """ + from litellm.proxy.proxy_server import general_settings, premium_user + + if premium_user is not True: + + return False + + # premium use has opted into using client credentials + if ( + general_settings.get("use_client_credentials_pass_through_routes", False) + is True + ): + return False + + # only enabled for LiteLLM Enterprise + return True + + +def _has_user_setup_sso(): + """ + Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables. + Returns a boolean indicating whether SSO has been set up. + """ + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + + sso_setup = ( + (microsoft_client_id is not None) + or (google_client_id is not None) + or (generic_client_id is not None) + ) + + return sso_setup + + +def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]: + # openai - check 'user' + if "user" in request_body and request_body["user"] is not None: + return str(request_body["user"]) + # anthropic - check 'litellm_metadata' + end_user_id = request_body.get("litellm_metadata", {}).get("user", None) + if end_user_id: + return str(end_user_id) + metadata = request_body.get("metadata") + if metadata and "user_id" in metadata and metadata["user_id"] is not None: + return str(metadata["user_id"]) + return None |