diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/litellm_pre_call_utils.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/litellm_pre_call_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/litellm_pre_call_utils.py | 901 |
1 files changed, 901 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/litellm_pre_call_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/litellm_pre_call_utils.py new file mode 100644 index 00000000..ece5ecf4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/litellm_pre_call_utils.py @@ -0,0 +1,901 @@ +import asyncio +import copy +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from fastapi import Request +from starlette.datastructures import Headers + +import litellm +from litellm._logging import verbose_logger, verbose_proxy_logger +from litellm._service_logger import ServiceLogging +from litellm.proxy._types import ( + AddTeamCallback, + CommonProxyErrors, + LitellmDataForBackendLLMCall, + SpecialHeaders, + TeamCallbackMetadata, + UserAPIKeyAuth, +) +from litellm.proxy.auth.route_checks import RouteChecks +from litellm.router import Router +from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS +from litellm.types.services import ServiceTypes +from litellm.types.utils import ( + ProviderSpecificHeader, + StandardLoggingUserAPIKeyMetadata, + SupportedCacheControls, +) + +service_logger_obj = ServiceLogging() # used for tracking latency on OTEL + + +if TYPE_CHECKING: + from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig + + ProxyConfig = _ProxyConfig +else: + ProxyConfig = Any + + +def parse_cache_control(cache_control): + cache_dict = {} + directives = cache_control.split(", ") + + for directive in directives: + if "=" in directive: + key, value = directive.split("=") + cache_dict[key] = value + else: + cache_dict[directive] = True + + return cache_dict + + +def _get_metadata_variable_name(request: Request) -> str: + """ + Helper to return what the "metadata" field should be called in the request data + + For all /thread or /assistant endpoints we need to call this "litellm_metadata" + + For ALL other endpoints we call this "metadata + """ + if RouteChecks._is_assistants_api_request(request): + return "litellm_metadata" + + LITELLM_METADATA_ROUTES = [ + "batches", + "/v1/messages", + "responses", + ] + if any( + [ + litellm_metadata_route in request.url.path + for litellm_metadata_route in LITELLM_METADATA_ROUTES + ] + ): + return "litellm_metadata" + else: + return "metadata" + + +def safe_add_api_version_from_query_params(data: dict, request: Request): + try: + if hasattr(request, "query_params"): + query_params = dict(request.query_params) + if "api-version" in query_params: + data["api_version"] = query_params["api-version"] + except KeyError: + pass + except Exception as e: + verbose_logger.exception( + "error checking api version in query params: %s", str(e) + ) + + +def convert_key_logging_metadata_to_callback( + data: AddTeamCallback, team_callback_settings_obj: Optional[TeamCallbackMetadata] +) -> TeamCallbackMetadata: + if team_callback_settings_obj is None: + team_callback_settings_obj = TeamCallbackMetadata() + if data.callback_type == "success": + if team_callback_settings_obj.success_callback is None: + team_callback_settings_obj.success_callback = [] + + if data.callback_name not in team_callback_settings_obj.success_callback: + team_callback_settings_obj.success_callback.append(data.callback_name) + elif data.callback_type == "failure": + if team_callback_settings_obj.failure_callback is None: + team_callback_settings_obj.failure_callback = [] + + if data.callback_name not in team_callback_settings_obj.failure_callback: + team_callback_settings_obj.failure_callback.append(data.callback_name) + elif ( + not data.callback_type or data.callback_type == "success_and_failure" + ): # assume 'success_and_failure' = litellm.callbacks + if team_callback_settings_obj.success_callback is None: + team_callback_settings_obj.success_callback = [] + if team_callback_settings_obj.failure_callback is None: + team_callback_settings_obj.failure_callback = [] + if team_callback_settings_obj.callbacks is None: + team_callback_settings_obj.callbacks = [] + + if data.callback_name not in team_callback_settings_obj.success_callback: + team_callback_settings_obj.success_callback.append(data.callback_name) + + if data.callback_name not in team_callback_settings_obj.failure_callback: + team_callback_settings_obj.failure_callback.append(data.callback_name) + + if data.callback_name not in team_callback_settings_obj.callbacks: + team_callback_settings_obj.callbacks.append(data.callback_name) + + for var, value in data.callback_vars.items(): + if team_callback_settings_obj.callback_vars is None: + team_callback_settings_obj.callback_vars = {} + team_callback_settings_obj.callback_vars[var] = str( + litellm.utils.get_secret(value, default_value=value) or value + ) + + return team_callback_settings_obj + + +def _get_dynamic_logging_metadata( + user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig +) -> Optional[TeamCallbackMetadata]: + callback_settings_obj: Optional[TeamCallbackMetadata] = None + if ( + user_api_key_dict.metadata is not None + and "logging" in user_api_key_dict.metadata + ): + for item in user_api_key_dict.metadata["logging"]: + callback_settings_obj = convert_key_logging_metadata_to_callback( + data=AddTeamCallback(**item), + team_callback_settings_obj=callback_settings_obj, + ) + elif ( + user_api_key_dict.team_metadata is not None + and "callback_settings" in user_api_key_dict.team_metadata + ): + """ + callback_settings = { + { + 'callback_vars': {'langfuse_public_key': 'pk', 'langfuse_secret_key': 'sk_'}, + 'failure_callback': [], + 'success_callback': ['langfuse', 'langfuse'] + } + } + """ + team_metadata = user_api_key_dict.team_metadata + callback_settings = team_metadata.get("callback_settings", None) or {} + callback_settings_obj = TeamCallbackMetadata(**callback_settings) + verbose_proxy_logger.debug( + "Team callback settings activated: %s", callback_settings_obj + ) + elif user_api_key_dict.team_id is not None: + callback_settings_obj = ( + LiteLLMProxyRequestSetup.add_team_based_callbacks_from_config( + team_id=user_api_key_dict.team_id, proxy_config=proxy_config + ) + ) + return callback_settings_obj + + +def clean_headers( + headers: Headers, litellm_key_header_name: Optional[str] = None +) -> dict: + """ + Removes litellm api key from headers + """ + special_headers = [v.value.lower() for v in SpecialHeaders._member_map_.values()] + special_headers = special_headers + if litellm_key_header_name is not None: + special_headers.append(litellm_key_header_name.lower()) + clean_headers = {} + for header, value in headers.items(): + if header.lower() not in special_headers: + clean_headers[header] = value + return clean_headers + + +class LiteLLMProxyRequestSetup: + @staticmethod + def _get_timeout_from_request(headers: dict) -> Optional[float]: + """ + Workaround for client request from Vercel's AI SDK. + + Allow's user to set a timeout in the request headers. + + Example: + + ```js + const openaiProvider = createOpenAI({ + baseURL: liteLLM.baseURL, + apiKey: liteLLM.apiKey, + compatibility: "compatible", + headers: { + "x-litellm-timeout": "90" + }, + }); + ``` + """ + timeout_header = headers.get("x-litellm-timeout", None) + if timeout_header is not None: + return float(timeout_header) + return None + + @staticmethod + def _get_forwardable_headers( + headers: Union[Headers, dict], + ): + """ + Get the headers that should be forwarded to the LLM Provider. + + Looks for any `x-` headers and sends them to the LLM Provider. + """ + forwarded_headers = {} + for header, value in headers.items(): + if header.lower().startswith("x-") and not header.lower().startswith( + "x-stainless" + ): # causes openai sdk to fail + forwarded_headers[header] = value + + return forwarded_headers + + @staticmethod + def get_openai_org_id_from_headers( + headers: dict, general_settings: Optional[Dict] = None + ) -> Optional[str]: + """ + Get the OpenAI Org ID from the headers. + """ + if ( + general_settings is not None + and general_settings.get("forward_openai_org_id") is not True + ): + return None + for header, value in headers.items(): + if header.lower() == "openai-organization": + verbose_logger.info(f"found openai org id: {value}, sending to llm") + return value + return None + + @staticmethod + def add_headers_to_llm_call( + headers: dict, user_api_key_dict: UserAPIKeyAuth + ) -> dict: + """ + Add headers to the LLM call + + - Checks request headers for forwardable headers + - Checks if user information should be added to the headers + """ + + returned_headers = LiteLLMProxyRequestSetup._get_forwardable_headers(headers) + + if litellm.add_user_information_to_llm_headers is True: + litellm_logging_metadata_headers = ( + LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key( + user_api_key_dict=user_api_key_dict + ) + ) + for k, v in litellm_logging_metadata_headers.items(): + if v is not None: + returned_headers["x-litellm-{}".format(k)] = v + + return returned_headers + + @staticmethod + def add_litellm_data_for_backend_llm_call( + *, + headers: dict, + user_api_key_dict: UserAPIKeyAuth, + general_settings: Optional[Dict[str, Any]] = None, + ) -> LitellmDataForBackendLLMCall: + """ + - Adds forwardable headers + - Adds org id + """ + data = LitellmDataForBackendLLMCall() + if ( + general_settings + and general_settings.get("forward_client_headers_to_llm_api") is True + ): + _headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call( + headers, user_api_key_dict + ) + if _headers != {}: + data["headers"] = _headers + _organization = LiteLLMProxyRequestSetup.get_openai_org_id_from_headers( + headers, general_settings + ) + if _organization is not None: + data["organization"] = _organization + + timeout = LiteLLMProxyRequestSetup._get_timeout_from_request(headers) + if timeout is not None: + data["timeout"] = timeout + + return data + + @staticmethod + def get_sanitized_user_information_from_key( + user_api_key_dict: UserAPIKeyAuth, + ) -> StandardLoggingUserAPIKeyMetadata: + user_api_key_logged_metadata = StandardLoggingUserAPIKeyMetadata( + user_api_key_hash=user_api_key_dict.api_key, # just the hashed token + user_api_key_alias=user_api_key_dict.key_alias, + user_api_key_team_id=user_api_key_dict.team_id, + user_api_key_user_id=user_api_key_dict.user_id, + user_api_key_org_id=user_api_key_dict.org_id, + user_api_key_team_alias=user_api_key_dict.team_alias, + user_api_key_end_user_id=user_api_key_dict.end_user_id, + user_api_key_user_email=user_api_key_dict.user_email, + ) + return user_api_key_logged_metadata + + @staticmethod + def add_key_level_controls( + key_metadata: dict, data: dict, _metadata_variable_name: str + ): + if "cache" in key_metadata: + data["cache"] = {} + if isinstance(key_metadata["cache"], dict): + for k, v in key_metadata["cache"].items(): + if k in SupportedCacheControls: + data["cache"][k] = v + + ## KEY-LEVEL SPEND LOGS / TAGS + if "tags" in key_metadata and key_metadata["tags"] is not None: + data[_metadata_variable_name]["tags"] = ( + LiteLLMProxyRequestSetup._merge_tags( + request_tags=data[_metadata_variable_name].get("tags"), + tags_to_add=key_metadata["tags"], + ) + ) + if "spend_logs_metadata" in key_metadata and isinstance( + key_metadata["spend_logs_metadata"], dict + ): + if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance( + data[_metadata_variable_name]["spend_logs_metadata"], dict + ): + for key, value in key_metadata["spend_logs_metadata"].items(): + if ( + key not in data[_metadata_variable_name]["spend_logs_metadata"] + ): # don't override k-v pair sent by request (user request) + data[_metadata_variable_name]["spend_logs_metadata"][ + key + ] = value + else: + data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[ + "spend_logs_metadata" + ] + + ## KEY-LEVEL DISABLE FALLBACKS + if "disable_fallbacks" in key_metadata and isinstance( + key_metadata["disable_fallbacks"], bool + ): + data["disable_fallbacks"] = key_metadata["disable_fallbacks"] + return data + + @staticmethod + def _merge_tags(request_tags: Optional[list], tags_to_add: Optional[list]) -> list: + """ + Helper function to merge two lists of tags, ensuring no duplicates. + + Args: + request_tags (Optional[list]): List of tags from the original request + tags_to_add (Optional[list]): List of tags to add + + Returns: + list: Combined list of unique tags + """ + final_tags = [] + + if request_tags and isinstance(request_tags, list): + final_tags.extend(request_tags) + + if tags_to_add and isinstance(tags_to_add, list): + for tag in tags_to_add: + if tag not in final_tags: + final_tags.append(tag) + + return final_tags + + @staticmethod + def add_team_based_callbacks_from_config( + team_id: str, + proxy_config: ProxyConfig, + ) -> Optional[TeamCallbackMetadata]: + """ + Add team-based callbacks from the config + """ + team_config = proxy_config.load_team_config(team_id=team_id) + if len(team_config.keys()) == 0: + return None + + callback_vars_dict = {**team_config.get("callback_vars", team_config)} + callback_vars_dict.pop("team_id", None) + callback_vars_dict.pop("success_callback", None) + callback_vars_dict.pop("failure_callback", None) + + return TeamCallbackMetadata( + success_callback=team_config.get("success_callback", None), + failure_callback=team_config.get("failure_callback", None), + callback_vars=callback_vars_dict, + ) + + @staticmethod + def add_request_tag_to_metadata( + llm_router: Optional[Router], + headers: dict, + data: dict, + ) -> Optional[List[str]]: + tags = None + + if llm_router and llm_router.enable_tag_filtering is True: + # Check request headers for tags + if "x-litellm-tags" in headers: + if isinstance(headers["x-litellm-tags"], str): + _tags = headers["x-litellm-tags"].split(",") + tags = [tag.strip() for tag in _tags] + elif isinstance(headers["x-litellm-tags"], list): + tags = headers["x-litellm-tags"] + # Check request body for tags + if "tags" in data and isinstance(data["tags"], list): + tags = data["tags"] + + return tags + + +async def add_litellm_data_to_request( # noqa: PLR0915 + data: dict, + request: Request, + user_api_key_dict: UserAPIKeyAuth, + proxy_config: ProxyConfig, + general_settings: Optional[Dict[str, Any]] = None, + version: Optional[str] = None, +): + """ + Adds LiteLLM-specific data to the request. + + Args: + data (dict): The data dictionary to be modified. + request (Request): The incoming request. + user_api_key_dict (UserAPIKeyAuth): The user API key dictionary. + general_settings (Optional[Dict[str, Any]], optional): General settings. Defaults to None. + version (Optional[str], optional): Version. Defaults to None. + + Returns: + dict: The modified data dictionary. + + """ + + from litellm.proxy.proxy_server import llm_router, premium_user + + safe_add_api_version_from_query_params(data, request) + + _headers = clean_headers( + request.headers, + litellm_key_header_name=( + general_settings.get("litellm_key_header_name") + if general_settings is not None + else None + ), + ) + + data.update( + LiteLLMProxyRequestSetup.add_litellm_data_for_backend_llm_call( + headers=_headers, + user_api_key_dict=user_api_key_dict, + general_settings=general_settings, + ) + ) + + # Include original request and headers in the data + data["proxy_server_request"] = { + "url": str(request.url), + "method": request.method, + "headers": _headers, + "body": copy.copy(data), # use copy instead of deepcopy + } + + ## Dynamic api version (Azure OpenAI endpoints) ## + try: + query_params = request.query_params + # Convert query parameters to a dictionary (optional) + query_dict = dict(query_params) + except KeyError: + query_dict = {} + + ## check for api version in query params + dynamic_api_version: Optional[str] = query_dict.get("api-version") + + if dynamic_api_version is not None: # only pass, if set + data["api_version"] = dynamic_api_version + + ## Forward any LLM API Provider specific headers in extra_headers + add_provider_specific_headers_to_request(data=data, headers=_headers) + + ## Cache Controls + headers = request.headers + verbose_proxy_logger.debug("Request Headers: %s", headers) + cache_control_header = headers.get("Cache-Control", None) + if cache_control_header: + cache_dict = parse_cache_control(cache_control_header) + data["ttl"] = cache_dict.get("s-maxage") + + verbose_proxy_logger.debug("receiving data: %s", data) + + _metadata_variable_name = _get_metadata_variable_name(request) + + if _metadata_variable_name not in data: + data[_metadata_variable_name] = {} + + # We want to log the "metadata" from the client side request. Avoid circular reference by not directly assigning metadata to itself. + if "metadata" in data and data["metadata"] is not None: + data[_metadata_variable_name]["requester_metadata"] = copy.deepcopy( + data["metadata"] + ) + + user_api_key_logged_metadata = ( + LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key( + user_api_key_dict=user_api_key_dict + ) + ) + data[_metadata_variable_name].update(user_api_key_logged_metadata) + data[_metadata_variable_name][ + "user_api_key" + ] = ( + user_api_key_dict.api_key + ) # this is just the hashed token. [TODO]: replace variable name in repo. + + data[_metadata_variable_name]["user_api_end_user_max_budget"] = getattr( + user_api_key_dict, "end_user_max_budget", None + ) + + data[_metadata_variable_name]["litellm_api_version"] = version + + if general_settings is not None: + data[_metadata_variable_name]["global_max_parallel_requests"] = ( + general_settings.get("global_max_parallel_requests", None) + ) + + ### KEY-LEVEL Controls + key_metadata = user_api_key_dict.metadata + data = LiteLLMProxyRequestSetup.add_key_level_controls( + key_metadata=key_metadata, + data=data, + _metadata_variable_name=_metadata_variable_name, + ) + ## TEAM-LEVEL SPEND LOGS/TAGS + team_metadata = user_api_key_dict.team_metadata or {} + if "tags" in team_metadata and team_metadata["tags"] is not None: + data[_metadata_variable_name]["tags"] = LiteLLMProxyRequestSetup._merge_tags( + request_tags=data[_metadata_variable_name].get("tags"), + tags_to_add=team_metadata["tags"], + ) + if "spend_logs_metadata" in team_metadata and isinstance( + team_metadata["spend_logs_metadata"], dict + ): + if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance( + data[_metadata_variable_name]["spend_logs_metadata"], dict + ): + for key, value in team_metadata["spend_logs_metadata"].items(): + if ( + key not in data[_metadata_variable_name]["spend_logs_metadata"] + ): # don't override k-v pair sent by request (user request) + data[_metadata_variable_name]["spend_logs_metadata"][key] = value + else: + data[_metadata_variable_name]["spend_logs_metadata"] = team_metadata[ + "spend_logs_metadata" + ] + + # Team spend, budget - used by prometheus.py + data[_metadata_variable_name][ + "user_api_key_team_max_budget" + ] = user_api_key_dict.team_max_budget + data[_metadata_variable_name][ + "user_api_key_team_spend" + ] = user_api_key_dict.team_spend + + # API Key spend, budget - used by prometheus.py + data[_metadata_variable_name]["user_api_key_spend"] = user_api_key_dict.spend + data[_metadata_variable_name][ + "user_api_key_max_budget" + ] = user_api_key_dict.max_budget + data[_metadata_variable_name][ + "user_api_key_model_max_budget" + ] = user_api_key_dict.model_max_budget + + data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data[_metadata_variable_name]["headers"] = _headers + data[_metadata_variable_name]["endpoint"] = str(request.url) + + # OTEL Controls / Tracing + # Add the OTEL Parent Trace before sending it LiteLLM + data[_metadata_variable_name][ + "litellm_parent_otel_span" + ] = user_api_key_dict.parent_otel_span + _add_otel_traceparent_to_data(data, request=request) + + ### END-USER SPECIFIC PARAMS ### + if user_api_key_dict.allowed_model_region is not None: + data["allowed_model_region"] = user_api_key_dict.allowed_model_region + start_time = time.time() + ## [Enterprise Only] + # Add User-IP Address + requester_ip_address = "" + if premium_user is True: + # Only set the IP Address for Enterprise Users + + # logic for tracking IP Address + if ( + general_settings is not None + and general_settings.get("use_x_forwarded_for") is True + and request is not None + and hasattr(request, "headers") + and "x-forwarded-for" in request.headers + ): + requester_ip_address = request.headers["x-forwarded-for"] + elif ( + request is not None + and hasattr(request, "client") + and hasattr(request.client, "host") + and request.client is not None + ): + requester_ip_address = request.client.host + data[_metadata_variable_name]["requester_ip_address"] = requester_ip_address + + # Check if using tag based routing + tags = LiteLLMProxyRequestSetup.add_request_tag_to_metadata( + llm_router=llm_router, + headers=dict(request.headers), + data=data, + ) + + if tags is not None: + data[_metadata_variable_name]["tags"] = tags + + # Team Callbacks controls + callback_settings_obj = _get_dynamic_logging_metadata( + user_api_key_dict=user_api_key_dict, proxy_config=proxy_config + ) + if callback_settings_obj is not None: + data["success_callback"] = callback_settings_obj.success_callback + data["failure_callback"] = callback_settings_obj.failure_callback + + if callback_settings_obj.callback_vars is not None: + # unpack callback_vars in data + for k, v in callback_settings_obj.callback_vars.items(): + data[k] = v + + # Guardrails + move_guardrails_to_metadata( + data=data, + _metadata_variable_name=_metadata_variable_name, + user_api_key_dict=user_api_key_dict, + ) + + # Team Model Aliases + _update_model_if_team_alias_exists( + data=data, + user_api_key_dict=user_api_key_dict, + ) + + verbose_proxy_logger.debug( + "[PROXY] returned data from litellm_pre_call_utils: %s", data + ) + + ## ENFORCED PARAMS CHECK + # loop through each enforced param + # example enforced_params ['user', 'metadata', 'metadata.generation_name'] + _enforced_params_check( + request_body=data, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + premium_user=premium_user, + ) + + end_time = time.time() + asyncio.create_task( + service_logger_obj.async_service_success_hook( + service=ServiceTypes.PROXY_PRE_CALL, + duration=end_time - start_time, + call_type="add_litellm_data_to_request", + start_time=start_time, + end_time=end_time, + parent_otel_span=user_api_key_dict.parent_otel_span, + ) + ) + + return data + + +def _update_model_if_team_alias_exists( + data: dict, + user_api_key_dict: UserAPIKeyAuth, +) -> None: + """ + Update the model if the team alias exists + + If a alias map has been set on a team, then we want to make the request with the model the team alias is pointing to + + eg. + - user calls `gpt-4o` + - team.model_alias_map = { + "gpt-4o": "gpt-4o-team-1" + } + - requested_model = "gpt-4o-team-1" + """ + _model = data.get("model") + if ( + _model + and user_api_key_dict.team_model_aliases + and _model in user_api_key_dict.team_model_aliases + ): + data["model"] = user_api_key_dict.team_model_aliases[_model] + return + + +def _get_enforced_params( + general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth +) -> Optional[list]: + enforced_params: Optional[list] = None + if general_settings is not None: + enforced_params = general_settings.get("enforced_params") + if "service_account_settings" in general_settings: + service_account_settings = general_settings["service_account_settings"] + if "enforced_params" in service_account_settings: + if enforced_params is None: + enforced_params = [] + enforced_params.extend(service_account_settings["enforced_params"]) + if user_api_key_dict.metadata.get("enforced_params", None) is not None: + if enforced_params is None: + enforced_params = [] + enforced_params.extend(user_api_key_dict.metadata["enforced_params"]) + return enforced_params + + +def _enforced_params_check( + request_body: dict, + general_settings: Optional[dict], + user_api_key_dict: UserAPIKeyAuth, + premium_user: bool, +) -> bool: + """ + If enforced params are set, check if the request body contains the enforced params. + """ + enforced_params: Optional[list] = _get_enforced_params( + general_settings=general_settings, user_api_key_dict=user_api_key_dict + ) + if enforced_params is None: + return True + if enforced_params is not None and premium_user is not True: + raise ValueError( + f"Enforced Params is an Enterprise feature. Enforced Params: {enforced_params}. {CommonProxyErrors.not_premium_user.value}" + ) + + for enforced_param in enforced_params: + _enforced_params = enforced_param.split(".") + if len(_enforced_params) == 1: + if _enforced_params[0] not in request_body: + raise ValueError( + f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param" + ) + elif len(_enforced_params) == 2: + # this is a scenario where user requires request['metadata']['generation_name'] to exist + if _enforced_params[0] not in request_body: + raise ValueError( + f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param" + ) + if _enforced_params[1] not in request_body[_enforced_params[0]]: + raise ValueError( + f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param" + ) + return True + + +def _add_guardrails_from_key_or_team_metadata( + key_metadata: Optional[dict], + team_metadata: Optional[dict], + data: dict, + metadata_variable_name: str, +) -> None: + """ + Helper add guardrails from key or team metadata to request data + + Args: + key_metadata: The key metadata dictionary to check for guardrails + team_metadata: The team metadata dictionary to check for guardrails + data: The request data to update + metadata_variable_name: The name of the metadata field in data + + """ + from litellm.proxy.utils import _premium_user_check + + for _management_object_metadata in [key_metadata, team_metadata]: + if _management_object_metadata and "guardrails" in _management_object_metadata: + if len(_management_object_metadata["guardrails"]) > 0: + _premium_user_check() + + data[metadata_variable_name]["guardrails"] = _management_object_metadata[ + "guardrails" + ] + + +def move_guardrails_to_metadata( + data: dict, + _metadata_variable_name: str, + user_api_key_dict: UserAPIKeyAuth, +): + """ + Helper to add guardrails from request to metadata + + - If guardrails set on API Key metadata then sets guardrails on request metadata + - If guardrails not set on API key, then checks request metadata + """ + # Check key-level guardrails + _add_guardrails_from_key_or_team_metadata( + key_metadata=user_api_key_dict.metadata, + team_metadata=user_api_key_dict.team_metadata, + data=data, + metadata_variable_name=_metadata_variable_name, + ) + + # Check request-level guardrails + if "guardrails" in data: + data[_metadata_variable_name]["guardrails"] = data["guardrails"] + del data["guardrails"] + + if "guardrail_config" in data: + data[_metadata_variable_name]["guardrail_config"] = data["guardrail_config"] + del data["guardrail_config"] + + +def add_provider_specific_headers_to_request( + data: dict, + headers: dict, +): + anthropic_headers = {} + # boolean to indicate if a header was added + added_header = False + for header in ANTHROPIC_API_HEADERS: + if header in headers: + header_value = headers[header] + anthropic_headers[header] = header_value + added_header = True + + if added_header is True: + data["provider_specific_header"] = ProviderSpecificHeader( + custom_llm_provider="anthropic", + extra_headers=anthropic_headers, + ) + + return + + +def _add_otel_traceparent_to_data(data: dict, request: Request): + from litellm.proxy.proxy_server import open_telemetry_logger + + if data is None: + return + if open_telemetry_logger is None: + # if user is not use OTEL don't send extra_headers + # relevant issue: https://github.com/BerriAI/litellm/issues/4448 + return + + if litellm.forward_traceparent_to_llm_provider is True: + if request.headers: + if "traceparent" in request.headers: + # we want to forward this to the LLM Provider + # Relevant issue: https://github.com/BerriAI/litellm/issues/4419 + # pass this in extra_headers + if "extra_headers" not in data: + data["extra_headers"] = {} + _exra_headers = data["extra_headers"] + if "traceparent" not in _exra_headers: + _exra_headers["traceparent"] = request.headers["traceparent"] |