aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py514
1 files changed, 514 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py
new file mode 100644
index 00000000..91fcaf7e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py
@@ -0,0 +1,514 @@
+import os
+import re
+import sys
+from typing import Any, List, Optional, Tuple
+
+from fastapi import HTTPException, Request, status
+
+from litellm import Router, provider_list
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import *
+from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS
+
+
+def _get_request_ip_address(
+ request: Request, use_x_forwarded_for: Optional[bool] = False
+) -> Optional[str]:
+
+ client_ip = None
+ if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
+ client_ip = request.headers["x-forwarded-for"]
+ elif request.client is not None:
+ client_ip = request.client.host
+ else:
+ client_ip = ""
+
+ return client_ip
+
+
+def _check_valid_ip(
+ allowed_ips: Optional[List[str]],
+ request: Request,
+ use_x_forwarded_for: Optional[bool] = False,
+) -> Tuple[bool, Optional[str]]:
+ """
+ Returns if ip is allowed or not
+ """
+ if allowed_ips is None: # if not set, assume true
+ return True, None
+
+ # if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
+ client_ip = _get_request_ip_address(
+ request=request, use_x_forwarded_for=use_x_forwarded_for
+ )
+
+ # Check if IP address is allowed
+ if client_ip not in allowed_ips:
+ return False, client_ip
+
+ return True, client_ip
+
+
+def check_complete_credentials(request_body: dict) -> bool:
+ """
+ if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks.
+ """
+ given_model: Optional[str] = None
+
+ given_model = request_body.get("model")
+ if given_model is None:
+ return False
+
+ if (
+ "sagemaker" in given_model
+ or "bedrock" in given_model
+ or "vertex_ai" in given_model
+ or "vertex_ai_beta" in given_model
+ ):
+ # complex credentials - easier to make a malicious request
+ return False
+
+ if "api_key" in request_body:
+ return True
+
+ return False
+
+
+def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
+ """
+ Check if request_body_value matches the regex_str or is equal to param
+ """
+ if re.match(regex_str, request_body_value) or regex_str == request_body_value:
+ return True
+ return False
+
+
+def _is_param_allowed(
+ param: str,
+ request_body_value: Any,
+ configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
+) -> bool:
+ """
+ Check if param is a str or dict and if request_body_value is in the list of allowed values
+ """
+ if configurable_clientside_auth_params is None:
+ return False
+
+ for item in configurable_clientside_auth_params:
+ if isinstance(item, str) and param == item:
+ return True
+ elif isinstance(item, Dict):
+ if param == "api_base" and check_regex_or_str_match(
+ request_body_value=request_body_value,
+ regex_str=item["api_base"],
+ ): # assume param is a regex
+ return True
+
+ return False
+
+
+def _allow_model_level_clientside_configurable_parameters(
+ model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
+) -> bool:
+ """
+ Check if model is allowed to use configurable client-side params
+ - get matching model
+ - check if 'clientside_configurable_parameters' is set for model
+ -
+ """
+ if llm_router is None:
+ return False
+ # check if model is set
+ model_info = llm_router.get_model_group_info(model_group=model)
+ if model_info is None:
+ # check if wildcard model is set
+ if model.split("/", 1)[0] in provider_list:
+ model_info = llm_router.get_model_group_info(
+ model_group=model.split("/", 1)[0]
+ )
+
+ if model_info is None:
+ return False
+
+ if model_info is None or model_info.configurable_clientside_auth_params is None:
+ return False
+
+ return _is_param_allowed(
+ param=param,
+ request_body_value=request_body_value,
+ configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
+ )
+
+
+def is_request_body_safe(
+ request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
+) -> bool:
+ """
+ Check if the request body is safe.
+
+ A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key.
+ Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997
+ """
+ banned_params = ["api_base", "base_url"]
+
+ for param in banned_params:
+ if (
+ param in request_body
+ and not check_complete_credentials( # allow client-credentials to be passed to proxy
+ request_body=request_body
+ )
+ ):
+ if general_settings.get("allow_client_side_credentials") is True:
+ return True
+ elif (
+ _allow_model_level_clientside_configurable_parameters(
+ model=model,
+ param=param,
+ request_body_value=request_body[param],
+ llm_router=llm_router,
+ )
+ is True
+ ):
+ return True
+ raise ValueError(
+ f"Rejected Request: {param} is not allowed in request body. "
+ "Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
+ "Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
+ )
+
+ return True
+
+
+async def pre_db_read_auth_checks(
+ request: Request,
+ request_data: dict,
+ route: str,
+):
+ """
+ 1. Checks if request size is under max_request_size_mb (if set)
+ 2. Check if request body is safe (example user has not set api_base in request body)
+ 3. Check if IP address is allowed (if set)
+ 4. Check if request route is an allowed route on the proxy (if set)
+
+ Returns:
+ - True
+
+ Raises:
+ - HTTPException if request fails initial auth checks
+ """
+ from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
+
+ # Check 1. request size
+ await check_if_request_size_is_safe(request=request)
+
+ # Check 2. Request body is safe
+ is_request_body_safe(
+ request_body=request_data,
+ general_settings=general_settings,
+ llm_router=llm_router,
+ model=request_data.get(
+ "model", ""
+ ), # [TODO] use model passed in url as well (azure openai routes)
+ )
+
+ # Check 3. Check if IP address is allowed
+ is_valid_ip, passed_in_ip = _check_valid_ip(
+ allowed_ips=general_settings.get("allowed_ips", None),
+ use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
+ request=request,
+ )
+
+ if not is_valid_ip:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
+ )
+
+ # Check 4. Check if request route is an allowed route on the proxy
+ if "allowed_routes" in general_settings:
+ _allowed_routes = general_settings["allowed_routes"]
+ if premium_user is not True:
+ verbose_proxy_logger.error(
+ f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}"
+ )
+ if route not in _allowed_routes:
+ verbose_proxy_logger.error(
+ f"Route {route} not in allowed_routes={_allowed_routes}"
+ )
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"Access forbidden: Route {route} not allowed",
+ )
+
+
+def route_in_additonal_public_routes(current_route: str):
+ """
+ Helper to check if the user defined public_routes on config.yaml
+
+ Parameters:
+ - current_route: str - the route the user is trying to call
+
+ Returns:
+ - bool - True if the route is defined in public_routes
+ - bool - False if the route is not defined in public_routes
+
+
+ In order to use this the litellm config.yaml should have the following in general_settings:
+
+ ```yaml
+ general_settings:
+ master_key: sk-1234
+ public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"]
+ ```
+ """
+
+ # check if user is premium_user - if not do nothing
+ from litellm.proxy.proxy_server import general_settings, premium_user
+
+ try:
+ if premium_user is not True:
+ return False
+ # check if this is defined on the config
+ if general_settings is None:
+ return False
+
+ routes_defined = general_settings.get("public_routes", [])
+ if current_route in routes_defined:
+ return True
+
+ return False
+ except Exception as e:
+ verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}")
+ return False
+
+
+def get_request_route(request: Request) -> str:
+ """
+ Helper to get the route from the request
+
+ remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions
+ """
+ try:
+ if hasattr(request, "base_url") and request.url.path.startswith(
+ request.base_url.path
+ ):
+ # remove base_url from path
+ return request.url.path[len(request.base_url.path) - 1 :]
+ else:
+ return request.url.path
+ except Exception as e:
+ verbose_proxy_logger.debug(
+ f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}"
+ )
+ return request.url.path
+
+
+async def check_if_request_size_is_safe(request: Request) -> bool:
+ """
+ Enterprise Only:
+ - Checks if the request size is within the limit
+
+ Args:
+ request (Request): The incoming request.
+
+ Returns:
+ bool: True if the request size is within the limit
+
+ Raises:
+ ProxyException: If the request size is too large
+
+ """
+ from litellm.proxy.proxy_server import general_settings, premium_user
+
+ max_request_size_mb = general_settings.get("max_request_size_mb", None)
+ if max_request_size_mb is not None:
+ # Check if premium user
+ if premium_user is not True:
+ verbose_proxy_logger.warning(
+ f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
+ )
+ return True
+
+ # Get the request body
+ content_length = request.headers.get("content-length")
+
+ if content_length:
+ header_size = int(content_length)
+ header_size_mb = bytes_to_mb(bytes_value=header_size)
+ verbose_proxy_logger.debug(
+ f"content_length request size in MB={header_size_mb}"
+ )
+
+ if header_size_mb > max_request_size_mb:
+ raise ProxyException(
+ message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB",
+ type=ProxyErrorTypes.bad_request_error.value,
+ code=400,
+ param="content-length",
+ )
+ else:
+ # If Content-Length is not available, read the body
+ body = await request.body()
+ body_size = len(body)
+ request_size_mb = bytes_to_mb(bytes_value=body_size)
+
+ verbose_proxy_logger.debug(
+ f"request body request size in MB={request_size_mb}"
+ )
+ if request_size_mb > max_request_size_mb:
+ raise ProxyException(
+ message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB",
+ type=ProxyErrorTypes.bad_request_error.value,
+ code=400,
+ param="content-length",
+ )
+
+ return True
+
+
+async def check_response_size_is_safe(response: Any) -> bool:
+ """
+ Enterprise Only:
+ - Checks if the response size is within the limit
+
+ Args:
+ response (Any): The response to check.
+
+ Returns:
+ bool: True if the response size is within the limit
+
+ Raises:
+ ProxyException: If the response size is too large
+
+ """
+
+ from litellm.proxy.proxy_server import general_settings, premium_user
+
+ max_response_size_mb = general_settings.get("max_response_size_mb", None)
+ if max_response_size_mb is not None:
+ # Check if premium user
+ if premium_user is not True:
+ verbose_proxy_logger.warning(
+ f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
+ )
+ return True
+
+ response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response))
+ verbose_proxy_logger.debug(f"response size in MB={response_size_mb}")
+ if response_size_mb > max_response_size_mb:
+ raise ProxyException(
+ message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB",
+ type=ProxyErrorTypes.bad_request_error.value,
+ code=400,
+ param="content-length",
+ )
+
+ return True
+
+
+def bytes_to_mb(bytes_value: int):
+ """
+ Helper to convert bytes to MB
+ """
+ return bytes_value / (1024 * 1024)
+
+
+# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
+def get_key_model_rpm_limit(
+ user_api_key_dict: UserAPIKeyAuth,
+) -> Optional[Dict[str, int]]:
+ if user_api_key_dict.metadata:
+ if "model_rpm_limit" in user_api_key_dict.metadata:
+ return user_api_key_dict.metadata["model_rpm_limit"]
+ elif user_api_key_dict.model_max_budget:
+ model_rpm_limit: Dict[str, Any] = {}
+ for model, budget in user_api_key_dict.model_max_budget.items():
+ if "rpm_limit" in budget and budget["rpm_limit"] is not None:
+ model_rpm_limit[model] = budget["rpm_limit"]
+ return model_rpm_limit
+
+ return None
+
+
+def get_key_model_tpm_limit(
+ user_api_key_dict: UserAPIKeyAuth,
+) -> Optional[Dict[str, int]]:
+ if user_api_key_dict.metadata:
+ if "model_tpm_limit" in user_api_key_dict.metadata:
+ return user_api_key_dict.metadata["model_tpm_limit"]
+ elif user_api_key_dict.model_max_budget:
+ if "tpm_limit" in user_api_key_dict.model_max_budget:
+ return user_api_key_dict.model_max_budget["tpm_limit"]
+
+ return None
+
+
+def is_pass_through_provider_route(route: str) -> bool:
+ PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
+ "vertex-ai",
+ ]
+
+ # check if any of the prefixes are in the route
+ for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES:
+ if prefix in route:
+ return True
+
+ return False
+
+
+def should_run_auth_on_pass_through_provider_route(route: str) -> bool:
+ """
+ Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on /vertex-ai/{endpoint} routes
+ Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on provider pass through routes
+ ex /vertex-ai/{endpoint} routes
+ Run virtual key auth if the following is try:
+ - User is premium_user
+ - User has enabled litellm_setting.use_client_credentials_pass_through_routes
+ """
+ from litellm.proxy.proxy_server import general_settings, premium_user
+
+ if premium_user is not True:
+
+ return False
+
+ # premium use has opted into using client credentials
+ if (
+ general_settings.get("use_client_credentials_pass_through_routes", False)
+ is True
+ ):
+ return False
+
+ # only enabled for LiteLLM Enterprise
+ return True
+
+
+def _has_user_setup_sso():
+ """
+ Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables.
+ Returns a boolean indicating whether SSO has been set up.
+ """
+ microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
+ google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
+ generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
+
+ sso_setup = (
+ (microsoft_client_id is not None)
+ or (google_client_id is not None)
+ or (generic_client_id is not None)
+ )
+
+ return sso_setup
+
+
+def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]:
+ # openai - check 'user'
+ if "user" in request_body and request_body["user"] is not None:
+ return str(request_body["user"])
+ # anthropic - check 'litellm_metadata'
+ end_user_id = request_body.get("litellm_metadata", {}).get("user", None)
+ if end_user_id:
+ return str(end_user_id)
+ metadata = request_body.get("metadata")
+ if metadata and "user_id" in metadata and metadata["user_id"] is not None:
+ return str(metadata["user_id"])
+ return None