aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/hooks
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/__init__.py1
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py156
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/batch_redis_get.py149
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/cache_control_check.py58
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/dynamic_rate_limiter.py298
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/example_presidio_ad_hoc_recognizer.json28
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py324
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/max_budget_limiter.py49
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/model_max_budget_limiter.py192
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/parallel_request_limiter.py866
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py280
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py246
12 files changed, 2647 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/__init__.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/__init__.py
new file mode 100644
index 00000000..b6e690fd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/__init__.py
@@ -0,0 +1 @@
+from . import *
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py
new file mode 100644
index 00000000..b35d6711
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py
@@ -0,0 +1,156 @@
+import traceback
+from typing import Optional
+
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+class _PROXY_AzureContentSafety(
+ CustomLogger
+): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
+ # Class variables or attributes
+
+ def __init__(self, endpoint, api_key, thresholds=None):
+ try:
+ from azure.ai.contentsafety.aio import ContentSafetyClient
+ from azure.ai.contentsafety.models import (
+ AnalyzeTextOptions,
+ AnalyzeTextOutputType,
+ TextCategory,
+ )
+ from azure.core.credentials import AzureKeyCredential
+ from azure.core.exceptions import HttpResponseError
+ except Exception as e:
+ raise Exception(
+ f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
+ )
+ self.endpoint = endpoint
+ self.api_key = api_key
+ self.text_category = TextCategory
+ self.analyze_text_options = AnalyzeTextOptions
+ self.analyze_text_output_type = AnalyzeTextOutputType
+ self.azure_http_error = HttpResponseError
+
+ self.thresholds = self._configure_thresholds(thresholds)
+
+ self.client = ContentSafetyClient(
+ self.endpoint, AzureKeyCredential(self.api_key)
+ )
+
+ def _configure_thresholds(self, thresholds=None):
+ default_thresholds = {
+ self.text_category.HATE: 4,
+ self.text_category.SELF_HARM: 4,
+ self.text_category.SEXUAL: 4,
+ self.text_category.VIOLENCE: 4,
+ }
+
+ if thresholds is None:
+ return default_thresholds
+
+ for key, default in default_thresholds.items():
+ if key not in thresholds:
+ thresholds[key] = default
+
+ return thresholds
+
+ def _compute_result(self, response):
+ result = {}
+
+ category_severity = {
+ item.category: item.severity for item in response.categories_analysis
+ }
+ for category in self.text_category:
+ severity = category_severity.get(category)
+ if severity is not None:
+ result[category] = {
+ "filtered": severity >= self.thresholds[category],
+ "severity": severity,
+ }
+
+ return result
+
+ async def test_violation(self, content: str, source: Optional[str] = None):
+ verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
+
+ # Construct a request
+ request = self.analyze_text_options(
+ text=content,
+ output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
+ )
+
+ # Analyze text
+ try:
+ response = await self.client.analyze_text(request)
+ except self.azure_http_error:
+ verbose_proxy_logger.debug(
+ "Error in Azure Content-Safety: %s", traceback.format_exc()
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ raise
+
+ result = self._compute_result(response)
+ verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
+
+ for key, value in result.items():
+ if value["filtered"]:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Violated content safety policy",
+ "source": source,
+ "category": key,
+ "severity": value["severity"],
+ },
+ )
+
+ async def async_pre_call_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: str, # "completion", "embeddings", "image_generation", "moderation"
+ ):
+ verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
+ try:
+ if call_type == "completion" and "messages" in data:
+ for m in data["messages"]:
+ if "content" in m and isinstance(m["content"], str):
+ await self.test_violation(content=m["content"], source="input")
+
+ except HTTPException as e:
+ raise e
+ except Exception as e:
+ verbose_proxy_logger.error(
+ "litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+
+ async def async_post_call_success_hook(
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ response,
+ ):
+ verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
+ if isinstance(response, litellm.ModelResponse) and isinstance(
+ response.choices[0], litellm.utils.Choices
+ ):
+ await self.test_violation(
+ content=response.choices[0].message.content or "", source="output"
+ )
+
+ # async def async_post_call_streaming_hook(
+ # self,
+ # user_api_key_dict: UserAPIKeyAuth,
+ # response: str,
+ # ):
+ # verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
+ # await self.test_violation(content=response, source="output")
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/batch_redis_get.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/batch_redis_get.py
new file mode 100644
index 00000000..c608317f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/batch_redis_get.py
@@ -0,0 +1,149 @@
+# What this does?
+## Gets a key's redis cache, and store it in memory for 1 minute.
+## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
+### [BETA] this is in Beta. And might change.
+
+import traceback
+from typing import Literal, Optional
+
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+class _PROXY_BatchRedisRequests(CustomLogger):
+ # Class variables or attributes
+ in_memory_cache: Optional[InMemoryCache] = None
+
+ def __init__(self):
+ if litellm.cache is not None:
+ litellm.cache.async_get_cache = (
+ self.async_get_cache
+ ) # map the litellm 'get_cache' function to our custom function
+
+ def print_verbose(
+ self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
+ ):
+ if debug_level == "DEBUG":
+ verbose_proxy_logger.debug(print_statement)
+ elif debug_level == "INFO":
+ verbose_proxy_logger.debug(print_statement)
+ if litellm.set_verbose is True:
+ print(print_statement) # noqa
+
+ async def async_pre_call_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: str,
+ ):
+ try:
+ """
+ Get the user key
+
+ Check if a key starting with `litellm:<api_key>:<call_type:` exists in-memory
+
+ If no, then get relevant cache from redis
+ """
+ api_key = user_api_key_dict.api_key
+
+ cache_key_name = f"litellm:{api_key}:{call_type}"
+ self.in_memory_cache = cache.in_memory_cache
+
+ key_value_dict = {}
+ in_memory_cache_exists = False
+ for key in cache.in_memory_cache.cache_dict.keys():
+ if isinstance(key, str) and key.startswith(cache_key_name):
+ in_memory_cache_exists = True
+
+ if in_memory_cache_exists is False and litellm.cache is not None:
+ """
+ - Check if `litellm.Cache` is redis
+ - Get the relevant values
+ """
+ if litellm.cache.type is not None and isinstance(
+ litellm.cache.cache, RedisCache
+ ):
+ # Initialize an empty list to store the keys
+ keys = []
+ self.print_verbose(f"cache_key_name: {cache_key_name}")
+ # Use the SCAN iterator to fetch keys matching the pattern
+ keys = await litellm.cache.cache.async_scan_iter(
+ pattern=cache_key_name, count=100
+ )
+ # If you need the truly "last" based on time or another criteria,
+ # ensure your key naming or storage strategy allows this determination
+ # Here you would sort or filter the keys as needed based on your strategy
+ self.print_verbose(f"redis keys: {keys}")
+ if len(keys) > 0:
+ key_value_dict = (
+ await litellm.cache.cache.async_batch_get_cache(
+ key_list=keys
+ )
+ )
+
+ ## Add to cache
+ if len(key_value_dict.items()) > 0:
+ await cache.in_memory_cache.async_set_cache_pipeline(
+ cache_list=list(key_value_dict.items()), ttl=60
+ )
+ ## Set cache namespace if it's a miss
+ data["metadata"]["redis_namespace"] = cache_key_name
+ except HTTPException as e:
+ raise e
+ except Exception as e:
+ verbose_proxy_logger.error(
+ "litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+
+ async def async_get_cache(self, *args, **kwargs):
+ """
+ - Check if the cache key is in-memory
+
+ - Else:
+ - add missing cache key from REDIS
+ - update in-memory cache
+ - return redis cache request
+ """
+ try: # never block execution
+ cache_key: Optional[str] = None
+ if "cache_key" in kwargs:
+ cache_key = kwargs["cache_key"]
+ elif litellm.cache is not None:
+ cache_key = litellm.cache.get_cache_key(
+ *args, **kwargs
+ ) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
+
+ if (
+ cache_key is not None
+ and self.in_memory_cache is not None
+ and litellm.cache is not None
+ ):
+ cache_control_args = kwargs.get("cache", {})
+ max_age = cache_control_args.get(
+ "s-max-age", cache_control_args.get("s-maxage", float("inf"))
+ )
+ cached_result = self.in_memory_cache.get_cache(
+ cache_key, *args, **kwargs
+ )
+ if cached_result is None:
+ cached_result = await litellm.cache.cache.async_get_cache(
+ cache_key, *args, **kwargs
+ )
+ if cached_result is not None:
+ await self.in_memory_cache.async_set_cache(
+ cache_key, cached_result, ttl=60
+ )
+ return litellm.cache._get_cache_logic(
+ cached_result=cached_result, max_age=max_age
+ )
+ except Exception:
+ return None
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/cache_control_check.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/cache_control_check.py
new file mode 100644
index 00000000..6e3fbf84
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/cache_control_check.py
@@ -0,0 +1,58 @@
+# What this does?
+## Checks if key is allowed to use the cache controls passed in to the completion() call
+
+
+from fastapi import HTTPException
+
+from litellm import verbose_logger
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+class _PROXY_CacheControlCheck(CustomLogger):
+ # Class variables or attributes
+ def __init__(self):
+ pass
+
+ async def async_pre_call_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: str,
+ ):
+ try:
+ verbose_proxy_logger.debug("Inside Cache Control Check Pre-Call Hook")
+ allowed_cache_controls = user_api_key_dict.allowed_cache_controls
+
+ if data.get("cache", None) is None:
+ return
+
+ cache_args = data.get("cache", None)
+ if isinstance(cache_args, dict):
+ for k, v in cache_args.items():
+ if (
+ (allowed_cache_controls is not None)
+ and (isinstance(allowed_cache_controls, list))
+ and (
+ len(allowed_cache_controls) > 0
+ ) # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663
+ and k not in allowed_cache_controls
+ ):
+ raise HTTPException(
+ status_code=403,
+ detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.",
+ )
+ else: # invalid cache
+ return
+
+ except HTTPException as e:
+ raise e
+ except Exception as e:
+ verbose_logger.exception(
+ "litellm.proxy.hooks.cache_control_check.py::async_pre_call_hook(): Exception occured - {}".format(
+ str(e)
+ )
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/dynamic_rate_limiter.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/dynamic_rate_limiter.py
new file mode 100644
index 00000000..15a9bc1b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/dynamic_rate_limiter.py
@@ -0,0 +1,298 @@
+# What is this?
+## Allocates dynamic tpm/rpm quota for a project based on current traffic
+## Tracks num active projects per minute
+
+import asyncio
+import os
+from typing import List, Literal, Optional, Tuple, Union
+
+from fastapi import HTTPException
+
+import litellm
+from litellm import ModelResponse, Router
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.types.router import ModelGroupInfo
+from litellm.utils import get_utc_datetime
+
+
+class DynamicRateLimiterCache:
+ """
+ Thin wrapper on DualCache for this file.
+
+ Track number of active projects calling a model.
+ """
+
+ def __init__(self, cache: DualCache) -> None:
+ self.cache = cache
+ self.ttl = 60 # 1 min ttl
+
+ async def async_get_cache(self, model: str) -> Optional[int]:
+ dt = get_utc_datetime()
+ current_minute = dt.strftime("%H-%M")
+ key_name = "{}:{}".format(current_minute, model)
+ _response = await self.cache.async_get_cache(key=key_name)
+ response: Optional[int] = None
+ if _response is not None:
+ response = len(_response)
+ return response
+
+ async def async_set_cache_sadd(self, model: str, value: List):
+ """
+ Add value to set.
+
+ Parameters:
+ - model: str, the name of the model group
+ - value: str, the team id
+
+ Returns:
+ - None
+
+ Raises:
+ - Exception, if unable to connect to cache client (if redis caching enabled)
+ """
+ try:
+ dt = get_utc_datetime()
+ current_minute = dt.strftime("%H-%M")
+
+ key_name = "{}:{}".format(current_minute, model)
+ await self.cache.async_set_cache_sadd(
+ key=key_name, value=value, ttl=self.ttl
+ )
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ raise e
+
+
+class _PROXY_DynamicRateLimitHandler(CustomLogger):
+
+ # Class variables or attributes
+ def __init__(self, internal_usage_cache: DualCache):
+ self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache)
+
+ def update_variables(self, llm_router: Router):
+ self.llm_router = llm_router
+
+ async def check_available_usage(
+ self, model: str, priority: Optional[str] = None
+ ) -> Tuple[
+ Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]
+ ]:
+ """
+ For a given model, get its available tpm
+
+ Params:
+ - model: str, the name of the model in the router model_list
+ - priority: Optional[str], the priority for the request.
+
+ Returns
+ - Tuple[available_tpm, available_tpm, model_tpm, model_rpm, active_projects]
+ - available_tpm: int or null - always 0 or positive.
+ - available_tpm: int or null - always 0 or positive.
+ - remaining_model_tpm: int or null. If available tpm is int, then this will be too.
+ - remaining_model_rpm: int or null. If available rpm is int, then this will be too.
+ - active_projects: int or null
+ """
+ try:
+ weight: float = 1
+ if (
+ litellm.priority_reservation is None
+ or priority not in litellm.priority_reservation
+ ):
+ verbose_proxy_logger.error(
+ "Priority Reservation not set. priority={}, but litellm.priority_reservation is {}.".format(
+ priority, litellm.priority_reservation
+ )
+ )
+ elif priority is not None and litellm.priority_reservation is not None:
+ if os.getenv("LITELLM_LICENSE", None) is None:
+ verbose_proxy_logger.error(
+ "PREMIUM FEATURE: Reserving tpm/rpm by priority is a premium feature. Please add a 'LITELLM_LICENSE' to your .env to enable this.\nGet a license: https://docs.litellm.ai/docs/proxy/enterprise."
+ )
+ else:
+ weight = litellm.priority_reservation[priority]
+
+ active_projects = await self.internal_usage_cache.async_get_cache(
+ model=model
+ )
+ current_model_tpm, current_model_rpm = (
+ await self.llm_router.get_model_group_usage(model_group=model)
+ )
+ model_group_info: Optional[ModelGroupInfo] = (
+ self.llm_router.get_model_group_info(model_group=model)
+ )
+ total_model_tpm: Optional[int] = None
+ total_model_rpm: Optional[int] = None
+ if model_group_info is not None:
+ if model_group_info.tpm is not None:
+ total_model_tpm = model_group_info.tpm
+ if model_group_info.rpm is not None:
+ total_model_rpm = model_group_info.rpm
+
+ remaining_model_tpm: Optional[int] = None
+ if total_model_tpm is not None and current_model_tpm is not None:
+ remaining_model_tpm = total_model_tpm - current_model_tpm
+ elif total_model_tpm is not None:
+ remaining_model_tpm = total_model_tpm
+
+ remaining_model_rpm: Optional[int] = None
+ if total_model_rpm is not None and current_model_rpm is not None:
+ remaining_model_rpm = total_model_rpm - current_model_rpm
+ elif total_model_rpm is not None:
+ remaining_model_rpm = total_model_rpm
+
+ available_tpm: Optional[int] = None
+
+ if remaining_model_tpm is not None:
+ if active_projects is not None:
+ available_tpm = int(remaining_model_tpm * weight / active_projects)
+ else:
+ available_tpm = int(remaining_model_tpm * weight)
+
+ if available_tpm is not None and available_tpm < 0:
+ available_tpm = 0
+
+ available_rpm: Optional[int] = None
+
+ if remaining_model_rpm is not None:
+ if active_projects is not None:
+ available_rpm = int(remaining_model_rpm * weight / active_projects)
+ else:
+ available_rpm = int(remaining_model_rpm * weight)
+
+ if available_rpm is not None and available_rpm < 0:
+ available_rpm = 0
+ return (
+ available_tpm,
+ available_rpm,
+ remaining_model_tpm,
+ remaining_model_rpm,
+ active_projects,
+ )
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "litellm.proxy.hooks.dynamic_rate_limiter.py::check_available_usage: Exception occurred - {}".format(
+ str(e)
+ )
+ )
+ return None, None, None, None, None
+
+ async def async_pre_call_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: Literal[
+ "completion",
+ "text_completion",
+ "embeddings",
+ "image_generation",
+ "moderation",
+ "audio_transcription",
+ "pass_through_endpoint",
+ "rerank",
+ ],
+ ) -> Optional[
+ Union[Exception, str, dict]
+ ]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
+ """
+ - For a model group
+ - Check if tpm/rpm available
+ - Raise RateLimitError if no tpm/rpm available
+ """
+ if "model" in data:
+ key_priority: Optional[str] = user_api_key_dict.metadata.get(
+ "priority", None
+ )
+ available_tpm, available_rpm, model_tpm, model_rpm, active_projects = (
+ await self.check_available_usage(
+ model=data["model"], priority=key_priority
+ )
+ )
+ ### CHECK TPM ###
+ if available_tpm is not None and available_tpm == 0:
+ raise HTTPException(
+ status_code=429,
+ detail={
+ "error": "Key={} over available TPM={}. Model TPM={}, Active keys={}".format(
+ user_api_key_dict.api_key,
+ available_tpm,
+ model_tpm,
+ active_projects,
+ )
+ },
+ )
+ ### CHECK RPM ###
+ elif available_rpm is not None and available_rpm == 0:
+ raise HTTPException(
+ status_code=429,
+ detail={
+ "error": "Key={} over available RPM={}. Model RPM={}, Active keys={}".format(
+ user_api_key_dict.api_key,
+ available_rpm,
+ model_rpm,
+ active_projects,
+ )
+ },
+ )
+ elif available_rpm is not None or available_tpm is not None:
+ ## UPDATE CACHE WITH ACTIVE PROJECT
+ asyncio.create_task(
+ self.internal_usage_cache.async_set_cache_sadd( # this is a set
+ model=data["model"], # type: ignore
+ value=[user_api_key_dict.token or "default_key"],
+ )
+ )
+ return None
+
+ async def async_post_call_success_hook(
+ self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
+ ):
+ try:
+ if isinstance(response, ModelResponse):
+ model_info = self.llm_router.get_model_info(
+ id=response._hidden_params["model_id"]
+ )
+ assert (
+ model_info is not None
+ ), "Model info for model with id={} is None".format(
+ response._hidden_params["model_id"]
+ )
+ key_priority: Optional[str] = user_api_key_dict.metadata.get(
+ "priority", None
+ )
+ available_tpm, available_rpm, model_tpm, model_rpm, active_projects = (
+ await self.check_available_usage(
+ model=model_info["model_name"], priority=key_priority
+ )
+ )
+ response._hidden_params["additional_headers"] = (
+ { # Add additional response headers - easier debugging
+ "x-litellm-model_group": model_info["model_name"],
+ "x-ratelimit-remaining-litellm-project-tokens": available_tpm,
+ "x-ratelimit-remaining-litellm-project-requests": available_rpm,
+ "x-ratelimit-remaining-model-tokens": model_tpm,
+ "x-ratelimit-remaining-model-requests": model_rpm,
+ "x-ratelimit-current-active-projects": active_projects,
+ }
+ )
+
+ return response
+ return await super().async_post_call_success_hook(
+ data=data,
+ user_api_key_dict=user_api_key_dict,
+ response=response,
+ )
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ return response
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/example_presidio_ad_hoc_recognizer.json b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/example_presidio_ad_hoc_recognizer.json
new file mode 100644
index 00000000..6a94d8de
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/example_presidio_ad_hoc_recognizer.json
@@ -0,0 +1,28 @@
+[
+ {
+ "name": "Zip code Recognizer",
+ "supported_language": "en",
+ "patterns": [
+ {
+ "name": "zip code (weak)",
+ "regex": "(\\b\\d{5}(?:\\-\\d{4})?\\b)",
+ "score": 0.01
+ }
+ ],
+ "context": ["zip", "code"],
+ "supported_entity": "ZIP"
+ },
+ {
+ "name": "Swiss AHV Number Recognizer",
+ "supported_language": "en",
+ "patterns": [
+ {
+ "name": "AHV number (strong)",
+ "regex": "(756\\.\\d{4}\\.\\d{4}\\.\\d{2})|(756\\d{10})",
+ "score": 0.95
+ }
+ ],
+ "context": ["AHV", "social security", "Swiss"],
+ "supported_entity": "AHV_NUMBER"
+ }
+] \ No newline at end of file
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py
new file mode 100644
index 00000000..2030cb2a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py
@@ -0,0 +1,324 @@
+import asyncio
+import json
+import uuid
+from datetime import datetime, timezone
+from typing import Any, List, Optional
+
+from fastapi import status
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import (
+ GenerateKeyRequest,
+ GenerateKeyResponse,
+ KeyRequest,
+ LiteLLM_AuditLogs,
+ LiteLLM_VerificationToken,
+ LitellmTableNames,
+ ProxyErrorTypes,
+ ProxyException,
+ RegenerateKeyRequest,
+ UpdateKeyRequest,
+ UserAPIKeyAuth,
+ WebhookEvent,
+)
+
+# NOTE: This is the prefix for all virtual keys stored in AWS Secrets Manager
+LITELLM_PREFIX_STORED_VIRTUAL_KEYS = "litellm/"
+
+
+class KeyManagementEventHooks:
+
+ @staticmethod
+ async def async_key_generated_hook(
+ data: GenerateKeyRequest,
+ response: GenerateKeyResponse,
+ user_api_key_dict: UserAPIKeyAuth,
+ litellm_changed_by: Optional[str] = None,
+ ):
+ """
+ Hook that runs after a successful /key/generate request
+
+ Handles the following:
+ - Sending Email with Key Details
+ - Storing Audit Logs for key generation
+ - Storing Generated Key in DB
+ """
+ from litellm.proxy.management_helpers.audit_logs import (
+ create_audit_log_for_update,
+ )
+ from litellm.proxy.proxy_server import litellm_proxy_admin_name
+
+ if data.send_invite_email is True:
+ await KeyManagementEventHooks._send_key_created_email(
+ response.model_dump(exclude_none=True)
+ )
+
+ # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+ if litellm.store_audit_logs is True:
+ _updated_values = response.model_dump_json(exclude_none=True)
+ asyncio.create_task(
+ create_audit_log_for_update(
+ request_data=LiteLLM_AuditLogs(
+ id=str(uuid.uuid4()),
+ updated_at=datetime.now(timezone.utc),
+ changed_by=litellm_changed_by
+ or user_api_key_dict.user_id
+ or litellm_proxy_admin_name,
+ changed_by_api_key=user_api_key_dict.api_key,
+ table_name=LitellmTableNames.KEY_TABLE_NAME,
+ object_id=response.token_id or "",
+ action="created",
+ updated_values=_updated_values,
+ before_value=None,
+ )
+ )
+ )
+ # store the generated key in the secret manager
+ await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
+ secret_name=data.key_alias or f"virtual-key-{response.token_id}",
+ secret_token=response.key,
+ )
+
+ @staticmethod
+ async def async_key_updated_hook(
+ data: UpdateKeyRequest,
+ existing_key_row: Any,
+ response: Any,
+ user_api_key_dict: UserAPIKeyAuth,
+ litellm_changed_by: Optional[str] = None,
+ ):
+ """
+ Post /key/update processing hook
+
+ Handles the following:
+ - Storing Audit Logs for key update
+ """
+ from litellm.proxy.management_helpers.audit_logs import (
+ create_audit_log_for_update,
+ )
+ from litellm.proxy.proxy_server import litellm_proxy_admin_name
+
+ # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+ if litellm.store_audit_logs is True:
+ _updated_values = json.dumps(data.json(exclude_none=True), default=str)
+
+ _before_value = existing_key_row.json(exclude_none=True)
+ _before_value = json.dumps(_before_value, default=str)
+
+ asyncio.create_task(
+ create_audit_log_for_update(
+ request_data=LiteLLM_AuditLogs(
+ id=str(uuid.uuid4()),
+ updated_at=datetime.now(timezone.utc),
+ changed_by=litellm_changed_by
+ or user_api_key_dict.user_id
+ or litellm_proxy_admin_name,
+ changed_by_api_key=user_api_key_dict.api_key,
+ table_name=LitellmTableNames.KEY_TABLE_NAME,
+ object_id=data.key,
+ action="updated",
+ updated_values=_updated_values,
+ before_value=_before_value,
+ )
+ )
+ )
+
+ @staticmethod
+ async def async_key_rotated_hook(
+ data: Optional[RegenerateKeyRequest],
+ existing_key_row: Any,
+ response: GenerateKeyResponse,
+ user_api_key_dict: UserAPIKeyAuth,
+ litellm_changed_by: Optional[str] = None,
+ ):
+ # store the generated key in the secret manager
+ if data is not None and response.token_id is not None:
+ initial_secret_name = (
+ existing_key_row.key_alias or f"virtual-key-{existing_key_row.token}"
+ )
+ await KeyManagementEventHooks._rotate_virtual_key_in_secret_manager(
+ current_secret_name=initial_secret_name,
+ new_secret_name=data.key_alias or f"virtual-key-{response.token_id}",
+ new_secret_value=response.key,
+ )
+
+ @staticmethod
+ async def async_key_deleted_hook(
+ data: KeyRequest,
+ keys_being_deleted: List[LiteLLM_VerificationToken],
+ response: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ litellm_changed_by: Optional[str] = None,
+ ):
+ """
+ Post /key/delete processing hook
+
+ Handles the following:
+ - Storing Audit Logs for key deletion
+ """
+ from litellm.proxy.management_helpers.audit_logs import (
+ create_audit_log_for_update,
+ )
+ from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
+
+ # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+ # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
+ if litellm.store_audit_logs is True and data.keys is not None:
+ # make an audit log for each team deleted
+ for key in data.keys:
+ key_row = await prisma_client.get_data( # type: ignore
+ token=key, table_name="key", query_type="find_unique"
+ )
+
+ if key_row is None:
+ raise ProxyException(
+ message=f"Key {key} not found",
+ type=ProxyErrorTypes.bad_request_error,
+ param="key",
+ code=status.HTTP_404_NOT_FOUND,
+ )
+
+ key_row = key_row.json(exclude_none=True)
+ _key_row = json.dumps(key_row, default=str)
+
+ asyncio.create_task(
+ create_audit_log_for_update(
+ request_data=LiteLLM_AuditLogs(
+ id=str(uuid.uuid4()),
+ updated_at=datetime.now(timezone.utc),
+ changed_by=litellm_changed_by
+ or user_api_key_dict.user_id
+ or litellm_proxy_admin_name,
+ changed_by_api_key=user_api_key_dict.api_key,
+ table_name=LitellmTableNames.KEY_TABLE_NAME,
+ object_id=key,
+ action="deleted",
+ updated_values="{}",
+ before_value=_key_row,
+ )
+ )
+ )
+ # delete the keys from the secret manager
+ await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
+ keys_being_deleted=keys_being_deleted
+ )
+ pass
+
+ @staticmethod
+ async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
+ """
+ Store a virtual key in the secret manager
+
+ Args:
+ secret_name: Name of the virtual key
+ secret_token: Value of the virtual key (example: sk-1234)
+ """
+ if litellm._key_management_settings is not None:
+ if litellm._key_management_settings.store_virtual_keys is True:
+ from litellm.secret_managers.base_secret_manager import (
+ BaseSecretManager,
+ )
+
+ # store the key in the secret manager
+ if isinstance(litellm.secret_manager_client, BaseSecretManager):
+ await litellm.secret_manager_client.async_write_secret(
+ secret_name=KeyManagementEventHooks._get_secret_name(
+ secret_name
+ ),
+ secret_value=secret_token,
+ )
+
+ @staticmethod
+ async def _rotate_virtual_key_in_secret_manager(
+ current_secret_name: str, new_secret_name: str, new_secret_value: str
+ ):
+ """
+ Update a virtual key in the secret manager
+
+ Args:
+ secret_name: Name of the virtual key
+ secret_token: Value of the virtual key (example: sk-1234)
+ """
+ if litellm._key_management_settings is not None:
+ if litellm._key_management_settings.store_virtual_keys is True:
+ from litellm.secret_managers.base_secret_manager import (
+ BaseSecretManager,
+ )
+
+ # store the key in the secret manager
+ if isinstance(litellm.secret_manager_client, BaseSecretManager):
+ await litellm.secret_manager_client.async_rotate_secret(
+ current_secret_name=KeyManagementEventHooks._get_secret_name(
+ current_secret_name
+ ),
+ new_secret_name=KeyManagementEventHooks._get_secret_name(
+ new_secret_name
+ ),
+ new_secret_value=new_secret_value,
+ )
+
+ @staticmethod
+ def _get_secret_name(secret_name: str) -> str:
+ if litellm._key_management_settings.prefix_for_stored_virtual_keys.endswith(
+ "/"
+ ):
+ return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{secret_name}"
+ else:
+ return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}/{secret_name}"
+
+ @staticmethod
+ async def _delete_virtual_keys_from_secret_manager(
+ keys_being_deleted: List[LiteLLM_VerificationToken],
+ ):
+ """
+ Deletes virtual keys from the secret manager
+
+ Args:
+ keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
+ """
+ if litellm._key_management_settings is not None:
+ if litellm._key_management_settings.store_virtual_keys is True:
+ from litellm.secret_managers.base_secret_manager import (
+ BaseSecretManager,
+ )
+
+ if isinstance(litellm.secret_manager_client, BaseSecretManager):
+ for key in keys_being_deleted:
+ if key.key_alias is not None:
+ await litellm.secret_manager_client.async_delete_secret(
+ secret_name=KeyManagementEventHooks._get_secret_name(
+ key.key_alias
+ )
+ )
+ else:
+ verbose_proxy_logger.warning(
+ f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
+ )
+
+ @staticmethod
+ async def _send_key_created_email(response: dict):
+ from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
+
+ if "email" not in general_settings.get("alerting", []):
+ raise ValueError(
+ "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
+ )
+ event = WebhookEvent(
+ event="key_created",
+ event_group="key",
+ event_message="API Key Created",
+ token=response.get("token", ""),
+ spend=response.get("spend", 0.0),
+ max_budget=response.get("max_budget", 0.0),
+ user_id=response.get("user_id", None),
+ team_id=response.get("team_id", "Default Team"),
+ key_alias=response.get("key_alias", None),
+ )
+
+ # If user configured email alerting - send an Email letting their end-user know the key was created
+ asyncio.create_task(
+ proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
+ webhook_event=event,
+ )
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/max_budget_limiter.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/max_budget_limiter.py
new file mode 100644
index 00000000..4b59f603
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/max_budget_limiter.py
@@ -0,0 +1,49 @@
+from fastapi import HTTPException
+
+from litellm import verbose_logger
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+class _PROXY_MaxBudgetLimiter(CustomLogger):
+ # Class variables or attributes
+ def __init__(self):
+ pass
+
+ async def async_pre_call_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: str,
+ ):
+ try:
+ verbose_proxy_logger.debug("Inside Max Budget Limiter Pre-Call Hook")
+ cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
+ user_row = await cache.async_get_cache(
+ cache_key, parent_otel_span=user_api_key_dict.parent_otel_span
+ )
+ if user_row is None: # value not yet cached
+ return
+ max_budget = user_row["max_budget"]
+ curr_spend = user_row["spend"]
+
+ if max_budget is None:
+ return
+
+ if curr_spend is None:
+ return
+
+ # CHECK IF REQUEST ALLOWED
+ if curr_spend >= max_budget:
+ raise HTTPException(status_code=429, detail="Max budget limit reached.")
+ except HTTPException as e:
+ raise e
+ except Exception as e:
+ verbose_logger.exception(
+ "litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format(
+ str(e)
+ )
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/model_max_budget_limiter.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/model_max_budget_limiter.py
new file mode 100644
index 00000000..ac02c915
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/model_max_budget_limiter.py
@@ -0,0 +1,192 @@
+import json
+from typing import List, Optional
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import Span
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import (
+ BudgetConfig,
+ GenericBudgetConfigType,
+ StandardLoggingPayload,
+)
+
+VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend"
+
+
+class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
+ """
+ Handles budgets for model + virtual key
+
+ Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
+ """
+
+ def __init__(self, dual_cache: DualCache):
+ self.dual_cache = dual_cache
+ self.redis_increment_operation_queue = []
+
+ async def is_key_within_model_budget(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ model: str,
+ ) -> bool:
+ """
+ Check if the user_api_key_dict is within the model budget
+
+ Raises:
+ BudgetExceededError: If the user_api_key_dict has exceeded the model budget
+ """
+ _model_max_budget = user_api_key_dict.model_max_budget
+ internal_model_max_budget: GenericBudgetConfigType = {}
+
+ for _model, _budget_info in _model_max_budget.items():
+ internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
+
+ verbose_proxy_logger.debug(
+ "internal_model_max_budget %s",
+ json.dumps(internal_model_max_budget, indent=4, default=str),
+ )
+
+ # check if current model is in internal_model_max_budget
+ _current_model_budget_info = self._get_request_model_budget_config(
+ model=model, internal_model_max_budget=internal_model_max_budget
+ )
+ if _current_model_budget_info is None:
+ verbose_proxy_logger.debug(
+ f"Model {model} not found in internal_model_max_budget"
+ )
+ return True
+
+ # check if current model is within budget
+ if (
+ _current_model_budget_info.max_budget
+ and _current_model_budget_info.max_budget > 0
+ ):
+ _current_spend = await self._get_virtual_key_spend_for_model(
+ user_api_key_hash=user_api_key_dict.token,
+ model=model,
+ key_budget_config=_current_model_budget_info,
+ )
+ if (
+ _current_spend is not None
+ and _current_model_budget_info.max_budget is not None
+ and _current_spend > _current_model_budget_info.max_budget
+ ):
+ raise litellm.BudgetExceededError(
+ message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}",
+ current_cost=_current_spend,
+ max_budget=_current_model_budget_info.max_budget,
+ )
+
+ return True
+
+ async def _get_virtual_key_spend_for_model(
+ self,
+ user_api_key_hash: Optional[str],
+ model: str,
+ key_budget_config: BudgetConfig,
+ ) -> Optional[float]:
+ """
+ Get the current spend for a virtual key for a model
+
+ Lookup model in this order:
+ 1. model: directly look up `model`
+ 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
+ """
+
+ # 1. model: directly look up `model`
+ virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.budget_duration}"
+ _current_spend = await self.dual_cache.async_get_cache(
+ key=virtual_key_model_spend_cache_key,
+ )
+
+ if _current_spend is None:
+ # 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
+ # if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview
+ virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.budget_duration}"
+ _current_spend = await self.dual_cache.async_get_cache(
+ key=virtual_key_model_spend_cache_key,
+ )
+ return _current_spend
+
+ def _get_request_model_budget_config(
+ self, model: str, internal_model_max_budget: GenericBudgetConfigType
+ ) -> Optional[BudgetConfig]:
+ """
+ Get the budget config for the request model
+
+ 1. Check if `model` is in `internal_model_max_budget`
+ 2. If not, check if `model` without custom llm provider is in `internal_model_max_budget`
+ """
+ return internal_model_max_budget.get(
+ model, None
+ ) or internal_model_max_budget.get(
+ self._get_model_without_custom_llm_provider(model), None
+ )
+
+ def _get_model_without_custom_llm_provider(self, model: str) -> str:
+ if "/" in model:
+ return model.split("/")[-1]
+ return model
+
+ async def async_filter_deployments(
+ self,
+ model: str,
+ healthy_deployments: List,
+ messages: Optional[List[AllMessageValues]],
+ request_kwargs: Optional[dict] = None,
+ parent_otel_span: Optional[Span] = None, # type: ignore
+ ) -> List[dict]:
+ return healthy_deployments
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Track spend for virtual key + model in DualCache
+
+ Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
+ """
+ verbose_proxy_logger.debug("in RouterBudgetLimiting.async_log_success_event")
+ standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+ if standard_logging_payload is None:
+ raise ValueError("standard_logging_payload is required")
+
+ _litellm_params: dict = kwargs.get("litellm_params", {}) or {}
+ _metadata: dict = _litellm_params.get("metadata", {}) or {}
+ user_api_key_model_max_budget: Optional[dict] = _metadata.get(
+ "user_api_key_model_max_budget", None
+ )
+ if (
+ user_api_key_model_max_budget is None
+ or len(user_api_key_model_max_budget) == 0
+ ):
+ verbose_proxy_logger.debug(
+ "Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget is None or empty. `user_api_key_model_max_budget`=%s",
+ user_api_key_model_max_budget,
+ )
+ return
+ response_cost: float = standard_logging_payload.get("response_cost", 0)
+ model = standard_logging_payload.get("model")
+
+ virtual_key = standard_logging_payload.get("metadata").get("user_api_key_hash")
+ model = standard_logging_payload.get("model")
+ if virtual_key is not None:
+ budget_config = BudgetConfig(time_period="1d", budget_limit=0.1)
+ virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_config.budget_duration}"
+ virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
+ await self._increment_spend_for_key(
+ budget_config=budget_config,
+ spend_key=virtual_spend_key,
+ start_time_key=virtual_start_time_key,
+ response_cost=response_cost,
+ )
+ verbose_proxy_logger.debug(
+ "current state of in memory cache %s",
+ json.dumps(
+ self.dual_cache.in_memory_cache.cache_dict, indent=4, default=str
+ ),
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/parallel_request_limiter.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/parallel_request_limiter.py
new file mode 100644
index 00000000..06f3b6af
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/parallel_request_limiter.py
@@ -0,0 +1,866 @@
+import asyncio
+import sys
+from datetime import datetime, timedelta
+from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union
+
+from fastapi import HTTPException
+from pydantic import BaseModel
+
+import litellm
+from litellm import DualCache, ModelResponse
+from litellm._logging import verbose_proxy_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
+from litellm.proxy._types import CommonProxyErrors, CurrentItemRateLimit, UserAPIKeyAuth
+from litellm.proxy.auth.auth_utils import (
+ get_key_model_rpm_limit,
+ get_key_model_tpm_limit,
+)
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
+
+ Span = _Span
+ InternalUsageCache = _InternalUsageCache
+else:
+ Span = Any
+ InternalUsageCache = Any
+
+
+class CacheObject(TypedDict):
+ current_global_requests: Optional[dict]
+ request_count_api_key: Optional[dict]
+ request_count_api_key_model: Optional[dict]
+ request_count_user_id: Optional[dict]
+ request_count_team_id: Optional[dict]
+ request_count_end_user_id: Optional[dict]
+
+
+class _PROXY_MaxParallelRequestsHandler(CustomLogger):
+ # Class variables or attributes
+ def __init__(self, internal_usage_cache: InternalUsageCache):
+ self.internal_usage_cache = internal_usage_cache
+
+ def print_verbose(self, print_statement):
+ try:
+ verbose_proxy_logger.debug(print_statement)
+ if litellm.set_verbose:
+ print(print_statement) # noqa
+ except Exception:
+ pass
+
+ async def check_key_in_limits(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: str,
+ max_parallel_requests: int,
+ tpm_limit: int,
+ rpm_limit: int,
+ current: Optional[dict],
+ request_count_api_key: str,
+ rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
+ values_to_update_in_cache: List[Tuple[Any, Any]],
+ ) -> dict:
+ verbose_proxy_logger.info(
+ f"Current Usage of {rate_limit_type} in this minute: {current}"
+ )
+ if current is None:
+ if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
+ # base case
+ raise self.raise_rate_limit_error(
+ additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
+ )
+ new_val = {
+ "current_requests": 1,
+ "current_tpm": 0,
+ "current_rpm": 1,
+ }
+ values_to_update_in_cache.append((request_count_api_key, new_val))
+ elif (
+ int(current["current_requests"]) < max_parallel_requests
+ and current["current_tpm"] < tpm_limit
+ and current["current_rpm"] < rpm_limit
+ ):
+ # Increase count for this token
+ new_val = {
+ "current_requests": current["current_requests"] + 1,
+ "current_tpm": current["current_tpm"],
+ "current_rpm": current["current_rpm"] + 1,
+ }
+ values_to_update_in_cache.append((request_count_api_key, new_val))
+
+ else:
+ raise HTTPException(
+ status_code=429,
+ detail=f"LiteLLM Rate Limit Handler for rate limit type = {rate_limit_type}. {CommonProxyErrors.max_parallel_request_limit_reached.value}. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}, current max_parallel_requests: {current['current_requests']}, max_parallel_requests: {max_parallel_requests}",
+ headers={"retry-after": str(self.time_to_next_minute())},
+ )
+
+ await self.internal_usage_cache.async_batch_set_cache(
+ cache_list=values_to_update_in_cache,
+ ttl=60,
+ litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+ local_only=True,
+ )
+ return new_val
+
+ def time_to_next_minute(self) -> float:
+ # Get the current time
+ now = datetime.now()
+
+ # Calculate the next minute
+ next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0)
+
+ # Calculate the difference in seconds
+ seconds_to_next_minute = (next_minute - now).total_seconds()
+
+ return seconds_to_next_minute
+
+ def raise_rate_limit_error(
+ self, additional_details: Optional[str] = None
+ ) -> HTTPException:
+ """
+ Raise an HTTPException with a 429 status code and a retry-after header
+ """
+ error_message = "Max parallel request limit reached"
+ if additional_details is not None:
+ error_message = error_message + " " + additional_details
+ raise HTTPException(
+ status_code=429,
+ detail=f"Max parallel request limit reached {additional_details}",
+ headers={"retry-after": str(self.time_to_next_minute())},
+ )
+
+ async def get_all_cache_objects(
+ self,
+ current_global_requests: Optional[str],
+ request_count_api_key: Optional[str],
+ request_count_api_key_model: Optional[str],
+ request_count_user_id: Optional[str],
+ request_count_team_id: Optional[str],
+ request_count_end_user_id: Optional[str],
+ parent_otel_span: Optional[Span] = None,
+ ) -> CacheObject:
+ keys = [
+ current_global_requests,
+ request_count_api_key,
+ request_count_api_key_model,
+ request_count_user_id,
+ request_count_team_id,
+ request_count_end_user_id,
+ ]
+ results = await self.internal_usage_cache.async_batch_get_cache(
+ keys=keys,
+ parent_otel_span=parent_otel_span,
+ )
+
+ if results is None:
+ return CacheObject(
+ current_global_requests=None,
+ request_count_api_key=None,
+ request_count_api_key_model=None,
+ request_count_user_id=None,
+ request_count_team_id=None,
+ request_count_end_user_id=None,
+ )
+
+ return CacheObject(
+ current_global_requests=results[0],
+ request_count_api_key=results[1],
+ request_count_api_key_model=results[2],
+ request_count_user_id=results[3],
+ request_count_team_id=results[4],
+ request_count_end_user_id=results[5],
+ )
+
+ async def async_pre_call_hook( # noqa: PLR0915
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: str,
+ ):
+ self.print_verbose("Inside Max Parallel Request Pre-Call Hook")
+ api_key = user_api_key_dict.api_key
+ max_parallel_requests = user_api_key_dict.max_parallel_requests
+ if max_parallel_requests is None:
+ max_parallel_requests = sys.maxsize
+ if data is None:
+ data = {}
+ global_max_parallel_requests = data.get("metadata", {}).get(
+ "global_max_parallel_requests", None
+ )
+ tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
+ if tpm_limit is None:
+ tpm_limit = sys.maxsize
+ rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
+ if rpm_limit is None:
+ rpm_limit = sys.maxsize
+
+ values_to_update_in_cache: List[Tuple[Any, Any]] = (
+ []
+ ) # values that need to get updated in cache, will run a batch_set_cache after this function
+
+ # ------------
+ # Setup values
+ # ------------
+ new_val: Optional[dict] = None
+
+ if global_max_parallel_requests is not None:
+ # get value from cache
+ _key = "global_max_parallel_requests"
+ current_global_requests = await self.internal_usage_cache.async_get_cache(
+ key=_key,
+ local_only=True,
+ litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+ )
+ # check if below limit
+ if current_global_requests is None:
+ current_global_requests = 1
+ # if above -> raise error
+ if current_global_requests >= global_max_parallel_requests:
+ return self.raise_rate_limit_error(
+ additional_details=f"Hit Global Limit: Limit={global_max_parallel_requests}, current: {current_global_requests}"
+ )
+ # if below -> increment
+ else:
+ await self.internal_usage_cache.async_increment_cache(
+ key=_key,
+ value=1,
+ local_only=True,
+ litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+ )
+ _model = data.get("model", None)
+
+ current_date = datetime.now().strftime("%Y-%m-%d")
+ current_hour = datetime.now().strftime("%H")
+ current_minute = datetime.now().strftime("%M")
+ precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+
+ cache_objects: CacheObject = await self.get_all_cache_objects(
+ current_global_requests=(
+ "global_max_parallel_requests"
+ if global_max_parallel_requests is not None
+ else None
+ ),
+ request_count_api_key=(
+ f"{api_key}::{precise_minute}::request_count"
+ if api_key is not None
+ else None
+ ),
+ request_count_api_key_model=(
+ f"{api_key}::{_model}::{precise_minute}::request_count"
+ if api_key is not None and _model is not None
+ else None
+ ),
+ request_count_user_id=(
+ f"{user_api_key_dict.user_id}::{precise_minute}::request_count"
+ if user_api_key_dict.user_id is not None
+ else None
+ ),
+ request_count_team_id=(
+ f"{user_api_key_dict.team_id}::{precise_minute}::request_count"
+ if user_api_key_dict.team_id is not None
+ else None
+ ),
+ request_count_end_user_id=(
+ f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
+ if user_api_key_dict.end_user_id is not None
+ else None
+ ),
+ parent_otel_span=user_api_key_dict.parent_otel_span,
+ )
+ if api_key is not None:
+ request_count_api_key = f"{api_key}::{precise_minute}::request_count"
+ # CHECK IF REQUEST ALLOWED for key
+ await self.check_key_in_limits(
+ user_api_key_dict=user_api_key_dict,
+ cache=cache,
+ data=data,
+ call_type=call_type,
+ max_parallel_requests=max_parallel_requests,
+ current=cache_objects["request_count_api_key"],
+ request_count_api_key=request_count_api_key,
+ tpm_limit=tpm_limit,
+ rpm_limit=rpm_limit,
+ rate_limit_type="key",
+ values_to_update_in_cache=values_to_update_in_cache,
+ )
+
+ # Check if request under RPM/TPM per model for a given API Key
+ if (
+ get_key_model_tpm_limit(user_api_key_dict) is not None
+ or get_key_model_rpm_limit(user_api_key_dict) is not None
+ ):
+ _model = data.get("model", None)
+ request_count_api_key = (
+ f"{api_key}::{_model}::{precise_minute}::request_count"
+ )
+ _tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
+ _rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
+ tpm_limit_for_model = None
+ rpm_limit_for_model = None
+
+ if _model is not None:
+ if _tpm_limit_for_key_model:
+ tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
+
+ if _rpm_limit_for_key_model:
+ rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
+
+ new_val = await self.check_key_in_limits(
+ user_api_key_dict=user_api_key_dict,
+ cache=cache,
+ data=data,
+ call_type=call_type,
+ max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a model
+ current=cache_objects["request_count_api_key_model"],
+ request_count_api_key=request_count_api_key,
+ tpm_limit=tpm_limit_for_model or sys.maxsize,
+ rpm_limit=rpm_limit_for_model or sys.maxsize,
+ rate_limit_type="model_per_key",
+ values_to_update_in_cache=values_to_update_in_cache,
+ )
+ _remaining_tokens = None
+ _remaining_requests = None
+ # Add remaining tokens, requests to metadata
+ if new_val:
+ if tpm_limit_for_model is not None:
+ _remaining_tokens = tpm_limit_for_model - new_val["current_tpm"]
+ if rpm_limit_for_model is not None:
+ _remaining_requests = rpm_limit_for_model - new_val["current_rpm"]
+
+ _remaining_limits_data = {
+ f"litellm-key-remaining-tokens-{_model}": _remaining_tokens,
+ f"litellm-key-remaining-requests-{_model}": _remaining_requests,
+ }
+
+ if "metadata" not in data:
+ data["metadata"] = {}
+ data["metadata"].update(_remaining_limits_data)
+
+ # check if REQUEST ALLOWED for user_id
+ user_id = user_api_key_dict.user_id
+ if user_id is not None:
+ user_tpm_limit = user_api_key_dict.user_tpm_limit
+ user_rpm_limit = user_api_key_dict.user_rpm_limit
+ if user_tpm_limit is None:
+ user_tpm_limit = sys.maxsize
+ if user_rpm_limit is None:
+ user_rpm_limit = sys.maxsize
+
+ request_count_api_key = f"{user_id}::{precise_minute}::request_count"
+ # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
+ await self.check_key_in_limits(
+ user_api_key_dict=user_api_key_dict,
+ cache=cache,
+ data=data,
+ call_type=call_type,
+ max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
+ current=cache_objects["request_count_user_id"],
+ request_count_api_key=request_count_api_key,
+ tpm_limit=user_tpm_limit,
+ rpm_limit=user_rpm_limit,
+ rate_limit_type="user",
+ values_to_update_in_cache=values_to_update_in_cache,
+ )
+
+ # TEAM RATE LIMITS
+ ## get team tpm/rpm limits
+ team_id = user_api_key_dict.team_id
+ if team_id is not None:
+ team_tpm_limit = user_api_key_dict.team_tpm_limit
+ team_rpm_limit = user_api_key_dict.team_rpm_limit
+
+ if team_tpm_limit is None:
+ team_tpm_limit = sys.maxsize
+ if team_rpm_limit is None:
+ team_rpm_limit = sys.maxsize
+
+ request_count_api_key = f"{team_id}::{precise_minute}::request_count"
+ # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
+ await self.check_key_in_limits(
+ user_api_key_dict=user_api_key_dict,
+ cache=cache,
+ data=data,
+ call_type=call_type,
+ max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team
+ current=cache_objects["request_count_team_id"],
+ request_count_api_key=request_count_api_key,
+ tpm_limit=team_tpm_limit,
+ rpm_limit=team_rpm_limit,
+ rate_limit_type="team",
+ values_to_update_in_cache=values_to_update_in_cache,
+ )
+
+ # End-User Rate Limits
+ # Only enforce if user passed `user` to /chat, /completions, /embeddings
+ if user_api_key_dict.end_user_id:
+ end_user_tpm_limit = getattr(
+ user_api_key_dict, "end_user_tpm_limit", sys.maxsize
+ )
+ end_user_rpm_limit = getattr(
+ user_api_key_dict, "end_user_rpm_limit", sys.maxsize
+ )
+
+ if end_user_tpm_limit is None:
+ end_user_tpm_limit = sys.maxsize
+ if end_user_rpm_limit is None:
+ end_user_rpm_limit = sys.maxsize
+
+ # now do the same tpm/rpm checks
+ request_count_api_key = (
+ f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
+ )
+
+ # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
+ await self.check_key_in_limits(
+ user_api_key_dict=user_api_key_dict,
+ cache=cache,
+ data=data,
+ call_type=call_type,
+ max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User
+ request_count_api_key=request_count_api_key,
+ current=cache_objects["request_count_end_user_id"],
+ tpm_limit=end_user_tpm_limit,
+ rpm_limit=end_user_rpm_limit,
+ rate_limit_type="customer",
+ values_to_update_in_cache=values_to_update_in_cache,
+ )
+
+ asyncio.create_task(
+ self.internal_usage_cache.async_batch_set_cache(
+ cache_list=values_to_update_in_cache,
+ ttl=60,
+ litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+ ) # don't block execution for cache updates
+ )
+
+ return
+
+ async def async_log_success_event( # noqa: PLR0915
+ self, kwargs, response_obj, start_time, end_time
+ ):
+ from litellm.proxy.common_utils.callback_utils import (
+ get_model_group_from_litellm_kwargs,
+ )
+
+ litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs(
+ kwargs=kwargs
+ )
+ try:
+ self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
+
+ global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
+ "global_max_parallel_requests", None
+ )
+ user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
+ user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
+ "user_api_key_user_id", None
+ )
+ user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
+ "user_api_key_team_id", None
+ )
+ user_api_key_model_max_budget = kwargs["litellm_params"]["metadata"].get(
+ "user_api_key_model_max_budget", None
+ )
+ user_api_key_end_user_id = kwargs.get("user")
+
+ user_api_key_metadata = (
+ kwargs["litellm_params"]["metadata"].get("user_api_key_metadata", {})
+ or {}
+ )
+
+ # ------------
+ # Setup values
+ # ------------
+
+ if global_max_parallel_requests is not None:
+ # get value from cache
+ _key = "global_max_parallel_requests"
+ # decrement
+ await self.internal_usage_cache.async_increment_cache(
+ key=_key,
+ value=-1,
+ local_only=True,
+ litellm_parent_otel_span=litellm_parent_otel_span,
+ )
+
+ current_date = datetime.now().strftime("%Y-%m-%d")
+ current_hour = datetime.now().strftime("%H")
+ current_minute = datetime.now().strftime("%M")
+ precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+
+ total_tokens = 0
+
+ if isinstance(response_obj, ModelResponse):
+ total_tokens = response_obj.usage.total_tokens # type: ignore
+
+ # ------------
+ # Update usage - API Key
+ # ------------
+
+ values_to_update_in_cache = []
+
+ if user_api_key is not None:
+ request_count_api_key = (
+ f"{user_api_key}::{precise_minute}::request_count"
+ )
+
+ current = await self.internal_usage_cache.async_get_cache(
+ key=request_count_api_key,
+ litellm_parent_otel_span=litellm_parent_otel_span,
+ ) or {
+ "current_requests": 1,
+ "current_tpm": 0,
+ "current_rpm": 0,
+ }
+
+ new_val = {
+ "current_requests": max(current["current_requests"] - 1, 0),
+ "current_tpm": current["current_tpm"] + total_tokens,
+ "current_rpm": current["current_rpm"],
+ }
+
+ self.print_verbose(
+ f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
+ )
+ values_to_update_in_cache.append((request_count_api_key, new_val))
+
+ # ------------
+ # Update usage - model group + API Key
+ # ------------
+ model_group = get_model_group_from_litellm_kwargs(kwargs)
+ if (
+ user_api_key is not None
+ and model_group is not None
+ and (
+ "model_rpm_limit" in user_api_key_metadata
+ or "model_tpm_limit" in user_api_key_metadata
+ or user_api_key_model_max_budget is not None
+ )
+ ):
+ request_count_api_key = (
+ f"{user_api_key}::{model_group}::{precise_minute}::request_count"
+ )
+
+ current = await self.internal_usage_cache.async_get_cache(
+ key=request_count_api_key,
+ litellm_parent_otel_span=litellm_parent_otel_span,
+ ) or {
+ "current_requests": 1,
+ "current_tpm": 0,
+ "current_rpm": 0,
+ }
+
+ new_val = {
+ "current_requests": max(current["current_requests"] - 1, 0),
+ "current_tpm": current["current_tpm"] + total_tokens,
+ "current_rpm": current["current_rpm"],
+ }
+
+ self.print_verbose(
+ f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
+ )
+ values_to_update_in_cache.append((request_count_api_key, new_val))
+
+ # ------------
+ # Update usage - User
+ # ------------
+ if user_api_key_user_id is not None:
+ total_tokens = 0
+
+ if isinstance(response_obj, ModelResponse):
+ total_tokens = response_obj.usage.total_tokens # type: ignore
+
+ request_count_api_key = (
+ f"{user_api_key_user_id}::{precise_minute}::request_count"
+ )
+
+ current = await self.internal_usage_cache.async_get_cache(
+ key=request_count_api_key,
+ litellm_parent_otel_span=litellm_parent_otel_span,
+ ) or {
+ "current_requests": 1,
+ "current_tpm": total_tokens,
+ "current_rpm": 1,
+ }
+
+ new_val = {
+ "current_requests": max(current["current_requests"] - 1, 0),
+ "current_tpm": current["current_tpm"] + total_tokens,
+ "current_rpm": current["current_rpm"],
+ }
+
+ self.print_verbose(
+ f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
+ )
+ values_to_update_in_cache.append((request_count_api_key, new_val))
+
+ # ------------
+ # Update usage - Team
+ # ------------
+ if user_api_key_team_id is not None:
+ total_tokens = 0
+
+ if isinstance(response_obj, ModelResponse):
+ total_tokens = response_obj.usage.total_tokens # type: ignore
+
+ request_count_api_key = (
+ f"{user_api_key_team_id}::{precise_minute}::request_count"
+ )
+
+ current = await self.internal_usage_cache.async_get_cache(
+ key=request_count_api_key,
+ litellm_parent_otel_span=litellm_parent_otel_span,
+ ) or {
+ "current_requests": 1,
+ "current_tpm": total_tokens,
+ "current_rpm": 1,
+ }
+
+ new_val = {
+ "current_requests": max(current["current_requests"] - 1, 0),
+ "current_tpm": current["current_tpm"] + total_tokens,
+ "current_rpm": current["current_rpm"],
+ }
+
+ self.print_verbose(
+ f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
+ )
+ values_to_update_in_cache.append((request_count_api_key, new_val))
+
+ # ------------
+ # Update usage - End User
+ # ------------
+ if user_api_key_end_user_id is not None:
+ total_tokens = 0
+
+ if isinstance(response_obj, ModelResponse):
+ total_tokens = response_obj.usage.total_tokens # type: ignore
+
+ request_count_api_key = (
+ f"{user_api_key_end_user_id}::{precise_minute}::request_count"
+ )
+
+ current = await self.internal_usage_cache.async_get_cache(
+ key=request_count_api_key,
+ litellm_parent_otel_span=litellm_parent_otel_span,
+ ) or {
+ "current_requests": 1,
+ "current_tpm": total_tokens,
+ "current_rpm": 1,
+ }
+
+ new_val = {
+ "current_requests": max(current["current_requests"] - 1, 0),
+ "current_tpm": current["current_tpm"] + total_tokens,
+ "current_rpm": current["current_rpm"],
+ }
+
+ self.print_verbose(
+ f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
+ )
+ values_to_update_in_cache.append((request_count_api_key, new_val))
+
+ await self.internal_usage_cache.async_batch_set_cache(
+ cache_list=values_to_update_in_cache,
+ ttl=60,
+ litellm_parent_otel_span=litellm_parent_otel_span,
+ )
+ except Exception as e:
+ self.print_verbose(e) # noqa
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ self.print_verbose("Inside Max Parallel Request Failure Hook")
+ litellm_parent_otel_span: Union[Span, None] = (
+ _get_parent_otel_span_from_kwargs(kwargs=kwargs)
+ )
+ _metadata = kwargs["litellm_params"].get("metadata", {}) or {}
+ global_max_parallel_requests = _metadata.get(
+ "global_max_parallel_requests", None
+ )
+ user_api_key = _metadata.get("user_api_key", None)
+ self.print_verbose(f"user_api_key: {user_api_key}")
+ if user_api_key is None:
+ return
+
+ ## decrement call count if call failed
+ if CommonProxyErrors.max_parallel_request_limit_reached.value in str(
+ kwargs["exception"]
+ ):
+ pass # ignore failed calls due to max limit being reached
+ else:
+ # ------------
+ # Setup values
+ # ------------
+
+ if global_max_parallel_requests is not None:
+ # get value from cache
+ _key = "global_max_parallel_requests"
+ (
+ await self.internal_usage_cache.async_get_cache(
+ key=_key,
+ local_only=True,
+ litellm_parent_otel_span=litellm_parent_otel_span,
+ )
+ )
+ # decrement
+ await self.internal_usage_cache.async_increment_cache(
+ key=_key,
+ value=-1,
+ local_only=True,
+ litellm_parent_otel_span=litellm_parent_otel_span,
+ )
+
+ current_date = datetime.now().strftime("%Y-%m-%d")
+ current_hour = datetime.now().strftime("%H")
+ current_minute = datetime.now().strftime("%M")
+ precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+
+ request_count_api_key = (
+ f"{user_api_key}::{precise_minute}::request_count"
+ )
+
+ # ------------
+ # Update usage
+ # ------------
+ current = await self.internal_usage_cache.async_get_cache(
+ key=request_count_api_key,
+ litellm_parent_otel_span=litellm_parent_otel_span,
+ ) or {
+ "current_requests": 1,
+ "current_tpm": 0,
+ "current_rpm": 0,
+ }
+
+ new_val = {
+ "current_requests": max(current["current_requests"] - 1, 0),
+ "current_tpm": current["current_tpm"],
+ "current_rpm": current["current_rpm"],
+ }
+
+ self.print_verbose(f"updated_value in failure call: {new_val}")
+ await self.internal_usage_cache.async_set_cache(
+ request_count_api_key,
+ new_val,
+ ttl=60,
+ litellm_parent_otel_span=litellm_parent_otel_span,
+ ) # save in cache for up to 1 min.
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "Inside Parallel Request Limiter: An exception occurred - {}".format(
+ str(e)
+ )
+ )
+
+ async def get_internal_user_object(
+ self,
+ user_id: str,
+ user_api_key_dict: UserAPIKeyAuth,
+ ) -> Optional[dict]:
+ """
+ Helper to get the 'Internal User Object'
+
+ It uses the `get_user_object` function from `litellm.proxy.auth.auth_checks`
+
+ We need this because the UserApiKeyAuth object does not contain the rpm/tpm limits for a User AND there could be a perf impact by additionally reading the UserTable.
+ """
+ from litellm._logging import verbose_proxy_logger
+ from litellm.proxy.auth.auth_checks import get_user_object
+ from litellm.proxy.proxy_server import prisma_client
+
+ try:
+ _user_id_rate_limits = await get_user_object(
+ user_id=user_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=self.internal_usage_cache.dual_cache,
+ user_id_upsert=False,
+ parent_otel_span=user_api_key_dict.parent_otel_span,
+ proxy_logging_obj=None,
+ )
+
+ if _user_id_rate_limits is None:
+ return None
+
+ return _user_id_rate_limits.model_dump()
+ except Exception as e:
+ verbose_proxy_logger.debug(
+ "Parallel Request Limiter: Error getting user object", str(e)
+ )
+ return None
+
+ async def async_post_call_success_hook(
+ self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
+ ):
+ """
+ Retrieve the key's remaining rate limits.
+ """
+ api_key = user_api_key_dict.api_key
+ current_date = datetime.now().strftime("%Y-%m-%d")
+ current_hour = datetime.now().strftime("%H")
+ current_minute = datetime.now().strftime("%M")
+ precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+ request_count_api_key = f"{api_key}::{precise_minute}::request_count"
+ current: Optional[CurrentItemRateLimit] = (
+ await self.internal_usage_cache.async_get_cache(
+ key=request_count_api_key,
+ litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+ )
+ )
+
+ key_remaining_rpm_limit: Optional[int] = None
+ key_rpm_limit: Optional[int] = None
+ key_remaining_tpm_limit: Optional[int] = None
+ key_tpm_limit: Optional[int] = None
+ if current is not None:
+ if user_api_key_dict.rpm_limit is not None:
+ key_remaining_rpm_limit = (
+ user_api_key_dict.rpm_limit - current["current_rpm"]
+ )
+ key_rpm_limit = user_api_key_dict.rpm_limit
+ if user_api_key_dict.tpm_limit is not None:
+ key_remaining_tpm_limit = (
+ user_api_key_dict.tpm_limit - current["current_tpm"]
+ )
+ key_tpm_limit = user_api_key_dict.tpm_limit
+
+ if hasattr(response, "_hidden_params"):
+ _hidden_params = getattr(response, "_hidden_params")
+ else:
+ _hidden_params = None
+ if _hidden_params is not None and (
+ isinstance(_hidden_params, BaseModel) or isinstance(_hidden_params, dict)
+ ):
+ if isinstance(_hidden_params, BaseModel):
+ _hidden_params = _hidden_params.model_dump()
+
+ _additional_headers = _hidden_params.get("additional_headers", {}) or {}
+
+ if key_remaining_rpm_limit is not None:
+ _additional_headers["x-ratelimit-remaining-requests"] = (
+ key_remaining_rpm_limit
+ )
+ if key_rpm_limit is not None:
+ _additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit
+ if key_remaining_tpm_limit is not None:
+ _additional_headers["x-ratelimit-remaining-tokens"] = (
+ key_remaining_tpm_limit
+ )
+ if key_tpm_limit is not None:
+ _additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit
+
+ setattr(
+ response,
+ "_hidden_params",
+ {**_hidden_params, "additional_headers": _additional_headers},
+ )
+
+ return await super().async_post_call_success_hook(
+ data, user_api_key_dict, response
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py
new file mode 100644
index 00000000..b1b2bbee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py
@@ -0,0 +1,280 @@
+# +------------------------------------+
+#
+# Prompt Injection Detection
+#
+# +------------------------------------+
+# Thank you users! We ❤️ you! - Krrish & Ishaan
+## Reject a call if it contains a prompt injection attack.
+
+
+from difflib import SequenceMatcher
+from typing import List, Literal, Optional
+
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.prompt_templates.factory import (
+ prompt_injection_detection_default_pt,
+)
+from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth
+from litellm.router import Router
+from litellm.utils import get_formatted_prompt
+
+
+class _OPTIONAL_PromptInjectionDetection(CustomLogger):
+ # Class variables or attributes
+ def __init__(
+ self,
+ prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None,
+ ):
+ self.prompt_injection_params = prompt_injection_params
+ self.llm_router: Optional[Router] = None
+
+ self.verbs = [
+ "Ignore",
+ "Disregard",
+ "Skip",
+ "Forget",
+ "Neglect",
+ "Overlook",
+ "Omit",
+ "Bypass",
+ "Pay no attention to",
+ "Do not follow",
+ "Do not obey",
+ ]
+ self.adjectives = [
+ "",
+ "prior",
+ "previous",
+ "preceding",
+ "above",
+ "foregoing",
+ "earlier",
+ "initial",
+ ]
+ self.prepositions = [
+ "",
+ "and start over",
+ "and start anew",
+ "and begin afresh",
+ "and start from scratch",
+ ]
+
+ def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
+ if level == "INFO":
+ verbose_proxy_logger.info(print_statement)
+ elif level == "DEBUG":
+ verbose_proxy_logger.debug(print_statement)
+
+ if litellm.set_verbose is True:
+ print(print_statement) # noqa
+
+ def update_environment(self, router: Optional[Router] = None):
+ self.llm_router = router
+
+ if (
+ self.prompt_injection_params is not None
+ and self.prompt_injection_params.llm_api_check is True
+ ):
+ if self.llm_router is None:
+ raise Exception(
+ "PromptInjectionDetection: Model List not set. Required for Prompt Injection detection."
+ )
+
+ self.print_verbose(
+ f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}"
+ )
+ if (
+ self.prompt_injection_params.llm_api_name is None
+ or self.prompt_injection_params.llm_api_name
+ not in self.llm_router.model_names
+ ):
+ raise Exception(
+ "PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'."
+ )
+
+ def generate_injection_keywords(self) -> List[str]:
+ combinations = []
+ for verb in self.verbs:
+ for adj in self.adjectives:
+ for prep in self.prepositions:
+ phrase = " ".join(filter(None, [verb, adj, prep])).strip()
+ if (
+ len(phrase.split()) > 2
+ ): # additional check to ensure more than 2 words
+ combinations.append(phrase.lower())
+ return combinations
+
+ def check_user_input_similarity(
+ self, user_input: str, similarity_threshold: float = 0.7
+ ) -> bool:
+ user_input_lower = user_input.lower()
+ keywords = self.generate_injection_keywords()
+
+ for keyword in keywords:
+ # Calculate the length of the keyword to extract substrings of the same length from user input
+ keyword_length = len(keyword)
+
+ for i in range(len(user_input_lower) - keyword_length + 1):
+ # Extract a substring of the same length as the keyword
+ substring = user_input_lower[i : i + keyword_length]
+
+ # Calculate similarity
+ match_ratio = SequenceMatcher(None, substring, keyword).ratio()
+ if match_ratio > similarity_threshold:
+ self.print_verbose(
+ print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}",
+ level="INFO",
+ )
+ return True # Found a highly similar substring
+ return False # No substring crossed the threshold
+
+ async def async_pre_call_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: str, # "completion", "embeddings", "image_generation", "moderation"
+ ):
+ try:
+ """
+ - check if user id part of call
+ - check if user id part of blocked list
+ """
+ self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook")
+ try:
+ assert call_type in [
+ "completion",
+ "text_completion",
+ "embeddings",
+ "image_generation",
+ "moderation",
+ "audio_transcription",
+ ]
+ except Exception:
+ self.print_verbose(
+ f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']"
+ )
+ return data
+ formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
+
+ is_prompt_attack = False
+
+ if self.prompt_injection_params is not None:
+ # 1. check if heuristics check turned on
+ if self.prompt_injection_params.heuristics_check is True:
+ is_prompt_attack = self.check_user_input_similarity(
+ user_input=formatted_prompt
+ )
+ if is_prompt_attack is True:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Rejected message. This is a prompt injection attack."
+ },
+ )
+ # 2. check if vector db similarity check turned on [TODO] Not Implemented yet
+ if self.prompt_injection_params.vector_db_check is True:
+ pass
+ else:
+ is_prompt_attack = self.check_user_input_similarity(
+ user_input=formatted_prompt
+ )
+
+ if is_prompt_attack is True:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Rejected message. This is a prompt injection attack."
+ },
+ )
+
+ return data
+
+ except HTTPException as e:
+
+ if (
+ e.status_code == 400
+ and isinstance(e.detail, dict)
+ and "error" in e.detail # type: ignore
+ and self.prompt_injection_params is not None
+ and self.prompt_injection_params.reject_as_response
+ ):
+ return e.detail.get("error")
+ raise e
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+
+ async def async_moderation_hook( # type: ignore
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ call_type: Literal[
+ "completion",
+ "embeddings",
+ "image_generation",
+ "moderation",
+ "audio_transcription",
+ ],
+ ) -> Optional[bool]:
+ self.print_verbose(
+ f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}"
+ )
+
+ if self.prompt_injection_params is None:
+ return None
+
+ formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
+ is_prompt_attack = False
+
+ prompt_injection_system_prompt = getattr(
+ self.prompt_injection_params,
+ "llm_api_system_prompt",
+ prompt_injection_detection_default_pt(),
+ )
+
+ # 3. check if llm api check turned on
+ if (
+ self.prompt_injection_params.llm_api_check is True
+ and self.prompt_injection_params.llm_api_name is not None
+ and self.llm_router is not None
+ ):
+ # make a call to the llm api
+ response = await self.llm_router.acompletion(
+ model=self.prompt_injection_params.llm_api_name,
+ messages=[
+ {
+ "role": "system",
+ "content": prompt_injection_system_prompt,
+ },
+ {"role": "user", "content": formatted_prompt},
+ ],
+ )
+
+ self.print_verbose(f"Received LLM Moderation response: {response}")
+ self.print_verbose(
+ f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}"
+ )
+ if isinstance(response, litellm.ModelResponse) and isinstance(
+ response.choices[0], litellm.Choices
+ ):
+ if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: # type: ignore
+ is_prompt_attack = True
+
+ if is_prompt_attack is True:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Rejected message. This is a prompt injection attack."
+ },
+ )
+
+ return is_prompt_attack
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py
new file mode 100644
index 00000000..e8a94732
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py
@@ -0,0 +1,246 @@
+import asyncio
+import traceback
+from datetime import datetime
+from typing import Any, Optional, Union, cast
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.core_helpers import (
+ _get_parent_otel_span_from_kwargs,
+ get_litellm_metadata_from_kwargs,
+)
+from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.proxy.auth.auth_checks import log_db_metrics
+from litellm.proxy.utils import ProxyUpdateSpend
+from litellm.types.utils import (
+ StandardLoggingPayload,
+ StandardLoggingUserAPIKeyMetadata,
+)
+from litellm.utils import get_end_user_id_for_cost_tracking
+
+
+class _ProxyDBLogger(CustomLogger):
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ await self._PROXY_track_cost_callback(
+ kwargs, response_obj, start_time, end_time
+ )
+
+ async def async_post_call_failure_hook(
+ self,
+ request_data: dict,
+ original_exception: Exception,
+ user_api_key_dict: UserAPIKeyAuth,
+ ):
+ from litellm.proxy.proxy_server import update_database
+
+ if _ProxyDBLogger._should_track_errors_in_db() is False:
+ return
+
+ _metadata = dict(
+ StandardLoggingUserAPIKeyMetadata(
+ user_api_key_hash=user_api_key_dict.api_key,
+ user_api_key_alias=user_api_key_dict.key_alias,
+ user_api_key_user_email=user_api_key_dict.user_email,
+ user_api_key_user_id=user_api_key_dict.user_id,
+ user_api_key_team_id=user_api_key_dict.team_id,
+ user_api_key_org_id=user_api_key_dict.org_id,
+ user_api_key_team_alias=user_api_key_dict.team_alias,
+ user_api_key_end_user_id=user_api_key_dict.end_user_id,
+ )
+ )
+ _metadata["user_api_key"] = user_api_key_dict.api_key
+ _metadata["status"] = "failure"
+ _metadata["error_information"] = (
+ StandardLoggingPayloadSetup.get_error_information(
+ original_exception=original_exception,
+ )
+ )
+
+ existing_metadata: dict = request_data.get("metadata", None) or {}
+ existing_metadata.update(_metadata)
+
+ if "litellm_params" not in request_data:
+ request_data["litellm_params"] = {}
+ request_data["litellm_params"]["proxy_server_request"] = (
+ request_data.get("proxy_server_request") or {}
+ )
+ request_data["litellm_params"]["metadata"] = existing_metadata
+ await update_database(
+ token=user_api_key_dict.api_key,
+ response_cost=0.0,
+ user_id=user_api_key_dict.user_id,
+ end_user_id=user_api_key_dict.end_user_id,
+ team_id=user_api_key_dict.team_id,
+ kwargs=request_data,
+ completion_response=original_exception,
+ start_time=datetime.now(),
+ end_time=datetime.now(),
+ org_id=user_api_key_dict.org_id,
+ )
+
+ @log_db_metrics
+ async def _PROXY_track_cost_callback(
+ self,
+ kwargs, # kwargs to completion
+ completion_response: Optional[
+ Union[litellm.ModelResponse, Any]
+ ], # response from completion
+ start_time=None,
+ end_time=None, # start/end time for completion
+ ):
+ from litellm.proxy.proxy_server import (
+ prisma_client,
+ proxy_logging_obj,
+ update_cache,
+ update_database,
+ )
+
+ verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
+ try:
+ verbose_proxy_logger.debug(
+ f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
+ )
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
+ litellm_params = kwargs.get("litellm_params", {}) or {}
+ end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
+ metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
+ user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
+ team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
+ org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
+ key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
+ end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
+ sl_object: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+ response_cost = (
+ sl_object.get("response_cost", None)
+ if sl_object is not None
+ else kwargs.get("response_cost", None)
+ )
+
+ if response_cost is not None:
+ user_api_key = metadata.get("user_api_key", None)
+ if kwargs.get("cache_hit", False) is True:
+ response_cost = 0.0
+ verbose_proxy_logger.info(
+ f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
+ )
+
+ verbose_proxy_logger.debug(
+ f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
+ )
+ if _should_track_cost_callback(
+ user_api_key=user_api_key,
+ user_id=user_id,
+ team_id=team_id,
+ end_user_id=end_user_id,
+ ):
+ ## UPDATE DATABASE
+ await update_database(
+ token=user_api_key,
+ response_cost=response_cost,
+ user_id=user_id,
+ end_user_id=end_user_id,
+ team_id=team_id,
+ kwargs=kwargs,
+ completion_response=completion_response,
+ start_time=start_time,
+ end_time=end_time,
+ org_id=org_id,
+ )
+
+ # update cache
+ asyncio.create_task(
+ update_cache(
+ token=user_api_key,
+ user_id=user_id,
+ end_user_id=end_user_id,
+ response_cost=response_cost,
+ team_id=team_id,
+ parent_otel_span=parent_otel_span,
+ )
+ )
+
+ await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
+ token=user_api_key,
+ key_alias=key_alias,
+ end_user_id=end_user_id,
+ response_cost=response_cost,
+ max_budget=end_user_max_budget,
+ )
+ else:
+ raise Exception(
+ "User API key and team id and user id missing from custom callback."
+ )
+ else:
+ if kwargs["stream"] is not True or (
+ kwargs["stream"] is True and "complete_streaming_response" in kwargs
+ ):
+ if sl_object is not None:
+ cost_tracking_failure_debug_info: Union[dict, str] = (
+ sl_object["response_cost_failure_debug_info"] # type: ignore
+ or "response_cost_failure_debug_info is None in standard_logging_object"
+ )
+ else:
+ cost_tracking_failure_debug_info = (
+ "standard_logging_object not found"
+ )
+ model = kwargs.get("model")
+ raise Exception(
+ f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
+ )
+ except Exception as e:
+ error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
+ model = kwargs.get("model", "")
+ metadata = kwargs.get("litellm_params", {}).get("metadata", {})
+ error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
+ asyncio.create_task(
+ proxy_logging_obj.failed_tracking_alert(
+ error_message=error_msg,
+ failing_model=model,
+ )
+ )
+ verbose_proxy_logger.exception(
+ "Error in tracking cost callback - %s", str(e)
+ )
+
+ @staticmethod
+ def _should_track_errors_in_db():
+ """
+ Returns True if errors should be tracked in the database
+
+ By default, errors are tracked in the database
+
+ If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings
+ """
+ from litellm.proxy.proxy_server import general_settings
+
+ if general_settings.get("disable_error_logs") is True:
+ return False
+ return
+
+
+def _should_track_cost_callback(
+ user_api_key: Optional[str],
+ user_id: Optional[str],
+ team_id: Optional[str],
+ end_user_id: Optional[str],
+) -> bool:
+ """
+ Determine if the cost callback should be tracked based on the kwargs
+ """
+
+ # don't run track cost callback if user opted into disabling spend
+ if ProxyUpdateSpend.disable_spend_updates() is True:
+ return False
+
+ if (
+ user_api_key is not None
+ or user_id is not None
+ or team_id is not None
+ or end_user_id is not None
+ ):
+ return True
+ return False