diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks')
12 files changed, 2647 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/__init__.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/__init__.py new file mode 100644 index 00000000..b6e690fd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/__init__.py @@ -0,0 +1 @@ +from . import * diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py new file mode 100644 index 00000000..b35d6711 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py @@ -0,0 +1,156 @@ +import traceback +from typing import Optional + +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth + + +class _PROXY_AzureContentSafety( + CustomLogger +): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class + # Class variables or attributes + + def __init__(self, endpoint, api_key, thresholds=None): + try: + from azure.ai.contentsafety.aio import ContentSafetyClient + from azure.ai.contentsafety.models import ( + AnalyzeTextOptions, + AnalyzeTextOutputType, + TextCategory, + ) + from azure.core.credentials import AzureKeyCredential + from azure.core.exceptions import HttpResponseError + except Exception as e: + raise Exception( + f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m" + ) + self.endpoint = endpoint + self.api_key = api_key + self.text_category = TextCategory + self.analyze_text_options = AnalyzeTextOptions + self.analyze_text_output_type = AnalyzeTextOutputType + self.azure_http_error = HttpResponseError + + self.thresholds = self._configure_thresholds(thresholds) + + self.client = ContentSafetyClient( + self.endpoint, AzureKeyCredential(self.api_key) + ) + + def _configure_thresholds(self, thresholds=None): + default_thresholds = { + self.text_category.HATE: 4, + self.text_category.SELF_HARM: 4, + self.text_category.SEXUAL: 4, + self.text_category.VIOLENCE: 4, + } + + if thresholds is None: + return default_thresholds + + for key, default in default_thresholds.items(): + if key not in thresholds: + thresholds[key] = default + + return thresholds + + def _compute_result(self, response): + result = {} + + category_severity = { + item.category: item.severity for item in response.categories_analysis + } + for category in self.text_category: + severity = category_severity.get(category) + if severity is not None: + result[category] = { + "filtered": severity >= self.thresholds[category], + "severity": severity, + } + + return result + + async def test_violation(self, content: str, source: Optional[str] = None): + verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content) + + # Construct a request + request = self.analyze_text_options( + text=content, + output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS, + ) + + # Analyze text + try: + response = await self.client.analyze_text(request) + except self.azure_http_error: + verbose_proxy_logger.debug( + "Error in Azure Content-Safety: %s", traceback.format_exc() + ) + verbose_proxy_logger.debug(traceback.format_exc()) + raise + + result = self._compute_result(response) + verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result) + + for key, value in result.items(): + if value["filtered"]: + raise HTTPException( + status_code=400, + detail={ + "error": "Violated content safety policy", + "source": source, + "category": key, + "severity": value["severity"], + }, + ) + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, # "completion", "embeddings", "image_generation", "moderation" + ): + verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook") + try: + if call_type == "completion" and "messages" in data: + for m in data["messages"]: + if "content" in m and isinstance(m["content"], str): + await self.test_violation(content=m["content"], source="input") + + except HTTPException as e: + raise e + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook") + if isinstance(response, litellm.ModelResponse) and isinstance( + response.choices[0], litellm.utils.Choices + ): + await self.test_violation( + content=response.choices[0].message.content or "", source="output" + ) + + # async def async_post_call_streaming_hook( + # self, + # user_api_key_dict: UserAPIKeyAuth, + # response: str, + # ): + # verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook") + # await self.test_violation(content=response, source="output") diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/batch_redis_get.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/batch_redis_get.py new file mode 100644 index 00000000..c608317f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/batch_redis_get.py @@ -0,0 +1,149 @@ +# What this does? +## Gets a key's redis cache, and store it in memory for 1 minute. +## This reduces the number of REDIS GET requests made during high-traffic by the proxy. +### [BETA] this is in Beta. And might change. + +import traceback +from typing import Literal, Optional + +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache, InMemoryCache, RedisCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth + + +class _PROXY_BatchRedisRequests(CustomLogger): + # Class variables or attributes + in_memory_cache: Optional[InMemoryCache] = None + + def __init__(self): + if litellm.cache is not None: + litellm.cache.async_get_cache = ( + self.async_get_cache + ) # map the litellm 'get_cache' function to our custom function + + def print_verbose( + self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG" + ): + if debug_level == "DEBUG": + verbose_proxy_logger.debug(print_statement) + elif debug_level == "INFO": + verbose_proxy_logger.debug(print_statement) + if litellm.set_verbose is True: + print(print_statement) # noqa + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ): + try: + """ + Get the user key + + Check if a key starting with `litellm:<api_key>:<call_type:` exists in-memory + + If no, then get relevant cache from redis + """ + api_key = user_api_key_dict.api_key + + cache_key_name = f"litellm:{api_key}:{call_type}" + self.in_memory_cache = cache.in_memory_cache + + key_value_dict = {} + in_memory_cache_exists = False + for key in cache.in_memory_cache.cache_dict.keys(): + if isinstance(key, str) and key.startswith(cache_key_name): + in_memory_cache_exists = True + + if in_memory_cache_exists is False and litellm.cache is not None: + """ + - Check if `litellm.Cache` is redis + - Get the relevant values + """ + if litellm.cache.type is not None and isinstance( + litellm.cache.cache, RedisCache + ): + # Initialize an empty list to store the keys + keys = [] + self.print_verbose(f"cache_key_name: {cache_key_name}") + # Use the SCAN iterator to fetch keys matching the pattern + keys = await litellm.cache.cache.async_scan_iter( + pattern=cache_key_name, count=100 + ) + # If you need the truly "last" based on time or another criteria, + # ensure your key naming or storage strategy allows this determination + # Here you would sort or filter the keys as needed based on your strategy + self.print_verbose(f"redis keys: {keys}") + if len(keys) > 0: + key_value_dict = ( + await litellm.cache.cache.async_batch_get_cache( + key_list=keys + ) + ) + + ## Add to cache + if len(key_value_dict.items()) > 0: + await cache.in_memory_cache.async_set_cache_pipeline( + cache_list=list(key_value_dict.items()), ttl=60 + ) + ## Set cache namespace if it's a miss + data["metadata"]["redis_namespace"] = cache_key_name + except HTTPException as e: + raise e + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + + async def async_get_cache(self, *args, **kwargs): + """ + - Check if the cache key is in-memory + + - Else: + - add missing cache key from REDIS + - update in-memory cache + - return redis cache request + """ + try: # never block execution + cache_key: Optional[str] = None + if "cache_key" in kwargs: + cache_key = kwargs["cache_key"] + elif litellm.cache is not None: + cache_key = litellm.cache.get_cache_key( + *args, **kwargs + ) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic + + if ( + cache_key is not None + and self.in_memory_cache is not None + and litellm.cache is not None + ): + cache_control_args = kwargs.get("cache", {}) + max_age = cache_control_args.get( + "s-max-age", cache_control_args.get("s-maxage", float("inf")) + ) + cached_result = self.in_memory_cache.get_cache( + cache_key, *args, **kwargs + ) + if cached_result is None: + cached_result = await litellm.cache.cache.async_get_cache( + cache_key, *args, **kwargs + ) + if cached_result is not None: + await self.in_memory_cache.async_set_cache( + cache_key, cached_result, ttl=60 + ) + return litellm.cache._get_cache_logic( + cached_result=cached_result, max_age=max_age + ) + except Exception: + return None diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/cache_control_check.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/cache_control_check.py new file mode 100644 index 00000000..6e3fbf84 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/cache_control_check.py @@ -0,0 +1,58 @@ +# What this does? +## Checks if key is allowed to use the cache controls passed in to the completion() call + + +from fastapi import HTTPException + +from litellm import verbose_logger +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth + + +class _PROXY_CacheControlCheck(CustomLogger): + # Class variables or attributes + def __init__(self): + pass + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ): + try: + verbose_proxy_logger.debug("Inside Cache Control Check Pre-Call Hook") + allowed_cache_controls = user_api_key_dict.allowed_cache_controls + + if data.get("cache", None) is None: + return + + cache_args = data.get("cache", None) + if isinstance(cache_args, dict): + for k, v in cache_args.items(): + if ( + (allowed_cache_controls is not None) + and (isinstance(allowed_cache_controls, list)) + and ( + len(allowed_cache_controls) > 0 + ) # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663 + and k not in allowed_cache_controls + ): + raise HTTPException( + status_code=403, + detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.", + ) + else: # invalid cache + return + + except HTTPException as e: + raise e + except Exception as e: + verbose_logger.exception( + "litellm.proxy.hooks.cache_control_check.py::async_pre_call_hook(): Exception occured - {}".format( + str(e) + ) + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/dynamic_rate_limiter.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/dynamic_rate_limiter.py new file mode 100644 index 00000000..15a9bc1b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/dynamic_rate_limiter.py @@ -0,0 +1,298 @@ +# What is this? +## Allocates dynamic tpm/rpm quota for a project based on current traffic +## Tracks num active projects per minute + +import asyncio +import os +from typing import List, Literal, Optional, Tuple, Union + +from fastapi import HTTPException + +import litellm +from litellm import ModelResponse, Router +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.router import ModelGroupInfo +from litellm.utils import get_utc_datetime + + +class DynamicRateLimiterCache: + """ + Thin wrapper on DualCache for this file. + + Track number of active projects calling a model. + """ + + def __init__(self, cache: DualCache) -> None: + self.cache = cache + self.ttl = 60 # 1 min ttl + + async def async_get_cache(self, model: str) -> Optional[int]: + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + key_name = "{}:{}".format(current_minute, model) + _response = await self.cache.async_get_cache(key=key_name) + response: Optional[int] = None + if _response is not None: + response = len(_response) + return response + + async def async_set_cache_sadd(self, model: str, value: List): + """ + Add value to set. + + Parameters: + - model: str, the name of the model group + - value: str, the team id + + Returns: + - None + + Raises: + - Exception, if unable to connect to cache client (if redis caching enabled) + """ + try: + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + + key_name = "{}:{}".format(current_minute, model) + await self.cache.async_set_cache_sadd( + key=key_name, value=value, ttl=self.ttl + ) + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}".format( + str(e) + ) + ) + raise e + + +class _PROXY_DynamicRateLimitHandler(CustomLogger): + + # Class variables or attributes + def __init__(self, internal_usage_cache: DualCache): + self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache) + + def update_variables(self, llm_router: Router): + self.llm_router = llm_router + + async def check_available_usage( + self, model: str, priority: Optional[str] = None + ) -> Tuple[ + Optional[int], Optional[int], Optional[int], Optional[int], Optional[int] + ]: + """ + For a given model, get its available tpm + + Params: + - model: str, the name of the model in the router model_list + - priority: Optional[str], the priority for the request. + + Returns + - Tuple[available_tpm, available_tpm, model_tpm, model_rpm, active_projects] + - available_tpm: int or null - always 0 or positive. + - available_tpm: int or null - always 0 or positive. + - remaining_model_tpm: int or null. If available tpm is int, then this will be too. + - remaining_model_rpm: int or null. If available rpm is int, then this will be too. + - active_projects: int or null + """ + try: + weight: float = 1 + if ( + litellm.priority_reservation is None + or priority not in litellm.priority_reservation + ): + verbose_proxy_logger.error( + "Priority Reservation not set. priority={}, but litellm.priority_reservation is {}.".format( + priority, litellm.priority_reservation + ) + ) + elif priority is not None and litellm.priority_reservation is not None: + if os.getenv("LITELLM_LICENSE", None) is None: + verbose_proxy_logger.error( + "PREMIUM FEATURE: Reserving tpm/rpm by priority is a premium feature. Please add a 'LITELLM_LICENSE' to your .env to enable this.\nGet a license: https://docs.litellm.ai/docs/proxy/enterprise." + ) + else: + weight = litellm.priority_reservation[priority] + + active_projects = await self.internal_usage_cache.async_get_cache( + model=model + ) + current_model_tpm, current_model_rpm = ( + await self.llm_router.get_model_group_usage(model_group=model) + ) + model_group_info: Optional[ModelGroupInfo] = ( + self.llm_router.get_model_group_info(model_group=model) + ) + total_model_tpm: Optional[int] = None + total_model_rpm: Optional[int] = None + if model_group_info is not None: + if model_group_info.tpm is not None: + total_model_tpm = model_group_info.tpm + if model_group_info.rpm is not None: + total_model_rpm = model_group_info.rpm + + remaining_model_tpm: Optional[int] = None + if total_model_tpm is not None and current_model_tpm is not None: + remaining_model_tpm = total_model_tpm - current_model_tpm + elif total_model_tpm is not None: + remaining_model_tpm = total_model_tpm + + remaining_model_rpm: Optional[int] = None + if total_model_rpm is not None and current_model_rpm is not None: + remaining_model_rpm = total_model_rpm - current_model_rpm + elif total_model_rpm is not None: + remaining_model_rpm = total_model_rpm + + available_tpm: Optional[int] = None + + if remaining_model_tpm is not None: + if active_projects is not None: + available_tpm = int(remaining_model_tpm * weight / active_projects) + else: + available_tpm = int(remaining_model_tpm * weight) + + if available_tpm is not None and available_tpm < 0: + available_tpm = 0 + + available_rpm: Optional[int] = None + + if remaining_model_rpm is not None: + if active_projects is not None: + available_rpm = int(remaining_model_rpm * weight / active_projects) + else: + available_rpm = int(remaining_model_rpm * weight) + + if available_rpm is not None and available_rpm < 0: + available_rpm = 0 + return ( + available_tpm, + available_rpm, + remaining_model_tpm, + remaining_model_rpm, + active_projects, + ) + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.hooks.dynamic_rate_limiter.py::check_available_usage: Exception occurred - {}".format( + str(e) + ) + ) + return None, None, None, None, None + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + "rerank", + ], + ) -> Optional[ + Union[Exception, str, dict] + ]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm + """ + - For a model group + - Check if tpm/rpm available + - Raise RateLimitError if no tpm/rpm available + """ + if "model" in data: + key_priority: Optional[str] = user_api_key_dict.metadata.get( + "priority", None + ) + available_tpm, available_rpm, model_tpm, model_rpm, active_projects = ( + await self.check_available_usage( + model=data["model"], priority=key_priority + ) + ) + ### CHECK TPM ### + if available_tpm is not None and available_tpm == 0: + raise HTTPException( + status_code=429, + detail={ + "error": "Key={} over available TPM={}. Model TPM={}, Active keys={}".format( + user_api_key_dict.api_key, + available_tpm, + model_tpm, + active_projects, + ) + }, + ) + ### CHECK RPM ### + elif available_rpm is not None and available_rpm == 0: + raise HTTPException( + status_code=429, + detail={ + "error": "Key={} over available RPM={}. Model RPM={}, Active keys={}".format( + user_api_key_dict.api_key, + available_rpm, + model_rpm, + active_projects, + ) + }, + ) + elif available_rpm is not None or available_tpm is not None: + ## UPDATE CACHE WITH ACTIVE PROJECT + asyncio.create_task( + self.internal_usage_cache.async_set_cache_sadd( # this is a set + model=data["model"], # type: ignore + value=[user_api_key_dict.token or "default_key"], + ) + ) + return None + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response + ): + try: + if isinstance(response, ModelResponse): + model_info = self.llm_router.get_model_info( + id=response._hidden_params["model_id"] + ) + assert ( + model_info is not None + ), "Model info for model with id={} is None".format( + response._hidden_params["model_id"] + ) + key_priority: Optional[str] = user_api_key_dict.metadata.get( + "priority", None + ) + available_tpm, available_rpm, model_tpm, model_rpm, active_projects = ( + await self.check_available_usage( + model=model_info["model_name"], priority=key_priority + ) + ) + response._hidden_params["additional_headers"] = ( + { # Add additional response headers - easier debugging + "x-litellm-model_group": model_info["model_name"], + "x-ratelimit-remaining-litellm-project-tokens": available_tpm, + "x-ratelimit-remaining-litellm-project-requests": available_rpm, + "x-ratelimit-remaining-model-tokens": model_tpm, + "x-ratelimit-remaining-model-requests": model_rpm, + "x-ratelimit-current-active-projects": active_projects, + } + ) + + return response + return await super().async_post_call_success_hook( + data=data, + user_api_key_dict=user_api_key_dict, + response=response, + ) + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}".format( + str(e) + ) + ) + return response diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/example_presidio_ad_hoc_recognizer.json b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/example_presidio_ad_hoc_recognizer.json new file mode 100644 index 00000000..6a94d8de --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/example_presidio_ad_hoc_recognizer.json @@ -0,0 +1,28 @@ +[ + { + "name": "Zip code Recognizer", + "supported_language": "en", + "patterns": [ + { + "name": "zip code (weak)", + "regex": "(\\b\\d{5}(?:\\-\\d{4})?\\b)", + "score": 0.01 + } + ], + "context": ["zip", "code"], + "supported_entity": "ZIP" + }, + { + "name": "Swiss AHV Number Recognizer", + "supported_language": "en", + "patterns": [ + { + "name": "AHV number (strong)", + "regex": "(756\\.\\d{4}\\.\\d{4}\\.\\d{2})|(756\\d{10})", + "score": 0.95 + } + ], + "context": ["AHV", "social security", "Swiss"], + "supported_entity": "AHV_NUMBER" + } +]
\ No newline at end of file diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py new file mode 100644 index 00000000..2030cb2a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py @@ -0,0 +1,324 @@ +import asyncio +import json +import uuid +from datetime import datetime, timezone +from typing import Any, List, Optional + +from fastapi import status + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + GenerateKeyRequest, + GenerateKeyResponse, + KeyRequest, + LiteLLM_AuditLogs, + LiteLLM_VerificationToken, + LitellmTableNames, + ProxyErrorTypes, + ProxyException, + RegenerateKeyRequest, + UpdateKeyRequest, + UserAPIKeyAuth, + WebhookEvent, +) + +# NOTE: This is the prefix for all virtual keys stored in AWS Secrets Manager +LITELLM_PREFIX_STORED_VIRTUAL_KEYS = "litellm/" + + +class KeyManagementEventHooks: + + @staticmethod + async def async_key_generated_hook( + data: GenerateKeyRequest, + response: GenerateKeyResponse, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, + ): + """ + Hook that runs after a successful /key/generate request + + Handles the following: + - Sending Email with Key Details + - Storing Audit Logs for key generation + - Storing Generated Key in DB + """ + from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + ) + from litellm.proxy.proxy_server import litellm_proxy_admin_name + + if data.send_invite_email is True: + await KeyManagementEventHooks._send_key_created_email( + response.model_dump(exclude_none=True) + ) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = response.model_dump_json(exclude_none=True) + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=response.token_id or "", + action="created", + updated_values=_updated_values, + before_value=None, + ) + ) + ) + # store the generated key in the secret manager + await KeyManagementEventHooks._store_virtual_key_in_secret_manager( + secret_name=data.key_alias or f"virtual-key-{response.token_id}", + secret_token=response.key, + ) + + @staticmethod + async def async_key_updated_hook( + data: UpdateKeyRequest, + existing_key_row: Any, + response: Any, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, + ): + """ + Post /key/update processing hook + + Handles the following: + - Storing Audit Logs for key update + """ + from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + ) + from litellm.proxy.proxy_server import litellm_proxy_admin_name + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = json.dumps(data.json(exclude_none=True), default=str) + + _before_value = existing_key_row.json(exclude_none=True) + _before_value = json.dumps(_before_value, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=data.key, + action="updated", + updated_values=_updated_values, + before_value=_before_value, + ) + ) + ) + + @staticmethod + async def async_key_rotated_hook( + data: Optional[RegenerateKeyRequest], + existing_key_row: Any, + response: GenerateKeyResponse, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, + ): + # store the generated key in the secret manager + if data is not None and response.token_id is not None: + initial_secret_name = ( + existing_key_row.key_alias or f"virtual-key-{existing_key_row.token}" + ) + await KeyManagementEventHooks._rotate_virtual_key_in_secret_manager( + current_secret_name=initial_secret_name, + new_secret_name=data.key_alias or f"virtual-key-{response.token_id}", + new_secret_value=response.key, + ) + + @staticmethod + async def async_key_deleted_hook( + data: KeyRequest, + keys_being_deleted: List[LiteLLM_VerificationToken], + response: dict, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, + ): + """ + Post /key/delete processing hook + + Handles the following: + - Storing Audit Logs for key deletion + """ + from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + ) + from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes + if litellm.store_audit_logs is True and data.keys is not None: + # make an audit log for each team deleted + for key in data.keys: + key_row = await prisma_client.get_data( # type: ignore + token=key, table_name="key", query_type="find_unique" + ) + + if key_row is None: + raise ProxyException( + message=f"Key {key} not found", + type=ProxyErrorTypes.bad_request_error, + param="key", + code=status.HTTP_404_NOT_FOUND, + ) + + key_row = key_row.json(exclude_none=True) + _key_row = json.dumps(key_row, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=key, + action="deleted", + updated_values="{}", + before_value=_key_row, + ) + ) + ) + # delete the keys from the secret manager + await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager( + keys_being_deleted=keys_being_deleted + ) + pass + + @staticmethod + async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str): + """ + Store a virtual key in the secret manager + + Args: + secret_name: Name of the virtual key + secret_token: Value of the virtual key (example: sk-1234) + """ + if litellm._key_management_settings is not None: + if litellm._key_management_settings.store_virtual_keys is True: + from litellm.secret_managers.base_secret_manager import ( + BaseSecretManager, + ) + + # store the key in the secret manager + if isinstance(litellm.secret_manager_client, BaseSecretManager): + await litellm.secret_manager_client.async_write_secret( + secret_name=KeyManagementEventHooks._get_secret_name( + secret_name + ), + secret_value=secret_token, + ) + + @staticmethod + async def _rotate_virtual_key_in_secret_manager( + current_secret_name: str, new_secret_name: str, new_secret_value: str + ): + """ + Update a virtual key in the secret manager + + Args: + secret_name: Name of the virtual key + secret_token: Value of the virtual key (example: sk-1234) + """ + if litellm._key_management_settings is not None: + if litellm._key_management_settings.store_virtual_keys is True: + from litellm.secret_managers.base_secret_manager import ( + BaseSecretManager, + ) + + # store the key in the secret manager + if isinstance(litellm.secret_manager_client, BaseSecretManager): + await litellm.secret_manager_client.async_rotate_secret( + current_secret_name=KeyManagementEventHooks._get_secret_name( + current_secret_name + ), + new_secret_name=KeyManagementEventHooks._get_secret_name( + new_secret_name + ), + new_secret_value=new_secret_value, + ) + + @staticmethod + def _get_secret_name(secret_name: str) -> str: + if litellm._key_management_settings.prefix_for_stored_virtual_keys.endswith( + "/" + ): + return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{secret_name}" + else: + return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}/{secret_name}" + + @staticmethod + async def _delete_virtual_keys_from_secret_manager( + keys_being_deleted: List[LiteLLM_VerificationToken], + ): + """ + Deletes virtual keys from the secret manager + + Args: + keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation + """ + if litellm._key_management_settings is not None: + if litellm._key_management_settings.store_virtual_keys is True: + from litellm.secret_managers.base_secret_manager import ( + BaseSecretManager, + ) + + if isinstance(litellm.secret_manager_client, BaseSecretManager): + for key in keys_being_deleted: + if key.key_alias is not None: + await litellm.secret_manager_client.async_delete_secret( + secret_name=KeyManagementEventHooks._get_secret_name( + key.key_alias + ) + ) + else: + verbose_proxy_logger.warning( + f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager." + ) + + @staticmethod + async def _send_key_created_email(response: dict): + from litellm.proxy.proxy_server import general_settings, proxy_logging_obj + + if "email" not in general_settings.get("alerting", []): + raise ValueError( + "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`" + ) + event = WebhookEvent( + event="key_created", + event_group="key", + event_message="API Key Created", + token=response.get("token", ""), + spend=response.get("spend", 0.0), + max_budget=response.get("max_budget", 0.0), + user_id=response.get("user_id", None), + team_id=response.get("team_id", "Default Team"), + key_alias=response.get("key_alias", None), + ) + + # If user configured email alerting - send an Email letting their end-user know the key was created + asyncio.create_task( + proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email( + webhook_event=event, + ) + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/max_budget_limiter.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/max_budget_limiter.py new file mode 100644 index 00000000..4b59f603 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/max_budget_limiter.py @@ -0,0 +1,49 @@ +from fastapi import HTTPException + +from litellm import verbose_logger +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth + + +class _PROXY_MaxBudgetLimiter(CustomLogger): + # Class variables or attributes + def __init__(self): + pass + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ): + try: + verbose_proxy_logger.debug("Inside Max Budget Limiter Pre-Call Hook") + cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id" + user_row = await cache.async_get_cache( + cache_key, parent_otel_span=user_api_key_dict.parent_otel_span + ) + if user_row is None: # value not yet cached + return + max_budget = user_row["max_budget"] + curr_spend = user_row["spend"] + + if max_budget is None: + return + + if curr_spend is None: + return + + # CHECK IF REQUEST ALLOWED + if curr_spend >= max_budget: + raise HTTPException(status_code=429, detail="Max budget limit reached.") + except HTTPException as e: + raise e + except Exception as e: + verbose_logger.exception( + "litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format( + str(e) + ) + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/model_max_budget_limiter.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/model_max_budget_limiter.py new file mode 100644 index 00000000..ac02c915 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/model_max_budget_limiter.py @@ -0,0 +1,192 @@ +import json +from typing import List, Optional + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import Span +from litellm.proxy._types import UserAPIKeyAuth +from litellm.router_strategy.budget_limiter import RouterBudgetLimiting +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ( + BudgetConfig, + GenericBudgetConfigType, + StandardLoggingPayload, +) + +VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend" + + +class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting): + """ + Handles budgets for model + virtual key + + Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d + """ + + def __init__(self, dual_cache: DualCache): + self.dual_cache = dual_cache + self.redis_increment_operation_queue = [] + + async def is_key_within_model_budget( + self, + user_api_key_dict: UserAPIKeyAuth, + model: str, + ) -> bool: + """ + Check if the user_api_key_dict is within the model budget + + Raises: + BudgetExceededError: If the user_api_key_dict has exceeded the model budget + """ + _model_max_budget = user_api_key_dict.model_max_budget + internal_model_max_budget: GenericBudgetConfigType = {} + + for _model, _budget_info in _model_max_budget.items(): + internal_model_max_budget[_model] = BudgetConfig(**_budget_info) + + verbose_proxy_logger.debug( + "internal_model_max_budget %s", + json.dumps(internal_model_max_budget, indent=4, default=str), + ) + + # check if current model is in internal_model_max_budget + _current_model_budget_info = self._get_request_model_budget_config( + model=model, internal_model_max_budget=internal_model_max_budget + ) + if _current_model_budget_info is None: + verbose_proxy_logger.debug( + f"Model {model} not found in internal_model_max_budget" + ) + return True + + # check if current model is within budget + if ( + _current_model_budget_info.max_budget + and _current_model_budget_info.max_budget > 0 + ): + _current_spend = await self._get_virtual_key_spend_for_model( + user_api_key_hash=user_api_key_dict.token, + model=model, + key_budget_config=_current_model_budget_info, + ) + if ( + _current_spend is not None + and _current_model_budget_info.max_budget is not None + and _current_spend > _current_model_budget_info.max_budget + ): + raise litellm.BudgetExceededError( + message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}", + current_cost=_current_spend, + max_budget=_current_model_budget_info.max_budget, + ) + + return True + + async def _get_virtual_key_spend_for_model( + self, + user_api_key_hash: Optional[str], + model: str, + key_budget_config: BudgetConfig, + ) -> Optional[float]: + """ + Get the current spend for a virtual key for a model + + Lookup model in this order: + 1. model: directly look up `model` + 2. If 1, does not exist, check if passed as {custom_llm_provider}/model + """ + + # 1. model: directly look up `model` + virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.budget_duration}" + _current_spend = await self.dual_cache.async_get_cache( + key=virtual_key_model_spend_cache_key, + ) + + if _current_spend is None: + # 2. If 1, does not exist, check if passed as {custom_llm_provider}/model + # if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview + virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.budget_duration}" + _current_spend = await self.dual_cache.async_get_cache( + key=virtual_key_model_spend_cache_key, + ) + return _current_spend + + def _get_request_model_budget_config( + self, model: str, internal_model_max_budget: GenericBudgetConfigType + ) -> Optional[BudgetConfig]: + """ + Get the budget config for the request model + + 1. Check if `model` is in `internal_model_max_budget` + 2. If not, check if `model` without custom llm provider is in `internal_model_max_budget` + """ + return internal_model_max_budget.get( + model, None + ) or internal_model_max_budget.get( + self._get_model_without_custom_llm_provider(model), None + ) + + def _get_model_without_custom_llm_provider(self, model: str) -> str: + if "/" in model: + return model.split("/")[-1] + return model + + async def async_filter_deployments( + self, + model: str, + healthy_deployments: List, + messages: Optional[List[AllMessageValues]], + request_kwargs: Optional[dict] = None, + parent_otel_span: Optional[Span] = None, # type: ignore + ) -> List[dict]: + return healthy_deployments + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + """ + Track spend for virtual key + model in DualCache + + Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d + """ + verbose_proxy_logger.debug("in RouterBudgetLimiting.async_log_success_event") + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + if standard_logging_payload is None: + raise ValueError("standard_logging_payload is required") + + _litellm_params: dict = kwargs.get("litellm_params", {}) or {} + _metadata: dict = _litellm_params.get("metadata", {}) or {} + user_api_key_model_max_budget: Optional[dict] = _metadata.get( + "user_api_key_model_max_budget", None + ) + if ( + user_api_key_model_max_budget is None + or len(user_api_key_model_max_budget) == 0 + ): + verbose_proxy_logger.debug( + "Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget is None or empty. `user_api_key_model_max_budget`=%s", + user_api_key_model_max_budget, + ) + return + response_cost: float = standard_logging_payload.get("response_cost", 0) + model = standard_logging_payload.get("model") + + virtual_key = standard_logging_payload.get("metadata").get("user_api_key_hash") + model = standard_logging_payload.get("model") + if virtual_key is not None: + budget_config = BudgetConfig(time_period="1d", budget_limit=0.1) + virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_config.budget_duration}" + virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}" + await self._increment_spend_for_key( + budget_config=budget_config, + spend_key=virtual_spend_key, + start_time_key=virtual_start_time_key, + response_cost=response_cost, + ) + verbose_proxy_logger.debug( + "current state of in memory cache %s", + json.dumps( + self.dual_cache.in_memory_cache.cache_dict, indent=4, default=str + ), + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/parallel_request_limiter.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/parallel_request_limiter.py new file mode 100644 index 00000000..06f3b6af --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/parallel_request_limiter.py @@ -0,0 +1,866 @@ +import asyncio +import sys +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union + +from fastapi import HTTPException +from pydantic import BaseModel + +import litellm +from litellm import DualCache, ModelResponse +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.proxy._types import CommonProxyErrors, CurrentItemRateLimit, UserAPIKeyAuth +from litellm.proxy.auth.auth_utils import ( + get_key_model_rpm_limit, + get_key_model_tpm_limit, +) + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache + + Span = _Span + InternalUsageCache = _InternalUsageCache +else: + Span = Any + InternalUsageCache = Any + + +class CacheObject(TypedDict): + current_global_requests: Optional[dict] + request_count_api_key: Optional[dict] + request_count_api_key_model: Optional[dict] + request_count_user_id: Optional[dict] + request_count_team_id: Optional[dict] + request_count_end_user_id: Optional[dict] + + +class _PROXY_MaxParallelRequestsHandler(CustomLogger): + # Class variables or attributes + def __init__(self, internal_usage_cache: InternalUsageCache): + self.internal_usage_cache = internal_usage_cache + + def print_verbose(self, print_statement): + try: + verbose_proxy_logger.debug(print_statement) + if litellm.set_verbose: + print(print_statement) # noqa + except Exception: + pass + + async def check_key_in_limits( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + max_parallel_requests: int, + tpm_limit: int, + rpm_limit: int, + current: Optional[dict], + request_count_api_key: str, + rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"], + values_to_update_in_cache: List[Tuple[Any, Any]], + ) -> dict: + verbose_proxy_logger.info( + f"Current Usage of {rate_limit_type} in this minute: {current}" + ) + if current is None: + if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: + # base case + raise self.raise_rate_limit_error( + additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}" + ) + new_val = { + "current_requests": 1, + "current_tpm": 0, + "current_rpm": 1, + } + values_to_update_in_cache.append((request_count_api_key, new_val)) + elif ( + int(current["current_requests"]) < max_parallel_requests + and current["current_tpm"] < tpm_limit + and current["current_rpm"] < rpm_limit + ): + # Increase count for this token + new_val = { + "current_requests": current["current_requests"] + 1, + "current_tpm": current["current_tpm"], + "current_rpm": current["current_rpm"] + 1, + } + values_to_update_in_cache.append((request_count_api_key, new_val)) + + else: + raise HTTPException( + status_code=429, + detail=f"LiteLLM Rate Limit Handler for rate limit type = {rate_limit_type}. {CommonProxyErrors.max_parallel_request_limit_reached.value}. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}, current max_parallel_requests: {current['current_requests']}, max_parallel_requests: {max_parallel_requests}", + headers={"retry-after": str(self.time_to_next_minute())}, + ) + + await self.internal_usage_cache.async_batch_set_cache( + cache_list=values_to_update_in_cache, + ttl=60, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + local_only=True, + ) + return new_val + + def time_to_next_minute(self) -> float: + # Get the current time + now = datetime.now() + + # Calculate the next minute + next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0) + + # Calculate the difference in seconds + seconds_to_next_minute = (next_minute - now).total_seconds() + + return seconds_to_next_minute + + def raise_rate_limit_error( + self, additional_details: Optional[str] = None + ) -> HTTPException: + """ + Raise an HTTPException with a 429 status code and a retry-after header + """ + error_message = "Max parallel request limit reached" + if additional_details is not None: + error_message = error_message + " " + additional_details + raise HTTPException( + status_code=429, + detail=f"Max parallel request limit reached {additional_details}", + headers={"retry-after": str(self.time_to_next_minute())}, + ) + + async def get_all_cache_objects( + self, + current_global_requests: Optional[str], + request_count_api_key: Optional[str], + request_count_api_key_model: Optional[str], + request_count_user_id: Optional[str], + request_count_team_id: Optional[str], + request_count_end_user_id: Optional[str], + parent_otel_span: Optional[Span] = None, + ) -> CacheObject: + keys = [ + current_global_requests, + request_count_api_key, + request_count_api_key_model, + request_count_user_id, + request_count_team_id, + request_count_end_user_id, + ] + results = await self.internal_usage_cache.async_batch_get_cache( + keys=keys, + parent_otel_span=parent_otel_span, + ) + + if results is None: + return CacheObject( + current_global_requests=None, + request_count_api_key=None, + request_count_api_key_model=None, + request_count_user_id=None, + request_count_team_id=None, + request_count_end_user_id=None, + ) + + return CacheObject( + current_global_requests=results[0], + request_count_api_key=results[1], + request_count_api_key_model=results[2], + request_count_user_id=results[3], + request_count_team_id=results[4], + request_count_end_user_id=results[5], + ) + + async def async_pre_call_hook( # noqa: PLR0915 + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ): + self.print_verbose("Inside Max Parallel Request Pre-Call Hook") + api_key = user_api_key_dict.api_key + max_parallel_requests = user_api_key_dict.max_parallel_requests + if max_parallel_requests is None: + max_parallel_requests = sys.maxsize + if data is None: + data = {} + global_max_parallel_requests = data.get("metadata", {}).get( + "global_max_parallel_requests", None + ) + tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) + if tpm_limit is None: + tpm_limit = sys.maxsize + rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize) + if rpm_limit is None: + rpm_limit = sys.maxsize + + values_to_update_in_cache: List[Tuple[Any, Any]] = ( + [] + ) # values that need to get updated in cache, will run a batch_set_cache after this function + + # ------------ + # Setup values + # ------------ + new_val: Optional[dict] = None + + if global_max_parallel_requests is not None: + # get value from cache + _key = "global_max_parallel_requests" + current_global_requests = await self.internal_usage_cache.async_get_cache( + key=_key, + local_only=True, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + ) + # check if below limit + if current_global_requests is None: + current_global_requests = 1 + # if above -> raise error + if current_global_requests >= global_max_parallel_requests: + return self.raise_rate_limit_error( + additional_details=f"Hit Global Limit: Limit={global_max_parallel_requests}, current: {current_global_requests}" + ) + # if below -> increment + else: + await self.internal_usage_cache.async_increment_cache( + key=_key, + value=1, + local_only=True, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + ) + _model = data.get("model", None) + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + cache_objects: CacheObject = await self.get_all_cache_objects( + current_global_requests=( + "global_max_parallel_requests" + if global_max_parallel_requests is not None + else None + ), + request_count_api_key=( + f"{api_key}::{precise_minute}::request_count" + if api_key is not None + else None + ), + request_count_api_key_model=( + f"{api_key}::{_model}::{precise_minute}::request_count" + if api_key is not None and _model is not None + else None + ), + request_count_user_id=( + f"{user_api_key_dict.user_id}::{precise_minute}::request_count" + if user_api_key_dict.user_id is not None + else None + ), + request_count_team_id=( + f"{user_api_key_dict.team_id}::{precise_minute}::request_count" + if user_api_key_dict.team_id is not None + else None + ), + request_count_end_user_id=( + f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count" + if user_api_key_dict.end_user_id is not None + else None + ), + parent_otel_span=user_api_key_dict.parent_otel_span, + ) + if api_key is not None: + request_count_api_key = f"{api_key}::{precise_minute}::request_count" + # CHECK IF REQUEST ALLOWED for key + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=max_parallel_requests, + current=cache_objects["request_count_api_key"], + request_count_api_key=request_count_api_key, + tpm_limit=tpm_limit, + rpm_limit=rpm_limit, + rate_limit_type="key", + values_to_update_in_cache=values_to_update_in_cache, + ) + + # Check if request under RPM/TPM per model for a given API Key + if ( + get_key_model_tpm_limit(user_api_key_dict) is not None + or get_key_model_rpm_limit(user_api_key_dict) is not None + ): + _model = data.get("model", None) + request_count_api_key = ( + f"{api_key}::{_model}::{precise_minute}::request_count" + ) + _tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict) + _rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict) + tpm_limit_for_model = None + rpm_limit_for_model = None + + if _model is not None: + if _tpm_limit_for_key_model: + tpm_limit_for_model = _tpm_limit_for_key_model.get(_model) + + if _rpm_limit_for_key_model: + rpm_limit_for_model = _rpm_limit_for_key_model.get(_model) + + new_val = await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a model + current=cache_objects["request_count_api_key_model"], + request_count_api_key=request_count_api_key, + tpm_limit=tpm_limit_for_model or sys.maxsize, + rpm_limit=rpm_limit_for_model or sys.maxsize, + rate_limit_type="model_per_key", + values_to_update_in_cache=values_to_update_in_cache, + ) + _remaining_tokens = None + _remaining_requests = None + # Add remaining tokens, requests to metadata + if new_val: + if tpm_limit_for_model is not None: + _remaining_tokens = tpm_limit_for_model - new_val["current_tpm"] + if rpm_limit_for_model is not None: + _remaining_requests = rpm_limit_for_model - new_val["current_rpm"] + + _remaining_limits_data = { + f"litellm-key-remaining-tokens-{_model}": _remaining_tokens, + f"litellm-key-remaining-requests-{_model}": _remaining_requests, + } + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"].update(_remaining_limits_data) + + # check if REQUEST ALLOWED for user_id + user_id = user_api_key_dict.user_id + if user_id is not None: + user_tpm_limit = user_api_key_dict.user_tpm_limit + user_rpm_limit = user_api_key_dict.user_rpm_limit + if user_tpm_limit is None: + user_tpm_limit = sys.maxsize + if user_rpm_limit is None: + user_rpm_limit = sys.maxsize + + request_count_api_key = f"{user_id}::{precise_minute}::request_count" + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user + current=cache_objects["request_count_user_id"], + request_count_api_key=request_count_api_key, + tpm_limit=user_tpm_limit, + rpm_limit=user_rpm_limit, + rate_limit_type="user", + values_to_update_in_cache=values_to_update_in_cache, + ) + + # TEAM RATE LIMITS + ## get team tpm/rpm limits + team_id = user_api_key_dict.team_id + if team_id is not None: + team_tpm_limit = user_api_key_dict.team_tpm_limit + team_rpm_limit = user_api_key_dict.team_rpm_limit + + if team_tpm_limit is None: + team_tpm_limit = sys.maxsize + if team_rpm_limit is None: + team_rpm_limit = sys.maxsize + + request_count_api_key = f"{team_id}::{precise_minute}::request_count" + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team + current=cache_objects["request_count_team_id"], + request_count_api_key=request_count_api_key, + tpm_limit=team_tpm_limit, + rpm_limit=team_rpm_limit, + rate_limit_type="team", + values_to_update_in_cache=values_to_update_in_cache, + ) + + # End-User Rate Limits + # Only enforce if user passed `user` to /chat, /completions, /embeddings + if user_api_key_dict.end_user_id: + end_user_tpm_limit = getattr( + user_api_key_dict, "end_user_tpm_limit", sys.maxsize + ) + end_user_rpm_limit = getattr( + user_api_key_dict, "end_user_rpm_limit", sys.maxsize + ) + + if end_user_tpm_limit is None: + end_user_tpm_limit = sys.maxsize + if end_user_rpm_limit is None: + end_user_rpm_limit = sys.maxsize + + # now do the same tpm/rpm checks + request_count_api_key = ( + f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count" + ) + + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User + request_count_api_key=request_count_api_key, + current=cache_objects["request_count_end_user_id"], + tpm_limit=end_user_tpm_limit, + rpm_limit=end_user_rpm_limit, + rate_limit_type="customer", + values_to_update_in_cache=values_to_update_in_cache, + ) + + asyncio.create_task( + self.internal_usage_cache.async_batch_set_cache( + cache_list=values_to_update_in_cache, + ttl=60, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + ) # don't block execution for cache updates + ) + + return + + async def async_log_success_event( # noqa: PLR0915 + self, kwargs, response_obj, start_time, end_time + ): + from litellm.proxy.common_utils.callback_utils import ( + get_model_group_from_litellm_kwargs, + ) + + litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs( + kwargs=kwargs + ) + try: + self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING") + + global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( + "global_max_parallel_requests", None + ) + user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] + user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( + "user_api_key_user_id", None + ) + user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( + "user_api_key_team_id", None + ) + user_api_key_model_max_budget = kwargs["litellm_params"]["metadata"].get( + "user_api_key_model_max_budget", None + ) + user_api_key_end_user_id = kwargs.get("user") + + user_api_key_metadata = ( + kwargs["litellm_params"]["metadata"].get("user_api_key_metadata", {}) + or {} + ) + + # ------------ + # Setup values + # ------------ + + if global_max_parallel_requests is not None: + # get value from cache + _key = "global_max_parallel_requests" + # decrement + await self.internal_usage_cache.async_increment_cache( + key=_key, + value=-1, + local_only=True, + litellm_parent_otel_span=litellm_parent_otel_span, + ) + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens # type: ignore + + # ------------ + # Update usage - API Key + # ------------ + + values_to_update_in_cache = [] + + if user_api_key is not None: + request_count_api_key = ( + f"{user_api_key}::{precise_minute}::request_count" + ) + + current = await self.internal_usage_cache.async_get_cache( + key=request_count_api_key, + litellm_parent_otel_span=litellm_parent_otel_span, + ) or { + "current_requests": 1, + "current_tpm": 0, + "current_rpm": 0, + } + + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"], + } + + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + values_to_update_in_cache.append((request_count_api_key, new_val)) + + # ------------ + # Update usage - model group + API Key + # ------------ + model_group = get_model_group_from_litellm_kwargs(kwargs) + if ( + user_api_key is not None + and model_group is not None + and ( + "model_rpm_limit" in user_api_key_metadata + or "model_tpm_limit" in user_api_key_metadata + or user_api_key_model_max_budget is not None + ) + ): + request_count_api_key = ( + f"{user_api_key}::{model_group}::{precise_minute}::request_count" + ) + + current = await self.internal_usage_cache.async_get_cache( + key=request_count_api_key, + litellm_parent_otel_span=litellm_parent_otel_span, + ) or { + "current_requests": 1, + "current_tpm": 0, + "current_rpm": 0, + } + + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"], + } + + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + values_to_update_in_cache.append((request_count_api_key, new_val)) + + # ------------ + # Update usage - User + # ------------ + if user_api_key_user_id is not None: + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens # type: ignore + + request_count_api_key = ( + f"{user_api_key_user_id}::{precise_minute}::request_count" + ) + + current = await self.internal_usage_cache.async_get_cache( + key=request_count_api_key, + litellm_parent_otel_span=litellm_parent_otel_span, + ) or { + "current_requests": 1, + "current_tpm": total_tokens, + "current_rpm": 1, + } + + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"], + } + + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + values_to_update_in_cache.append((request_count_api_key, new_val)) + + # ------------ + # Update usage - Team + # ------------ + if user_api_key_team_id is not None: + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens # type: ignore + + request_count_api_key = ( + f"{user_api_key_team_id}::{precise_minute}::request_count" + ) + + current = await self.internal_usage_cache.async_get_cache( + key=request_count_api_key, + litellm_parent_otel_span=litellm_parent_otel_span, + ) or { + "current_requests": 1, + "current_tpm": total_tokens, + "current_rpm": 1, + } + + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"], + } + + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + values_to_update_in_cache.append((request_count_api_key, new_val)) + + # ------------ + # Update usage - End User + # ------------ + if user_api_key_end_user_id is not None: + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens # type: ignore + + request_count_api_key = ( + f"{user_api_key_end_user_id}::{precise_minute}::request_count" + ) + + current = await self.internal_usage_cache.async_get_cache( + key=request_count_api_key, + litellm_parent_otel_span=litellm_parent_otel_span, + ) or { + "current_requests": 1, + "current_tpm": total_tokens, + "current_rpm": 1, + } + + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"], + } + + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + values_to_update_in_cache.append((request_count_api_key, new_val)) + + await self.internal_usage_cache.async_batch_set_cache( + cache_list=values_to_update_in_cache, + ttl=60, + litellm_parent_otel_span=litellm_parent_otel_span, + ) + except Exception as e: + self.print_verbose(e) # noqa + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + self.print_verbose("Inside Max Parallel Request Failure Hook") + litellm_parent_otel_span: Union[Span, None] = ( + _get_parent_otel_span_from_kwargs(kwargs=kwargs) + ) + _metadata = kwargs["litellm_params"].get("metadata", {}) or {} + global_max_parallel_requests = _metadata.get( + "global_max_parallel_requests", None + ) + user_api_key = _metadata.get("user_api_key", None) + self.print_verbose(f"user_api_key: {user_api_key}") + if user_api_key is None: + return + + ## decrement call count if call failed + if CommonProxyErrors.max_parallel_request_limit_reached.value in str( + kwargs["exception"] + ): + pass # ignore failed calls due to max limit being reached + else: + # ------------ + # Setup values + # ------------ + + if global_max_parallel_requests is not None: + # get value from cache + _key = "global_max_parallel_requests" + ( + await self.internal_usage_cache.async_get_cache( + key=_key, + local_only=True, + litellm_parent_otel_span=litellm_parent_otel_span, + ) + ) + # decrement + await self.internal_usage_cache.async_increment_cache( + key=_key, + value=-1, + local_only=True, + litellm_parent_otel_span=litellm_parent_otel_span, + ) + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + request_count_api_key = ( + f"{user_api_key}::{precise_minute}::request_count" + ) + + # ------------ + # Update usage + # ------------ + current = await self.internal_usage_cache.async_get_cache( + key=request_count_api_key, + litellm_parent_otel_span=litellm_parent_otel_span, + ) or { + "current_requests": 1, + "current_tpm": 0, + "current_rpm": 0, + } + + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"], + "current_rpm": current["current_rpm"], + } + + self.print_verbose(f"updated_value in failure call: {new_val}") + await self.internal_usage_cache.async_set_cache( + request_count_api_key, + new_val, + ttl=60, + litellm_parent_otel_span=litellm_parent_otel_span, + ) # save in cache for up to 1 min. + except Exception as e: + verbose_proxy_logger.exception( + "Inside Parallel Request Limiter: An exception occurred - {}".format( + str(e) + ) + ) + + async def get_internal_user_object( + self, + user_id: str, + user_api_key_dict: UserAPIKeyAuth, + ) -> Optional[dict]: + """ + Helper to get the 'Internal User Object' + + It uses the `get_user_object` function from `litellm.proxy.auth.auth_checks` + + We need this because the UserApiKeyAuth object does not contain the rpm/tpm limits for a User AND there could be a perf impact by additionally reading the UserTable. + """ + from litellm._logging import verbose_proxy_logger + from litellm.proxy.auth.auth_checks import get_user_object + from litellm.proxy.proxy_server import prisma_client + + try: + _user_id_rate_limits = await get_user_object( + user_id=user_id, + prisma_client=prisma_client, + user_api_key_cache=self.internal_usage_cache.dual_cache, + user_id_upsert=False, + parent_otel_span=user_api_key_dict.parent_otel_span, + proxy_logging_obj=None, + ) + + if _user_id_rate_limits is None: + return None + + return _user_id_rate_limits.model_dump() + except Exception as e: + verbose_proxy_logger.debug( + "Parallel Request Limiter: Error getting user object", str(e) + ) + return None + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response + ): + """ + Retrieve the key's remaining rate limits. + """ + api_key = user_api_key_dict.api_key + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + request_count_api_key = f"{api_key}::{precise_minute}::request_count" + current: Optional[CurrentItemRateLimit] = ( + await self.internal_usage_cache.async_get_cache( + key=request_count_api_key, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + ) + ) + + key_remaining_rpm_limit: Optional[int] = None + key_rpm_limit: Optional[int] = None + key_remaining_tpm_limit: Optional[int] = None + key_tpm_limit: Optional[int] = None + if current is not None: + if user_api_key_dict.rpm_limit is not None: + key_remaining_rpm_limit = ( + user_api_key_dict.rpm_limit - current["current_rpm"] + ) + key_rpm_limit = user_api_key_dict.rpm_limit + if user_api_key_dict.tpm_limit is not None: + key_remaining_tpm_limit = ( + user_api_key_dict.tpm_limit - current["current_tpm"] + ) + key_tpm_limit = user_api_key_dict.tpm_limit + + if hasattr(response, "_hidden_params"): + _hidden_params = getattr(response, "_hidden_params") + else: + _hidden_params = None + if _hidden_params is not None and ( + isinstance(_hidden_params, BaseModel) or isinstance(_hidden_params, dict) + ): + if isinstance(_hidden_params, BaseModel): + _hidden_params = _hidden_params.model_dump() + + _additional_headers = _hidden_params.get("additional_headers", {}) or {} + + if key_remaining_rpm_limit is not None: + _additional_headers["x-ratelimit-remaining-requests"] = ( + key_remaining_rpm_limit + ) + if key_rpm_limit is not None: + _additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit + if key_remaining_tpm_limit is not None: + _additional_headers["x-ratelimit-remaining-tokens"] = ( + key_remaining_tpm_limit + ) + if key_tpm_limit is not None: + _additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit + + setattr( + response, + "_hidden_params", + {**_hidden_params, "additional_headers": _additional_headers}, + ) + + return await super().async_post_call_success_hook( + data, user_api_key_dict, response + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py new file mode 100644 index 00000000..b1b2bbee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py @@ -0,0 +1,280 @@ +# +------------------------------------+ +# +# Prompt Injection Detection +# +# +------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan +## Reject a call if it contains a prompt injection attack. + + +from difflib import SequenceMatcher +from typing import List, Literal, Optional + +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.prompt_templates.factory import ( + prompt_injection_detection_default_pt, +) +from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth +from litellm.router import Router +from litellm.utils import get_formatted_prompt + + +class _OPTIONAL_PromptInjectionDetection(CustomLogger): + # Class variables or attributes + def __init__( + self, + prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None, + ): + self.prompt_injection_params = prompt_injection_params + self.llm_router: Optional[Router] = None + + self.verbs = [ + "Ignore", + "Disregard", + "Skip", + "Forget", + "Neglect", + "Overlook", + "Omit", + "Bypass", + "Pay no attention to", + "Do not follow", + "Do not obey", + ] + self.adjectives = [ + "", + "prior", + "previous", + "preceding", + "above", + "foregoing", + "earlier", + "initial", + ] + self.prepositions = [ + "", + "and start over", + "and start anew", + "and begin afresh", + "and start from scratch", + ] + + def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"): + if level == "INFO": + verbose_proxy_logger.info(print_statement) + elif level == "DEBUG": + verbose_proxy_logger.debug(print_statement) + + if litellm.set_verbose is True: + print(print_statement) # noqa + + def update_environment(self, router: Optional[Router] = None): + self.llm_router = router + + if ( + self.prompt_injection_params is not None + and self.prompt_injection_params.llm_api_check is True + ): + if self.llm_router is None: + raise Exception( + "PromptInjectionDetection: Model List not set. Required for Prompt Injection detection." + ) + + self.print_verbose( + f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}" + ) + if ( + self.prompt_injection_params.llm_api_name is None + or self.prompt_injection_params.llm_api_name + not in self.llm_router.model_names + ): + raise Exception( + "PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'." + ) + + def generate_injection_keywords(self) -> List[str]: + combinations = [] + for verb in self.verbs: + for adj in self.adjectives: + for prep in self.prepositions: + phrase = " ".join(filter(None, [verb, adj, prep])).strip() + if ( + len(phrase.split()) > 2 + ): # additional check to ensure more than 2 words + combinations.append(phrase.lower()) + return combinations + + def check_user_input_similarity( + self, user_input: str, similarity_threshold: float = 0.7 + ) -> bool: + user_input_lower = user_input.lower() + keywords = self.generate_injection_keywords() + + for keyword in keywords: + # Calculate the length of the keyword to extract substrings of the same length from user input + keyword_length = len(keyword) + + for i in range(len(user_input_lower) - keyword_length + 1): + # Extract a substring of the same length as the keyword + substring = user_input_lower[i : i + keyword_length] + + # Calculate similarity + match_ratio = SequenceMatcher(None, substring, keyword).ratio() + if match_ratio > similarity_threshold: + self.print_verbose( + print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}", + level="INFO", + ) + return True # Found a highly similar substring + return False # No substring crossed the threshold + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, # "completion", "embeddings", "image_generation", "moderation" + ): + try: + """ + - check if user id part of call + - check if user id part of blocked list + """ + self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook") + try: + assert call_type in [ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ] + except Exception: + self.print_verbose( + f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']" + ) + return data + formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore + + is_prompt_attack = False + + if self.prompt_injection_params is not None: + # 1. check if heuristics check turned on + if self.prompt_injection_params.heuristics_check is True: + is_prompt_attack = self.check_user_input_similarity( + user_input=formatted_prompt + ) + if is_prompt_attack is True: + raise HTTPException( + status_code=400, + detail={ + "error": "Rejected message. This is a prompt injection attack." + }, + ) + # 2. check if vector db similarity check turned on [TODO] Not Implemented yet + if self.prompt_injection_params.vector_db_check is True: + pass + else: + is_prompt_attack = self.check_user_input_similarity( + user_input=formatted_prompt + ) + + if is_prompt_attack is True: + raise HTTPException( + status_code=400, + detail={ + "error": "Rejected message. This is a prompt injection attack." + }, + ) + + return data + + except HTTPException as e: + + if ( + e.status_code == 400 + and isinstance(e.detail, dict) + and "error" in e.detail # type: ignore + and self.prompt_injection_params is not None + and self.prompt_injection_params.reject_as_response + ): + return e.detail.get("error") + raise e + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( + str(e) + ) + ) + + async def async_moderation_hook( # type: ignore + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal[ + "completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ], + ) -> Optional[bool]: + self.print_verbose( + f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}" + ) + + if self.prompt_injection_params is None: + return None + + formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore + is_prompt_attack = False + + prompt_injection_system_prompt = getattr( + self.prompt_injection_params, + "llm_api_system_prompt", + prompt_injection_detection_default_pt(), + ) + + # 3. check if llm api check turned on + if ( + self.prompt_injection_params.llm_api_check is True + and self.prompt_injection_params.llm_api_name is not None + and self.llm_router is not None + ): + # make a call to the llm api + response = await self.llm_router.acompletion( + model=self.prompt_injection_params.llm_api_name, + messages=[ + { + "role": "system", + "content": prompt_injection_system_prompt, + }, + {"role": "user", "content": formatted_prompt}, + ], + ) + + self.print_verbose(f"Received LLM Moderation response: {response}") + self.print_verbose( + f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}" + ) + if isinstance(response, litellm.ModelResponse) and isinstance( + response.choices[0], litellm.Choices + ): + if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: # type: ignore + is_prompt_attack = True + + if is_prompt_attack is True: + raise HTTPException( + status_code=400, + detail={ + "error": "Rejected message. This is a prompt injection attack." + }, + ) + + return is_prompt_attack diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py new file mode 100644 index 00000000..e8a94732 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -0,0 +1,246 @@ +import asyncio +import traceback +from datetime import datetime +from typing import Any, Optional, Union, cast + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.core_helpers import ( + _get_parent_otel_span_from_kwargs, + get_litellm_metadata_from_kwargs, +) +from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.auth.auth_checks import log_db_metrics +from litellm.proxy.utils import ProxyUpdateSpend +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingUserAPIKeyMetadata, +) +from litellm.utils import get_end_user_id_for_cost_tracking + + +class _ProxyDBLogger(CustomLogger): + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + await self._PROXY_track_cost_callback( + kwargs, response_obj, start_time, end_time + ) + + async def async_post_call_failure_hook( + self, + request_data: dict, + original_exception: Exception, + user_api_key_dict: UserAPIKeyAuth, + ): + from litellm.proxy.proxy_server import update_database + + if _ProxyDBLogger._should_track_errors_in_db() is False: + return + + _metadata = dict( + StandardLoggingUserAPIKeyMetadata( + user_api_key_hash=user_api_key_dict.api_key, + user_api_key_alias=user_api_key_dict.key_alias, + user_api_key_user_email=user_api_key_dict.user_email, + user_api_key_user_id=user_api_key_dict.user_id, + user_api_key_team_id=user_api_key_dict.team_id, + user_api_key_org_id=user_api_key_dict.org_id, + user_api_key_team_alias=user_api_key_dict.team_alias, + user_api_key_end_user_id=user_api_key_dict.end_user_id, + ) + ) + _metadata["user_api_key"] = user_api_key_dict.api_key + _metadata["status"] = "failure" + _metadata["error_information"] = ( + StandardLoggingPayloadSetup.get_error_information( + original_exception=original_exception, + ) + ) + + existing_metadata: dict = request_data.get("metadata", None) or {} + existing_metadata.update(_metadata) + + if "litellm_params" not in request_data: + request_data["litellm_params"] = {} + request_data["litellm_params"]["proxy_server_request"] = ( + request_data.get("proxy_server_request") or {} + ) + request_data["litellm_params"]["metadata"] = existing_metadata + await update_database( + token=user_api_key_dict.api_key, + response_cost=0.0, + user_id=user_api_key_dict.user_id, + end_user_id=user_api_key_dict.end_user_id, + team_id=user_api_key_dict.team_id, + kwargs=request_data, + completion_response=original_exception, + start_time=datetime.now(), + end_time=datetime.now(), + org_id=user_api_key_dict.org_id, + ) + + @log_db_metrics + async def _PROXY_track_cost_callback( + self, + kwargs, # kwargs to completion + completion_response: Optional[ + Union[litellm.ModelResponse, Any] + ], # response from completion + start_time=None, + end_time=None, # start/end time for completion + ): + from litellm.proxy.proxy_server import ( + prisma_client, + proxy_logging_obj, + update_cache, + update_database, + ) + + verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback") + try: + verbose_proxy_logger.debug( + f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) + litellm_params = kwargs.get("litellm_params", {}) or {} + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) + metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) + user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None)) + team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None)) + org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None)) + key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None)) + end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) + sl_object: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + response_cost = ( + sl_object.get("response_cost", None) + if sl_object is not None + else kwargs.get("response_cost", None) + ) + + if response_cost is not None: + user_api_key = metadata.get("user_api_key", None) + if kwargs.get("cache_hit", False) is True: + response_cost = 0.0 + verbose_proxy_logger.info( + f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" + ) + + verbose_proxy_logger.debug( + f"user_api_key {user_api_key}, prisma_client: {prisma_client}" + ) + if _should_track_cost_callback( + user_api_key=user_api_key, + user_id=user_id, + team_id=team_id, + end_user_id=end_user_id, + ): + ## UPDATE DATABASE + await update_database( + token=user_api_key, + response_cost=response_cost, + user_id=user_id, + end_user_id=end_user_id, + team_id=team_id, + kwargs=kwargs, + completion_response=completion_response, + start_time=start_time, + end_time=end_time, + org_id=org_id, + ) + + # update cache + asyncio.create_task( + update_cache( + token=user_api_key, + user_id=user_id, + end_user_id=end_user_id, + response_cost=response_cost, + team_id=team_id, + parent_otel_span=parent_otel_span, + ) + ) + + await proxy_logging_obj.slack_alerting_instance.customer_spend_alert( + token=user_api_key, + key_alias=key_alias, + end_user_id=end_user_id, + response_cost=response_cost, + max_budget=end_user_max_budget, + ) + else: + raise Exception( + "User API key and team id and user id missing from custom callback." + ) + else: + if kwargs["stream"] is not True or ( + kwargs["stream"] is True and "complete_streaming_response" in kwargs + ): + if sl_object is not None: + cost_tracking_failure_debug_info: Union[dict, str] = ( + sl_object["response_cost_failure_debug_info"] # type: ignore + or "response_cost_failure_debug_info is None in standard_logging_object" + ) + else: + cost_tracking_failure_debug_info = ( + "standard_logging_object not found" + ) + model = kwargs.get("model") + raise Exception( + f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" + ) + except Exception as e: + error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}" + model = kwargs.get("model", "") + metadata = kwargs.get("litellm_params", {}).get("metadata", {}) + error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n" + asyncio.create_task( + proxy_logging_obj.failed_tracking_alert( + error_message=error_msg, + failing_model=model, + ) + ) + verbose_proxy_logger.exception( + "Error in tracking cost callback - %s", str(e) + ) + + @staticmethod + def _should_track_errors_in_db(): + """ + Returns True if errors should be tracked in the database + + By default, errors are tracked in the database + + If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings + """ + from litellm.proxy.proxy_server import general_settings + + if general_settings.get("disable_error_logs") is True: + return False + return + + +def _should_track_cost_callback( + user_api_key: Optional[str], + user_id: Optional[str], + team_id: Optional[str], + end_user_id: Optional[str], +) -> bool: + """ + Determine if the cost callback should be tracked based on the kwargs + """ + + # don't run track cost callback if user opted into disabling spend + if ProxyUpdateSpend.disable_spend_updates() is True: + return False + + if ( + user_api_key is not None + or user_id is not None + or team_id is not None + or end_user_id is not None + ): + return True + return False |