about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/hooks
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/hooks
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/__init__.py1
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py156
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/batch_redis_get.py149
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/cache_control_check.py58
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/dynamic_rate_limiter.py298
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/example_presidio_ad_hoc_recognizer.json28
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py324
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/max_budget_limiter.py49
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/model_max_budget_limiter.py192
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/parallel_request_limiter.py866
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py280
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py246
12 files changed, 2647 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/__init__.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/__init__.py
new file mode 100644
index 00000000..b6e690fd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/__init__.py
@@ -0,0 +1 @@
+from . import *
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py
new file mode 100644
index 00000000..b35d6711
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py
@@ -0,0 +1,156 @@
+import traceback
+from typing import Optional
+
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+class _PROXY_AzureContentSafety(
+    CustomLogger
+):  # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
+    # Class variables or attributes
+
+    def __init__(self, endpoint, api_key, thresholds=None):
+        try:
+            from azure.ai.contentsafety.aio import ContentSafetyClient
+            from azure.ai.contentsafety.models import (
+                AnalyzeTextOptions,
+                AnalyzeTextOutputType,
+                TextCategory,
+            )
+            from azure.core.credentials import AzureKeyCredential
+            from azure.core.exceptions import HttpResponseError
+        except Exception as e:
+            raise Exception(
+                f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
+            )
+        self.endpoint = endpoint
+        self.api_key = api_key
+        self.text_category = TextCategory
+        self.analyze_text_options = AnalyzeTextOptions
+        self.analyze_text_output_type = AnalyzeTextOutputType
+        self.azure_http_error = HttpResponseError
+
+        self.thresholds = self._configure_thresholds(thresholds)
+
+        self.client = ContentSafetyClient(
+            self.endpoint, AzureKeyCredential(self.api_key)
+        )
+
+    def _configure_thresholds(self, thresholds=None):
+        default_thresholds = {
+            self.text_category.HATE: 4,
+            self.text_category.SELF_HARM: 4,
+            self.text_category.SEXUAL: 4,
+            self.text_category.VIOLENCE: 4,
+        }
+
+        if thresholds is None:
+            return default_thresholds
+
+        for key, default in default_thresholds.items():
+            if key not in thresholds:
+                thresholds[key] = default
+
+        return thresholds
+
+    def _compute_result(self, response):
+        result = {}
+
+        category_severity = {
+            item.category: item.severity for item in response.categories_analysis
+        }
+        for category in self.text_category:
+            severity = category_severity.get(category)
+            if severity is not None:
+                result[category] = {
+                    "filtered": severity >= self.thresholds[category],
+                    "severity": severity,
+                }
+
+        return result
+
+    async def test_violation(self, content: str, source: Optional[str] = None):
+        verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
+
+        # Construct a request
+        request = self.analyze_text_options(
+            text=content,
+            output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
+        )
+
+        # Analyze text
+        try:
+            response = await self.client.analyze_text(request)
+        except self.azure_http_error:
+            verbose_proxy_logger.debug(
+                "Error in Azure Content-Safety: %s", traceback.format_exc()
+            )
+            verbose_proxy_logger.debug(traceback.format_exc())
+            raise
+
+        result = self._compute_result(response)
+        verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
+
+        for key, value in result.items():
+            if value["filtered"]:
+                raise HTTPException(
+                    status_code=400,
+                    detail={
+                        "error": "Violated content safety policy",
+                        "source": source,
+                        "category": key,
+                        "severity": value["severity"],
+                    },
+                )
+
+    async def async_pre_call_hook(
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        cache: DualCache,
+        data: dict,
+        call_type: str,  # "completion", "embeddings", "image_generation", "moderation"
+    ):
+        verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
+        try:
+            if call_type == "completion" and "messages" in data:
+                for m in data["messages"]:
+                    if "content" in m and isinstance(m["content"], str):
+                        await self.test_violation(content=m["content"], source="input")
+
+        except HTTPException as e:
+            raise e
+        except Exception as e:
+            verbose_proxy_logger.error(
+                "litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
+                    str(e)
+                )
+            )
+            verbose_proxy_logger.debug(traceback.format_exc())
+
+    async def async_post_call_success_hook(
+        self,
+        data: dict,
+        user_api_key_dict: UserAPIKeyAuth,
+        response,
+    ):
+        verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
+        if isinstance(response, litellm.ModelResponse) and isinstance(
+            response.choices[0], litellm.utils.Choices
+        ):
+            await self.test_violation(
+                content=response.choices[0].message.content or "", source="output"
+            )
+
+    # async def async_post_call_streaming_hook(
+    #    self,
+    #    user_api_key_dict: UserAPIKeyAuth,
+    #    response: str,
+    # ):
+    #    verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
+    #    await self.test_violation(content=response, source="output")
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/batch_redis_get.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/batch_redis_get.py
new file mode 100644
index 00000000..c608317f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/batch_redis_get.py
@@ -0,0 +1,149 @@
+# What this does?
+## Gets a key's redis cache, and store it in memory for 1 minute.
+## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
+### [BETA] this is in Beta. And might change.
+
+import traceback
+from typing import Literal, Optional
+
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+class _PROXY_BatchRedisRequests(CustomLogger):
+    # Class variables or attributes
+    in_memory_cache: Optional[InMemoryCache] = None
+
+    def __init__(self):
+        if litellm.cache is not None:
+            litellm.cache.async_get_cache = (
+                self.async_get_cache
+            )  # map the litellm 'get_cache' function to our custom function
+
+    def print_verbose(
+        self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
+    ):
+        if debug_level == "DEBUG":
+            verbose_proxy_logger.debug(print_statement)
+        elif debug_level == "INFO":
+            verbose_proxy_logger.debug(print_statement)
+        if litellm.set_verbose is True:
+            print(print_statement)  # noqa
+
+    async def async_pre_call_hook(
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        cache: DualCache,
+        data: dict,
+        call_type: str,
+    ):
+        try:
+            """
+            Get the user key
+
+            Check if a key starting with `litellm:<api_key>:<call_type:` exists in-memory
+
+            If no, then get relevant cache from redis
+            """
+            api_key = user_api_key_dict.api_key
+
+            cache_key_name = f"litellm:{api_key}:{call_type}"
+            self.in_memory_cache = cache.in_memory_cache
+
+            key_value_dict = {}
+            in_memory_cache_exists = False
+            for key in cache.in_memory_cache.cache_dict.keys():
+                if isinstance(key, str) and key.startswith(cache_key_name):
+                    in_memory_cache_exists = True
+
+            if in_memory_cache_exists is False and litellm.cache is not None:
+                """
+                - Check if `litellm.Cache` is redis
+                - Get the relevant values
+                """
+                if litellm.cache.type is not None and isinstance(
+                    litellm.cache.cache, RedisCache
+                ):
+                    # Initialize an empty list to store the keys
+                    keys = []
+                    self.print_verbose(f"cache_key_name: {cache_key_name}")
+                    # Use the SCAN iterator to fetch keys matching the pattern
+                    keys = await litellm.cache.cache.async_scan_iter(
+                        pattern=cache_key_name, count=100
+                    )
+                    # If you need the truly "last" based on time or another criteria,
+                    # ensure your key naming or storage strategy allows this determination
+                    # Here you would sort or filter the keys as needed based on your strategy
+                    self.print_verbose(f"redis keys: {keys}")
+                    if len(keys) > 0:
+                        key_value_dict = (
+                            await litellm.cache.cache.async_batch_get_cache(
+                                key_list=keys
+                            )
+                        )
+
+            ## Add to cache
+            if len(key_value_dict.items()) > 0:
+                await cache.in_memory_cache.async_set_cache_pipeline(
+                    cache_list=list(key_value_dict.items()), ttl=60
+                )
+            ## Set cache namespace if it's a miss
+            data["metadata"]["redis_namespace"] = cache_key_name
+        except HTTPException as e:
+            raise e
+        except Exception as e:
+            verbose_proxy_logger.error(
+                "litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format(
+                    str(e)
+                )
+            )
+            verbose_proxy_logger.debug(traceback.format_exc())
+
+    async def async_get_cache(self, *args, **kwargs):
+        """
+        - Check if the cache key is in-memory
+
+        - Else:
+            - add missing cache key from REDIS
+            - update in-memory cache
+            - return redis cache request
+        """
+        try:  # never block execution
+            cache_key: Optional[str] = None
+            if "cache_key" in kwargs:
+                cache_key = kwargs["cache_key"]
+            elif litellm.cache is not None:
+                cache_key = litellm.cache.get_cache_key(
+                    *args, **kwargs
+                )  # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
+
+            if (
+                cache_key is not None
+                and self.in_memory_cache is not None
+                and litellm.cache is not None
+            ):
+                cache_control_args = kwargs.get("cache", {})
+                max_age = cache_control_args.get(
+                    "s-max-age", cache_control_args.get("s-maxage", float("inf"))
+                )
+                cached_result = self.in_memory_cache.get_cache(
+                    cache_key, *args, **kwargs
+                )
+                if cached_result is None:
+                    cached_result = await litellm.cache.cache.async_get_cache(
+                        cache_key, *args, **kwargs
+                    )
+                    if cached_result is not None:
+                        await self.in_memory_cache.async_set_cache(
+                            cache_key, cached_result, ttl=60
+                        )
+                return litellm.cache._get_cache_logic(
+                    cached_result=cached_result, max_age=max_age
+                )
+        except Exception:
+            return None
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/cache_control_check.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/cache_control_check.py
new file mode 100644
index 00000000..6e3fbf84
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/cache_control_check.py
@@ -0,0 +1,58 @@
+# What this does?
+## Checks if key is allowed to use the cache controls passed in to the completion() call
+
+
+from fastapi import HTTPException
+
+from litellm import verbose_logger
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+class _PROXY_CacheControlCheck(CustomLogger):
+    # Class variables or attributes
+    def __init__(self):
+        pass
+
+    async def async_pre_call_hook(
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        cache: DualCache,
+        data: dict,
+        call_type: str,
+    ):
+        try:
+            verbose_proxy_logger.debug("Inside Cache Control Check Pre-Call Hook")
+            allowed_cache_controls = user_api_key_dict.allowed_cache_controls
+
+            if data.get("cache", None) is None:
+                return
+
+            cache_args = data.get("cache", None)
+            if isinstance(cache_args, dict):
+                for k, v in cache_args.items():
+                    if (
+                        (allowed_cache_controls is not None)
+                        and (isinstance(allowed_cache_controls, list))
+                        and (
+                            len(allowed_cache_controls) > 0
+                        )  # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663
+                        and k not in allowed_cache_controls
+                    ):
+                        raise HTTPException(
+                            status_code=403,
+                            detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.",
+                        )
+            else:  # invalid cache
+                return
+
+        except HTTPException as e:
+            raise e
+        except Exception as e:
+            verbose_logger.exception(
+                "litellm.proxy.hooks.cache_control_check.py::async_pre_call_hook(): Exception occured - {}".format(
+                    str(e)
+                )
+            )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/dynamic_rate_limiter.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/dynamic_rate_limiter.py
new file mode 100644
index 00000000..15a9bc1b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/dynamic_rate_limiter.py
@@ -0,0 +1,298 @@
+# What is this?
+## Allocates dynamic tpm/rpm quota for a project based on current traffic
+## Tracks num active projects per minute
+
+import asyncio
+import os
+from typing import List, Literal, Optional, Tuple, Union
+
+from fastapi import HTTPException
+
+import litellm
+from litellm import ModelResponse, Router
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.types.router import ModelGroupInfo
+from litellm.utils import get_utc_datetime
+
+
+class DynamicRateLimiterCache:
+    """
+    Thin wrapper on DualCache for this file.
+
+    Track number of active projects calling a model.
+    """
+
+    def __init__(self, cache: DualCache) -> None:
+        self.cache = cache
+        self.ttl = 60  # 1 min ttl
+
+    async def async_get_cache(self, model: str) -> Optional[int]:
+        dt = get_utc_datetime()
+        current_minute = dt.strftime("%H-%M")
+        key_name = "{}:{}".format(current_minute, model)
+        _response = await self.cache.async_get_cache(key=key_name)
+        response: Optional[int] = None
+        if _response is not None:
+            response = len(_response)
+        return response
+
+    async def async_set_cache_sadd(self, model: str, value: List):
+        """
+        Add value to set.
+
+        Parameters:
+        - model: str, the name of the model group
+        - value: str, the team id
+
+        Returns:
+        - None
+
+        Raises:
+        - Exception, if unable to connect to cache client (if redis caching enabled)
+        """
+        try:
+            dt = get_utc_datetime()
+            current_minute = dt.strftime("%H-%M")
+
+            key_name = "{}:{}".format(current_minute, model)
+            await self.cache.async_set_cache_sadd(
+                key=key_name, value=value, ttl=self.ttl
+            )
+        except Exception as e:
+            verbose_proxy_logger.exception(
+                "litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}".format(
+                    str(e)
+                )
+            )
+            raise e
+
+
+class _PROXY_DynamicRateLimitHandler(CustomLogger):
+
+    # Class variables or attributes
+    def __init__(self, internal_usage_cache: DualCache):
+        self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache)
+
+    def update_variables(self, llm_router: Router):
+        self.llm_router = llm_router
+
+    async def check_available_usage(
+        self, model: str, priority: Optional[str] = None
+    ) -> Tuple[
+        Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]
+    ]:
+        """
+        For a given model, get its available tpm
+
+        Params:
+        - model: str, the name of the model in the router model_list
+        - priority: Optional[str], the priority for the request.
+
+        Returns
+        - Tuple[available_tpm, available_tpm, model_tpm, model_rpm, active_projects]
+            - available_tpm: int or null - always 0 or positive.
+            - available_tpm: int or null - always 0 or positive.
+            - remaining_model_tpm: int or null. If available tpm is int, then this will be too.
+            - remaining_model_rpm: int or null. If available rpm is int, then this will be too.
+            - active_projects: int or null
+        """
+        try:
+            weight: float = 1
+            if (
+                litellm.priority_reservation is None
+                or priority not in litellm.priority_reservation
+            ):
+                verbose_proxy_logger.error(
+                    "Priority Reservation not set. priority={}, but litellm.priority_reservation is {}.".format(
+                        priority, litellm.priority_reservation
+                    )
+                )
+            elif priority is not None and litellm.priority_reservation is not None:
+                if os.getenv("LITELLM_LICENSE", None) is None:
+                    verbose_proxy_logger.error(
+                        "PREMIUM FEATURE: Reserving tpm/rpm by priority is a premium feature. Please add a 'LITELLM_LICENSE' to your .env to enable this.\nGet a license: https://docs.litellm.ai/docs/proxy/enterprise."
+                    )
+                else:
+                    weight = litellm.priority_reservation[priority]
+
+            active_projects = await self.internal_usage_cache.async_get_cache(
+                model=model
+            )
+            current_model_tpm, current_model_rpm = (
+                await self.llm_router.get_model_group_usage(model_group=model)
+            )
+            model_group_info: Optional[ModelGroupInfo] = (
+                self.llm_router.get_model_group_info(model_group=model)
+            )
+            total_model_tpm: Optional[int] = None
+            total_model_rpm: Optional[int] = None
+            if model_group_info is not None:
+                if model_group_info.tpm is not None:
+                    total_model_tpm = model_group_info.tpm
+                if model_group_info.rpm is not None:
+                    total_model_rpm = model_group_info.rpm
+
+            remaining_model_tpm: Optional[int] = None
+            if total_model_tpm is not None and current_model_tpm is not None:
+                remaining_model_tpm = total_model_tpm - current_model_tpm
+            elif total_model_tpm is not None:
+                remaining_model_tpm = total_model_tpm
+
+            remaining_model_rpm: Optional[int] = None
+            if total_model_rpm is not None and current_model_rpm is not None:
+                remaining_model_rpm = total_model_rpm - current_model_rpm
+            elif total_model_rpm is not None:
+                remaining_model_rpm = total_model_rpm
+
+            available_tpm: Optional[int] = None
+
+            if remaining_model_tpm is not None:
+                if active_projects is not None:
+                    available_tpm = int(remaining_model_tpm * weight / active_projects)
+                else:
+                    available_tpm = int(remaining_model_tpm * weight)
+
+            if available_tpm is not None and available_tpm < 0:
+                available_tpm = 0
+
+            available_rpm: Optional[int] = None
+
+            if remaining_model_rpm is not None:
+                if active_projects is not None:
+                    available_rpm = int(remaining_model_rpm * weight / active_projects)
+                else:
+                    available_rpm = int(remaining_model_rpm * weight)
+
+            if available_rpm is not None and available_rpm < 0:
+                available_rpm = 0
+            return (
+                available_tpm,
+                available_rpm,
+                remaining_model_tpm,
+                remaining_model_rpm,
+                active_projects,
+            )
+        except Exception as e:
+            verbose_proxy_logger.exception(
+                "litellm.proxy.hooks.dynamic_rate_limiter.py::check_available_usage: Exception occurred - {}".format(
+                    str(e)
+                )
+            )
+            return None, None, None, None, None
+
+    async def async_pre_call_hook(
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        cache: DualCache,
+        data: dict,
+        call_type: Literal[
+            "completion",
+            "text_completion",
+            "embeddings",
+            "image_generation",
+            "moderation",
+            "audio_transcription",
+            "pass_through_endpoint",
+            "rerank",
+        ],
+    ) -> Optional[
+        Union[Exception, str, dict]
+    ]:  # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
+        """
+        - For a model group
+        - Check if tpm/rpm available
+        - Raise RateLimitError if no tpm/rpm available
+        """
+        if "model" in data:
+            key_priority: Optional[str] = user_api_key_dict.metadata.get(
+                "priority", None
+            )
+            available_tpm, available_rpm, model_tpm, model_rpm, active_projects = (
+                await self.check_available_usage(
+                    model=data["model"], priority=key_priority
+                )
+            )
+            ### CHECK TPM ###
+            if available_tpm is not None and available_tpm == 0:
+                raise HTTPException(
+                    status_code=429,
+                    detail={
+                        "error": "Key={} over available TPM={}. Model TPM={}, Active keys={}".format(
+                            user_api_key_dict.api_key,
+                            available_tpm,
+                            model_tpm,
+                            active_projects,
+                        )
+                    },
+                )
+            ### CHECK RPM ###
+            elif available_rpm is not None and available_rpm == 0:
+                raise HTTPException(
+                    status_code=429,
+                    detail={
+                        "error": "Key={} over available RPM={}. Model RPM={}, Active keys={}".format(
+                            user_api_key_dict.api_key,
+                            available_rpm,
+                            model_rpm,
+                            active_projects,
+                        )
+                    },
+                )
+            elif available_rpm is not None or available_tpm is not None:
+                ## UPDATE CACHE WITH ACTIVE PROJECT
+                asyncio.create_task(
+                    self.internal_usage_cache.async_set_cache_sadd(  # this is a set
+                        model=data["model"],  # type: ignore
+                        value=[user_api_key_dict.token or "default_key"],
+                    )
+                )
+        return None
+
+    async def async_post_call_success_hook(
+        self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
+    ):
+        try:
+            if isinstance(response, ModelResponse):
+                model_info = self.llm_router.get_model_info(
+                    id=response._hidden_params["model_id"]
+                )
+                assert (
+                    model_info is not None
+                ), "Model info for model with id={} is None".format(
+                    response._hidden_params["model_id"]
+                )
+                key_priority: Optional[str] = user_api_key_dict.metadata.get(
+                    "priority", None
+                )
+                available_tpm, available_rpm, model_tpm, model_rpm, active_projects = (
+                    await self.check_available_usage(
+                        model=model_info["model_name"], priority=key_priority
+                    )
+                )
+                response._hidden_params["additional_headers"] = (
+                    {  # Add additional response headers - easier debugging
+                        "x-litellm-model_group": model_info["model_name"],
+                        "x-ratelimit-remaining-litellm-project-tokens": available_tpm,
+                        "x-ratelimit-remaining-litellm-project-requests": available_rpm,
+                        "x-ratelimit-remaining-model-tokens": model_tpm,
+                        "x-ratelimit-remaining-model-requests": model_rpm,
+                        "x-ratelimit-current-active-projects": active_projects,
+                    }
+                )
+
+                return response
+            return await super().async_post_call_success_hook(
+                data=data,
+                user_api_key_dict=user_api_key_dict,
+                response=response,
+            )
+        except Exception as e:
+            verbose_proxy_logger.exception(
+                "litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}".format(
+                    str(e)
+                )
+            )
+            return response
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/example_presidio_ad_hoc_recognizer.json b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/example_presidio_ad_hoc_recognizer.json
new file mode 100644
index 00000000..6a94d8de
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/example_presidio_ad_hoc_recognizer.json
@@ -0,0 +1,28 @@
+[
+    {
+        "name": "Zip code Recognizer",
+        "supported_language": "en",
+        "patterns": [
+            {
+                "name": "zip code (weak)",
+                "regex": "(\\b\\d{5}(?:\\-\\d{4})?\\b)",
+                "score": 0.01
+            }
+        ],
+        "context": ["zip", "code"],
+        "supported_entity": "ZIP"
+    },
+    {
+        "name": "Swiss AHV Number Recognizer",
+        "supported_language": "en",
+        "patterns": [
+            {
+                "name": "AHV number (strong)",
+                "regex": "(756\\.\\d{4}\\.\\d{4}\\.\\d{2})|(756\\d{10})",
+                "score": 0.95
+            }
+        ],
+        "context": ["AHV", "social security", "Swiss"],
+        "supported_entity": "AHV_NUMBER"
+    }
+]
\ No newline at end of file
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py
new file mode 100644
index 00000000..2030cb2a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/key_management_event_hooks.py
@@ -0,0 +1,324 @@
+import asyncio
+import json
+import uuid
+from datetime import datetime, timezone
+from typing import Any, List, Optional
+
+from fastapi import status
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import (
+    GenerateKeyRequest,
+    GenerateKeyResponse,
+    KeyRequest,
+    LiteLLM_AuditLogs,
+    LiteLLM_VerificationToken,
+    LitellmTableNames,
+    ProxyErrorTypes,
+    ProxyException,
+    RegenerateKeyRequest,
+    UpdateKeyRequest,
+    UserAPIKeyAuth,
+    WebhookEvent,
+)
+
+# NOTE: This is the prefix for all virtual keys stored in AWS Secrets Manager
+LITELLM_PREFIX_STORED_VIRTUAL_KEYS = "litellm/"
+
+
+class KeyManagementEventHooks:
+
+    @staticmethod
+    async def async_key_generated_hook(
+        data: GenerateKeyRequest,
+        response: GenerateKeyResponse,
+        user_api_key_dict: UserAPIKeyAuth,
+        litellm_changed_by: Optional[str] = None,
+    ):
+        """
+        Hook that runs after a successful /key/generate request
+
+        Handles the following:
+        - Sending Email with Key Details
+        - Storing Audit Logs for key generation
+        - Storing Generated Key in DB
+        """
+        from litellm.proxy.management_helpers.audit_logs import (
+            create_audit_log_for_update,
+        )
+        from litellm.proxy.proxy_server import litellm_proxy_admin_name
+
+        if data.send_invite_email is True:
+            await KeyManagementEventHooks._send_key_created_email(
+                response.model_dump(exclude_none=True)
+            )
+
+        # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+        if litellm.store_audit_logs is True:
+            _updated_values = response.model_dump_json(exclude_none=True)
+            asyncio.create_task(
+                create_audit_log_for_update(
+                    request_data=LiteLLM_AuditLogs(
+                        id=str(uuid.uuid4()),
+                        updated_at=datetime.now(timezone.utc),
+                        changed_by=litellm_changed_by
+                        or user_api_key_dict.user_id
+                        or litellm_proxy_admin_name,
+                        changed_by_api_key=user_api_key_dict.api_key,
+                        table_name=LitellmTableNames.KEY_TABLE_NAME,
+                        object_id=response.token_id or "",
+                        action="created",
+                        updated_values=_updated_values,
+                        before_value=None,
+                    )
+                )
+            )
+        # store the generated key in the secret manager
+        await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
+            secret_name=data.key_alias or f"virtual-key-{response.token_id}",
+            secret_token=response.key,
+        )
+
+    @staticmethod
+    async def async_key_updated_hook(
+        data: UpdateKeyRequest,
+        existing_key_row: Any,
+        response: Any,
+        user_api_key_dict: UserAPIKeyAuth,
+        litellm_changed_by: Optional[str] = None,
+    ):
+        """
+        Post /key/update processing hook
+
+        Handles the following:
+        - Storing Audit Logs for key update
+        """
+        from litellm.proxy.management_helpers.audit_logs import (
+            create_audit_log_for_update,
+        )
+        from litellm.proxy.proxy_server import litellm_proxy_admin_name
+
+        # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+        if litellm.store_audit_logs is True:
+            _updated_values = json.dumps(data.json(exclude_none=True), default=str)
+
+            _before_value = existing_key_row.json(exclude_none=True)
+            _before_value = json.dumps(_before_value, default=str)
+
+            asyncio.create_task(
+                create_audit_log_for_update(
+                    request_data=LiteLLM_AuditLogs(
+                        id=str(uuid.uuid4()),
+                        updated_at=datetime.now(timezone.utc),
+                        changed_by=litellm_changed_by
+                        or user_api_key_dict.user_id
+                        or litellm_proxy_admin_name,
+                        changed_by_api_key=user_api_key_dict.api_key,
+                        table_name=LitellmTableNames.KEY_TABLE_NAME,
+                        object_id=data.key,
+                        action="updated",
+                        updated_values=_updated_values,
+                        before_value=_before_value,
+                    )
+                )
+            )
+
+    @staticmethod
+    async def async_key_rotated_hook(
+        data: Optional[RegenerateKeyRequest],
+        existing_key_row: Any,
+        response: GenerateKeyResponse,
+        user_api_key_dict: UserAPIKeyAuth,
+        litellm_changed_by: Optional[str] = None,
+    ):
+        # store the generated key in the secret manager
+        if data is not None and response.token_id is not None:
+            initial_secret_name = (
+                existing_key_row.key_alias or f"virtual-key-{existing_key_row.token}"
+            )
+            await KeyManagementEventHooks._rotate_virtual_key_in_secret_manager(
+                current_secret_name=initial_secret_name,
+                new_secret_name=data.key_alias or f"virtual-key-{response.token_id}",
+                new_secret_value=response.key,
+            )
+
+    @staticmethod
+    async def async_key_deleted_hook(
+        data: KeyRequest,
+        keys_being_deleted: List[LiteLLM_VerificationToken],
+        response: dict,
+        user_api_key_dict: UserAPIKeyAuth,
+        litellm_changed_by: Optional[str] = None,
+    ):
+        """
+        Post /key/delete processing hook
+
+        Handles the following:
+        - Storing Audit Logs for key deletion
+        """
+        from litellm.proxy.management_helpers.audit_logs import (
+            create_audit_log_for_update,
+        )
+        from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
+
+        # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+        # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
+        if litellm.store_audit_logs is True and data.keys is not None:
+            # make an audit log for each team deleted
+            for key in data.keys:
+                key_row = await prisma_client.get_data(  # type: ignore
+                    token=key, table_name="key", query_type="find_unique"
+                )
+
+                if key_row is None:
+                    raise ProxyException(
+                        message=f"Key {key} not found",
+                        type=ProxyErrorTypes.bad_request_error,
+                        param="key",
+                        code=status.HTTP_404_NOT_FOUND,
+                    )
+
+                key_row = key_row.json(exclude_none=True)
+                _key_row = json.dumps(key_row, default=str)
+
+                asyncio.create_task(
+                    create_audit_log_for_update(
+                        request_data=LiteLLM_AuditLogs(
+                            id=str(uuid.uuid4()),
+                            updated_at=datetime.now(timezone.utc),
+                            changed_by=litellm_changed_by
+                            or user_api_key_dict.user_id
+                            or litellm_proxy_admin_name,
+                            changed_by_api_key=user_api_key_dict.api_key,
+                            table_name=LitellmTableNames.KEY_TABLE_NAME,
+                            object_id=key,
+                            action="deleted",
+                            updated_values="{}",
+                            before_value=_key_row,
+                        )
+                    )
+                )
+        # delete the keys from the secret manager
+        await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
+            keys_being_deleted=keys_being_deleted
+        )
+        pass
+
+    @staticmethod
+    async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
+        """
+        Store a virtual key in the secret manager
+
+        Args:
+            secret_name: Name of the virtual key
+            secret_token: Value of the virtual key (example: sk-1234)
+        """
+        if litellm._key_management_settings is not None:
+            if litellm._key_management_settings.store_virtual_keys is True:
+                from litellm.secret_managers.base_secret_manager import (
+                    BaseSecretManager,
+                )
+
+                # store the key in the secret manager
+                if isinstance(litellm.secret_manager_client, BaseSecretManager):
+                    await litellm.secret_manager_client.async_write_secret(
+                        secret_name=KeyManagementEventHooks._get_secret_name(
+                            secret_name
+                        ),
+                        secret_value=secret_token,
+                    )
+
+    @staticmethod
+    async def _rotate_virtual_key_in_secret_manager(
+        current_secret_name: str, new_secret_name: str, new_secret_value: str
+    ):
+        """
+        Update a virtual key in the secret manager
+
+        Args:
+            secret_name: Name of the virtual key
+            secret_token: Value of the virtual key (example: sk-1234)
+        """
+        if litellm._key_management_settings is not None:
+            if litellm._key_management_settings.store_virtual_keys is True:
+                from litellm.secret_managers.base_secret_manager import (
+                    BaseSecretManager,
+                )
+
+                # store the key in the secret manager
+                if isinstance(litellm.secret_manager_client, BaseSecretManager):
+                    await litellm.secret_manager_client.async_rotate_secret(
+                        current_secret_name=KeyManagementEventHooks._get_secret_name(
+                            current_secret_name
+                        ),
+                        new_secret_name=KeyManagementEventHooks._get_secret_name(
+                            new_secret_name
+                        ),
+                        new_secret_value=new_secret_value,
+                    )
+
+    @staticmethod
+    def _get_secret_name(secret_name: str) -> str:
+        if litellm._key_management_settings.prefix_for_stored_virtual_keys.endswith(
+            "/"
+        ):
+            return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{secret_name}"
+        else:
+            return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}/{secret_name}"
+
+    @staticmethod
+    async def _delete_virtual_keys_from_secret_manager(
+        keys_being_deleted: List[LiteLLM_VerificationToken],
+    ):
+        """
+        Deletes virtual keys from the secret manager
+
+        Args:
+            keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
+        """
+        if litellm._key_management_settings is not None:
+            if litellm._key_management_settings.store_virtual_keys is True:
+                from litellm.secret_managers.base_secret_manager import (
+                    BaseSecretManager,
+                )
+
+                if isinstance(litellm.secret_manager_client, BaseSecretManager):
+                    for key in keys_being_deleted:
+                        if key.key_alias is not None:
+                            await litellm.secret_manager_client.async_delete_secret(
+                                secret_name=KeyManagementEventHooks._get_secret_name(
+                                    key.key_alias
+                                )
+                            )
+                        else:
+                            verbose_proxy_logger.warning(
+                                f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
+                            )
+
+    @staticmethod
+    async def _send_key_created_email(response: dict):
+        from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
+
+        if "email" not in general_settings.get("alerting", []):
+            raise ValueError(
+                "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
+            )
+        event = WebhookEvent(
+            event="key_created",
+            event_group="key",
+            event_message="API Key Created",
+            token=response.get("token", ""),
+            spend=response.get("spend", 0.0),
+            max_budget=response.get("max_budget", 0.0),
+            user_id=response.get("user_id", None),
+            team_id=response.get("team_id", "Default Team"),
+            key_alias=response.get("key_alias", None),
+        )
+
+        # If user configured email alerting - send an Email letting their end-user know the key was created
+        asyncio.create_task(
+            proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
+                webhook_event=event,
+            )
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/max_budget_limiter.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/max_budget_limiter.py
new file mode 100644
index 00000000..4b59f603
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/max_budget_limiter.py
@@ -0,0 +1,49 @@
+from fastapi import HTTPException
+
+from litellm import verbose_logger
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+class _PROXY_MaxBudgetLimiter(CustomLogger):
+    # Class variables or attributes
+    def __init__(self):
+        pass
+
+    async def async_pre_call_hook(
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        cache: DualCache,
+        data: dict,
+        call_type: str,
+    ):
+        try:
+            verbose_proxy_logger.debug("Inside Max Budget Limiter Pre-Call Hook")
+            cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
+            user_row = await cache.async_get_cache(
+                cache_key, parent_otel_span=user_api_key_dict.parent_otel_span
+            )
+            if user_row is None:  # value not yet cached
+                return
+            max_budget = user_row["max_budget"]
+            curr_spend = user_row["spend"]
+
+            if max_budget is None:
+                return
+
+            if curr_spend is None:
+                return
+
+            # CHECK IF REQUEST ALLOWED
+            if curr_spend >= max_budget:
+                raise HTTPException(status_code=429, detail="Max budget limit reached.")
+        except HTTPException as e:
+            raise e
+        except Exception as e:
+            verbose_logger.exception(
+                "litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format(
+                    str(e)
+                )
+            )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/model_max_budget_limiter.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/model_max_budget_limiter.py
new file mode 100644
index 00000000..ac02c915
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/model_max_budget_limiter.py
@@ -0,0 +1,192 @@
+import json
+from typing import List, Optional
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import Span
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import (
+    BudgetConfig,
+    GenericBudgetConfigType,
+    StandardLoggingPayload,
+)
+
+VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend"
+
+
+class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
+    """
+    Handles budgets for model + virtual key
+
+    Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
+    """
+
+    def __init__(self, dual_cache: DualCache):
+        self.dual_cache = dual_cache
+        self.redis_increment_operation_queue = []
+
+    async def is_key_within_model_budget(
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        model: str,
+    ) -> bool:
+        """
+        Check if the user_api_key_dict is within the model budget
+
+        Raises:
+            BudgetExceededError: If the user_api_key_dict has exceeded the model budget
+        """
+        _model_max_budget = user_api_key_dict.model_max_budget
+        internal_model_max_budget: GenericBudgetConfigType = {}
+
+        for _model, _budget_info in _model_max_budget.items():
+            internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
+
+        verbose_proxy_logger.debug(
+            "internal_model_max_budget %s",
+            json.dumps(internal_model_max_budget, indent=4, default=str),
+        )
+
+        # check if current model is in internal_model_max_budget
+        _current_model_budget_info = self._get_request_model_budget_config(
+            model=model, internal_model_max_budget=internal_model_max_budget
+        )
+        if _current_model_budget_info is None:
+            verbose_proxy_logger.debug(
+                f"Model {model} not found in internal_model_max_budget"
+            )
+            return True
+
+        # check if current model is within budget
+        if (
+            _current_model_budget_info.max_budget
+            and _current_model_budget_info.max_budget > 0
+        ):
+            _current_spend = await self._get_virtual_key_spend_for_model(
+                user_api_key_hash=user_api_key_dict.token,
+                model=model,
+                key_budget_config=_current_model_budget_info,
+            )
+            if (
+                _current_spend is not None
+                and _current_model_budget_info.max_budget is not None
+                and _current_spend > _current_model_budget_info.max_budget
+            ):
+                raise litellm.BudgetExceededError(
+                    message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}",
+                    current_cost=_current_spend,
+                    max_budget=_current_model_budget_info.max_budget,
+                )
+
+        return True
+
+    async def _get_virtual_key_spend_for_model(
+        self,
+        user_api_key_hash: Optional[str],
+        model: str,
+        key_budget_config: BudgetConfig,
+    ) -> Optional[float]:
+        """
+        Get the current spend for a virtual key for a model
+
+        Lookup model in this order:
+            1. model: directly look up `model`
+            2. If 1, does not exist, check if passed as {custom_llm_provider}/model
+        """
+
+        # 1. model: directly look up `model`
+        virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.budget_duration}"
+        _current_spend = await self.dual_cache.async_get_cache(
+            key=virtual_key_model_spend_cache_key,
+        )
+
+        if _current_spend is None:
+            # 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
+            # if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview
+            virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.budget_duration}"
+            _current_spend = await self.dual_cache.async_get_cache(
+                key=virtual_key_model_spend_cache_key,
+            )
+        return _current_spend
+
+    def _get_request_model_budget_config(
+        self, model: str, internal_model_max_budget: GenericBudgetConfigType
+    ) -> Optional[BudgetConfig]:
+        """
+        Get the budget config for the request model
+
+        1. Check if `model` is in `internal_model_max_budget`
+        2. If not, check if `model` without custom llm provider is in `internal_model_max_budget`
+        """
+        return internal_model_max_budget.get(
+            model, None
+        ) or internal_model_max_budget.get(
+            self._get_model_without_custom_llm_provider(model), None
+        )
+
+    def _get_model_without_custom_llm_provider(self, model: str) -> str:
+        if "/" in model:
+            return model.split("/")[-1]
+        return model
+
+    async def async_filter_deployments(
+        self,
+        model: str,
+        healthy_deployments: List,
+        messages: Optional[List[AllMessageValues]],
+        request_kwargs: Optional[dict] = None,
+        parent_otel_span: Optional[Span] = None,  # type: ignore
+    ) -> List[dict]:
+        return healthy_deployments
+
+    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+        """
+        Track spend for virtual key + model in DualCache
+
+        Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
+        """
+        verbose_proxy_logger.debug("in RouterBudgetLimiting.async_log_success_event")
+        standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+            "standard_logging_object", None
+        )
+        if standard_logging_payload is None:
+            raise ValueError("standard_logging_payload is required")
+
+        _litellm_params: dict = kwargs.get("litellm_params", {}) or {}
+        _metadata: dict = _litellm_params.get("metadata", {}) or {}
+        user_api_key_model_max_budget: Optional[dict] = _metadata.get(
+            "user_api_key_model_max_budget", None
+        )
+        if (
+            user_api_key_model_max_budget is None
+            or len(user_api_key_model_max_budget) == 0
+        ):
+            verbose_proxy_logger.debug(
+                "Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget is None or empty. `user_api_key_model_max_budget`=%s",
+                user_api_key_model_max_budget,
+            )
+            return
+        response_cost: float = standard_logging_payload.get("response_cost", 0)
+        model = standard_logging_payload.get("model")
+
+        virtual_key = standard_logging_payload.get("metadata").get("user_api_key_hash")
+        model = standard_logging_payload.get("model")
+        if virtual_key is not None:
+            budget_config = BudgetConfig(time_period="1d", budget_limit=0.1)
+            virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_config.budget_duration}"
+            virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
+            await self._increment_spend_for_key(
+                budget_config=budget_config,
+                spend_key=virtual_spend_key,
+                start_time_key=virtual_start_time_key,
+                response_cost=response_cost,
+            )
+        verbose_proxy_logger.debug(
+            "current state of in memory cache %s",
+            json.dumps(
+                self.dual_cache.in_memory_cache.cache_dict, indent=4, default=str
+            ),
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/parallel_request_limiter.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/parallel_request_limiter.py
new file mode 100644
index 00000000..06f3b6af
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/parallel_request_limiter.py
@@ -0,0 +1,866 @@
+import asyncio
+import sys
+from datetime import datetime, timedelta
+from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union
+
+from fastapi import HTTPException
+from pydantic import BaseModel
+
+import litellm
+from litellm import DualCache, ModelResponse
+from litellm._logging import verbose_proxy_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
+from litellm.proxy._types import CommonProxyErrors, CurrentItemRateLimit, UserAPIKeyAuth
+from litellm.proxy.auth.auth_utils import (
+    get_key_model_rpm_limit,
+    get_key_model_tpm_limit,
+)
+
+if TYPE_CHECKING:
+    from opentelemetry.trace import Span as _Span
+
+    from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
+
+    Span = _Span
+    InternalUsageCache = _InternalUsageCache
+else:
+    Span = Any
+    InternalUsageCache = Any
+
+
+class CacheObject(TypedDict):
+    current_global_requests: Optional[dict]
+    request_count_api_key: Optional[dict]
+    request_count_api_key_model: Optional[dict]
+    request_count_user_id: Optional[dict]
+    request_count_team_id: Optional[dict]
+    request_count_end_user_id: Optional[dict]
+
+
+class _PROXY_MaxParallelRequestsHandler(CustomLogger):
+    # Class variables or attributes
+    def __init__(self, internal_usage_cache: InternalUsageCache):
+        self.internal_usage_cache = internal_usage_cache
+
+    def print_verbose(self, print_statement):
+        try:
+            verbose_proxy_logger.debug(print_statement)
+            if litellm.set_verbose:
+                print(print_statement)  # noqa
+        except Exception:
+            pass
+
+    async def check_key_in_limits(
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        cache: DualCache,
+        data: dict,
+        call_type: str,
+        max_parallel_requests: int,
+        tpm_limit: int,
+        rpm_limit: int,
+        current: Optional[dict],
+        request_count_api_key: str,
+        rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
+        values_to_update_in_cache: List[Tuple[Any, Any]],
+    ) -> dict:
+        verbose_proxy_logger.info(
+            f"Current Usage of {rate_limit_type} in this minute: {current}"
+        )
+        if current is None:
+            if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
+                # base case
+                raise self.raise_rate_limit_error(
+                    additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
+                )
+            new_val = {
+                "current_requests": 1,
+                "current_tpm": 0,
+                "current_rpm": 1,
+            }
+            values_to_update_in_cache.append((request_count_api_key, new_val))
+        elif (
+            int(current["current_requests"]) < max_parallel_requests
+            and current["current_tpm"] < tpm_limit
+            and current["current_rpm"] < rpm_limit
+        ):
+            # Increase count for this token
+            new_val = {
+                "current_requests": current["current_requests"] + 1,
+                "current_tpm": current["current_tpm"],
+                "current_rpm": current["current_rpm"] + 1,
+            }
+            values_to_update_in_cache.append((request_count_api_key, new_val))
+
+        else:
+            raise HTTPException(
+                status_code=429,
+                detail=f"LiteLLM Rate Limit Handler for rate limit type = {rate_limit_type}. {CommonProxyErrors.max_parallel_request_limit_reached.value}. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}, current max_parallel_requests: {current['current_requests']}, max_parallel_requests: {max_parallel_requests}",
+                headers={"retry-after": str(self.time_to_next_minute())},
+            )
+
+        await self.internal_usage_cache.async_batch_set_cache(
+            cache_list=values_to_update_in_cache,
+            ttl=60,
+            litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+            local_only=True,
+        )
+        return new_val
+
+    def time_to_next_minute(self) -> float:
+        # Get the current time
+        now = datetime.now()
+
+        # Calculate the next minute
+        next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0)
+
+        # Calculate the difference in seconds
+        seconds_to_next_minute = (next_minute - now).total_seconds()
+
+        return seconds_to_next_minute
+
+    def raise_rate_limit_error(
+        self, additional_details: Optional[str] = None
+    ) -> HTTPException:
+        """
+        Raise an HTTPException with a 429 status code and a retry-after header
+        """
+        error_message = "Max parallel request limit reached"
+        if additional_details is not None:
+            error_message = error_message + " " + additional_details
+        raise HTTPException(
+            status_code=429,
+            detail=f"Max parallel request limit reached {additional_details}",
+            headers={"retry-after": str(self.time_to_next_minute())},
+        )
+
+    async def get_all_cache_objects(
+        self,
+        current_global_requests: Optional[str],
+        request_count_api_key: Optional[str],
+        request_count_api_key_model: Optional[str],
+        request_count_user_id: Optional[str],
+        request_count_team_id: Optional[str],
+        request_count_end_user_id: Optional[str],
+        parent_otel_span: Optional[Span] = None,
+    ) -> CacheObject:
+        keys = [
+            current_global_requests,
+            request_count_api_key,
+            request_count_api_key_model,
+            request_count_user_id,
+            request_count_team_id,
+            request_count_end_user_id,
+        ]
+        results = await self.internal_usage_cache.async_batch_get_cache(
+            keys=keys,
+            parent_otel_span=parent_otel_span,
+        )
+
+        if results is None:
+            return CacheObject(
+                current_global_requests=None,
+                request_count_api_key=None,
+                request_count_api_key_model=None,
+                request_count_user_id=None,
+                request_count_team_id=None,
+                request_count_end_user_id=None,
+            )
+
+        return CacheObject(
+            current_global_requests=results[0],
+            request_count_api_key=results[1],
+            request_count_api_key_model=results[2],
+            request_count_user_id=results[3],
+            request_count_team_id=results[4],
+            request_count_end_user_id=results[5],
+        )
+
+    async def async_pre_call_hook(  # noqa: PLR0915
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        cache: DualCache,
+        data: dict,
+        call_type: str,
+    ):
+        self.print_verbose("Inside Max Parallel Request Pre-Call Hook")
+        api_key = user_api_key_dict.api_key
+        max_parallel_requests = user_api_key_dict.max_parallel_requests
+        if max_parallel_requests is None:
+            max_parallel_requests = sys.maxsize
+        if data is None:
+            data = {}
+        global_max_parallel_requests = data.get("metadata", {}).get(
+            "global_max_parallel_requests", None
+        )
+        tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
+        if tpm_limit is None:
+            tpm_limit = sys.maxsize
+        rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
+        if rpm_limit is None:
+            rpm_limit = sys.maxsize
+
+        values_to_update_in_cache: List[Tuple[Any, Any]] = (
+            []
+        )  # values that need to get updated in cache, will run a batch_set_cache after this function
+
+        # ------------
+        # Setup values
+        # ------------
+        new_val: Optional[dict] = None
+
+        if global_max_parallel_requests is not None:
+            # get value from cache
+            _key = "global_max_parallel_requests"
+            current_global_requests = await self.internal_usage_cache.async_get_cache(
+                key=_key,
+                local_only=True,
+                litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+            )
+            # check if below limit
+            if current_global_requests is None:
+                current_global_requests = 1
+            # if above -> raise error
+            if current_global_requests >= global_max_parallel_requests:
+                return self.raise_rate_limit_error(
+                    additional_details=f"Hit Global Limit: Limit={global_max_parallel_requests}, current: {current_global_requests}"
+                )
+            # if below -> increment
+            else:
+                await self.internal_usage_cache.async_increment_cache(
+                    key=_key,
+                    value=1,
+                    local_only=True,
+                    litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+                )
+        _model = data.get("model", None)
+
+        current_date = datetime.now().strftime("%Y-%m-%d")
+        current_hour = datetime.now().strftime("%H")
+        current_minute = datetime.now().strftime("%M")
+        precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+
+        cache_objects: CacheObject = await self.get_all_cache_objects(
+            current_global_requests=(
+                "global_max_parallel_requests"
+                if global_max_parallel_requests is not None
+                else None
+            ),
+            request_count_api_key=(
+                f"{api_key}::{precise_minute}::request_count"
+                if api_key is not None
+                else None
+            ),
+            request_count_api_key_model=(
+                f"{api_key}::{_model}::{precise_minute}::request_count"
+                if api_key is not None and _model is not None
+                else None
+            ),
+            request_count_user_id=(
+                f"{user_api_key_dict.user_id}::{precise_minute}::request_count"
+                if user_api_key_dict.user_id is not None
+                else None
+            ),
+            request_count_team_id=(
+                f"{user_api_key_dict.team_id}::{precise_minute}::request_count"
+                if user_api_key_dict.team_id is not None
+                else None
+            ),
+            request_count_end_user_id=(
+                f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
+                if user_api_key_dict.end_user_id is not None
+                else None
+            ),
+            parent_otel_span=user_api_key_dict.parent_otel_span,
+        )
+        if api_key is not None:
+            request_count_api_key = f"{api_key}::{precise_minute}::request_count"
+            # CHECK IF REQUEST ALLOWED for key
+            await self.check_key_in_limits(
+                user_api_key_dict=user_api_key_dict,
+                cache=cache,
+                data=data,
+                call_type=call_type,
+                max_parallel_requests=max_parallel_requests,
+                current=cache_objects["request_count_api_key"],
+                request_count_api_key=request_count_api_key,
+                tpm_limit=tpm_limit,
+                rpm_limit=rpm_limit,
+                rate_limit_type="key",
+                values_to_update_in_cache=values_to_update_in_cache,
+            )
+
+        # Check if request under RPM/TPM per model for a given API Key
+        if (
+            get_key_model_tpm_limit(user_api_key_dict) is not None
+            or get_key_model_rpm_limit(user_api_key_dict) is not None
+        ):
+            _model = data.get("model", None)
+            request_count_api_key = (
+                f"{api_key}::{_model}::{precise_minute}::request_count"
+            )
+            _tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
+            _rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
+            tpm_limit_for_model = None
+            rpm_limit_for_model = None
+
+            if _model is not None:
+                if _tpm_limit_for_key_model:
+                    tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
+
+                if _rpm_limit_for_key_model:
+                    rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
+
+            new_val = await self.check_key_in_limits(
+                user_api_key_dict=user_api_key_dict,
+                cache=cache,
+                data=data,
+                call_type=call_type,
+                max_parallel_requests=sys.maxsize,  # TODO: Support max parallel requests for a model
+                current=cache_objects["request_count_api_key_model"],
+                request_count_api_key=request_count_api_key,
+                tpm_limit=tpm_limit_for_model or sys.maxsize,
+                rpm_limit=rpm_limit_for_model or sys.maxsize,
+                rate_limit_type="model_per_key",
+                values_to_update_in_cache=values_to_update_in_cache,
+            )
+            _remaining_tokens = None
+            _remaining_requests = None
+            # Add remaining tokens, requests to metadata
+            if new_val:
+                if tpm_limit_for_model is not None:
+                    _remaining_tokens = tpm_limit_for_model - new_val["current_tpm"]
+                if rpm_limit_for_model is not None:
+                    _remaining_requests = rpm_limit_for_model - new_val["current_rpm"]
+
+            _remaining_limits_data = {
+                f"litellm-key-remaining-tokens-{_model}": _remaining_tokens,
+                f"litellm-key-remaining-requests-{_model}": _remaining_requests,
+            }
+
+            if "metadata" not in data:
+                data["metadata"] = {}
+            data["metadata"].update(_remaining_limits_data)
+
+        # check if REQUEST ALLOWED for user_id
+        user_id = user_api_key_dict.user_id
+        if user_id is not None:
+            user_tpm_limit = user_api_key_dict.user_tpm_limit
+            user_rpm_limit = user_api_key_dict.user_rpm_limit
+            if user_tpm_limit is None:
+                user_tpm_limit = sys.maxsize
+            if user_rpm_limit is None:
+                user_rpm_limit = sys.maxsize
+
+            request_count_api_key = f"{user_id}::{precise_minute}::request_count"
+            # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
+            await self.check_key_in_limits(
+                user_api_key_dict=user_api_key_dict,
+                cache=cache,
+                data=data,
+                call_type=call_type,
+                max_parallel_requests=sys.maxsize,  # TODO: Support max parallel requests for a user
+                current=cache_objects["request_count_user_id"],
+                request_count_api_key=request_count_api_key,
+                tpm_limit=user_tpm_limit,
+                rpm_limit=user_rpm_limit,
+                rate_limit_type="user",
+                values_to_update_in_cache=values_to_update_in_cache,
+            )
+
+        # TEAM RATE LIMITS
+        ## get team tpm/rpm limits
+        team_id = user_api_key_dict.team_id
+        if team_id is not None:
+            team_tpm_limit = user_api_key_dict.team_tpm_limit
+            team_rpm_limit = user_api_key_dict.team_rpm_limit
+
+            if team_tpm_limit is None:
+                team_tpm_limit = sys.maxsize
+            if team_rpm_limit is None:
+                team_rpm_limit = sys.maxsize
+
+            request_count_api_key = f"{team_id}::{precise_minute}::request_count"
+            # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
+            await self.check_key_in_limits(
+                user_api_key_dict=user_api_key_dict,
+                cache=cache,
+                data=data,
+                call_type=call_type,
+                max_parallel_requests=sys.maxsize,  # TODO: Support max parallel requests for a team
+                current=cache_objects["request_count_team_id"],
+                request_count_api_key=request_count_api_key,
+                tpm_limit=team_tpm_limit,
+                rpm_limit=team_rpm_limit,
+                rate_limit_type="team",
+                values_to_update_in_cache=values_to_update_in_cache,
+            )
+
+        # End-User Rate Limits
+        # Only enforce if user passed `user` to /chat, /completions, /embeddings
+        if user_api_key_dict.end_user_id:
+            end_user_tpm_limit = getattr(
+                user_api_key_dict, "end_user_tpm_limit", sys.maxsize
+            )
+            end_user_rpm_limit = getattr(
+                user_api_key_dict, "end_user_rpm_limit", sys.maxsize
+            )
+
+            if end_user_tpm_limit is None:
+                end_user_tpm_limit = sys.maxsize
+            if end_user_rpm_limit is None:
+                end_user_rpm_limit = sys.maxsize
+
+            # now do the same tpm/rpm checks
+            request_count_api_key = (
+                f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
+            )
+
+            # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
+            await self.check_key_in_limits(
+                user_api_key_dict=user_api_key_dict,
+                cache=cache,
+                data=data,
+                call_type=call_type,
+                max_parallel_requests=sys.maxsize,  # TODO: Support max parallel requests for an End-User
+                request_count_api_key=request_count_api_key,
+                current=cache_objects["request_count_end_user_id"],
+                tpm_limit=end_user_tpm_limit,
+                rpm_limit=end_user_rpm_limit,
+                rate_limit_type="customer",
+                values_to_update_in_cache=values_to_update_in_cache,
+            )
+
+        asyncio.create_task(
+            self.internal_usage_cache.async_batch_set_cache(
+                cache_list=values_to_update_in_cache,
+                ttl=60,
+                litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+            )  # don't block execution for cache updates
+        )
+
+        return
+
+    async def async_log_success_event(  # noqa: PLR0915
+        self, kwargs, response_obj, start_time, end_time
+    ):
+        from litellm.proxy.common_utils.callback_utils import (
+            get_model_group_from_litellm_kwargs,
+        )
+
+        litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs(
+            kwargs=kwargs
+        )
+        try:
+            self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
+
+            global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
+                "global_max_parallel_requests", None
+            )
+            user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
+            user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
+                "user_api_key_user_id", None
+            )
+            user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
+                "user_api_key_team_id", None
+            )
+            user_api_key_model_max_budget = kwargs["litellm_params"]["metadata"].get(
+                "user_api_key_model_max_budget", None
+            )
+            user_api_key_end_user_id = kwargs.get("user")
+
+            user_api_key_metadata = (
+                kwargs["litellm_params"]["metadata"].get("user_api_key_metadata", {})
+                or {}
+            )
+
+            # ------------
+            # Setup values
+            # ------------
+
+            if global_max_parallel_requests is not None:
+                # get value from cache
+                _key = "global_max_parallel_requests"
+                # decrement
+                await self.internal_usage_cache.async_increment_cache(
+                    key=_key,
+                    value=-1,
+                    local_only=True,
+                    litellm_parent_otel_span=litellm_parent_otel_span,
+                )
+
+            current_date = datetime.now().strftime("%Y-%m-%d")
+            current_hour = datetime.now().strftime("%H")
+            current_minute = datetime.now().strftime("%M")
+            precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+
+            total_tokens = 0
+
+            if isinstance(response_obj, ModelResponse):
+                total_tokens = response_obj.usage.total_tokens  # type: ignore
+
+            # ------------
+            # Update usage - API Key
+            # ------------
+
+            values_to_update_in_cache = []
+
+            if user_api_key is not None:
+                request_count_api_key = (
+                    f"{user_api_key}::{precise_minute}::request_count"
+                )
+
+                current = await self.internal_usage_cache.async_get_cache(
+                    key=request_count_api_key,
+                    litellm_parent_otel_span=litellm_parent_otel_span,
+                ) or {
+                    "current_requests": 1,
+                    "current_tpm": 0,
+                    "current_rpm": 0,
+                }
+
+                new_val = {
+                    "current_requests": max(current["current_requests"] - 1, 0),
+                    "current_tpm": current["current_tpm"] + total_tokens,
+                    "current_rpm": current["current_rpm"],
+                }
+
+                self.print_verbose(
+                    f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
+                )
+                values_to_update_in_cache.append((request_count_api_key, new_val))
+
+            # ------------
+            # Update usage - model group + API Key
+            # ------------
+            model_group = get_model_group_from_litellm_kwargs(kwargs)
+            if (
+                user_api_key is not None
+                and model_group is not None
+                and (
+                    "model_rpm_limit" in user_api_key_metadata
+                    or "model_tpm_limit" in user_api_key_metadata
+                    or user_api_key_model_max_budget is not None
+                )
+            ):
+                request_count_api_key = (
+                    f"{user_api_key}::{model_group}::{precise_minute}::request_count"
+                )
+
+                current = await self.internal_usage_cache.async_get_cache(
+                    key=request_count_api_key,
+                    litellm_parent_otel_span=litellm_parent_otel_span,
+                ) or {
+                    "current_requests": 1,
+                    "current_tpm": 0,
+                    "current_rpm": 0,
+                }
+
+                new_val = {
+                    "current_requests": max(current["current_requests"] - 1, 0),
+                    "current_tpm": current["current_tpm"] + total_tokens,
+                    "current_rpm": current["current_rpm"],
+                }
+
+                self.print_verbose(
+                    f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
+                )
+                values_to_update_in_cache.append((request_count_api_key, new_val))
+
+            # ------------
+            # Update usage - User
+            # ------------
+            if user_api_key_user_id is not None:
+                total_tokens = 0
+
+                if isinstance(response_obj, ModelResponse):
+                    total_tokens = response_obj.usage.total_tokens  # type: ignore
+
+                request_count_api_key = (
+                    f"{user_api_key_user_id}::{precise_minute}::request_count"
+                )
+
+                current = await self.internal_usage_cache.async_get_cache(
+                    key=request_count_api_key,
+                    litellm_parent_otel_span=litellm_parent_otel_span,
+                ) or {
+                    "current_requests": 1,
+                    "current_tpm": total_tokens,
+                    "current_rpm": 1,
+                }
+
+                new_val = {
+                    "current_requests": max(current["current_requests"] - 1, 0),
+                    "current_tpm": current["current_tpm"] + total_tokens,
+                    "current_rpm": current["current_rpm"],
+                }
+
+                self.print_verbose(
+                    f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
+                )
+                values_to_update_in_cache.append((request_count_api_key, new_val))
+
+            # ------------
+            # Update usage - Team
+            # ------------
+            if user_api_key_team_id is not None:
+                total_tokens = 0
+
+                if isinstance(response_obj, ModelResponse):
+                    total_tokens = response_obj.usage.total_tokens  # type: ignore
+
+                request_count_api_key = (
+                    f"{user_api_key_team_id}::{precise_minute}::request_count"
+                )
+
+                current = await self.internal_usage_cache.async_get_cache(
+                    key=request_count_api_key,
+                    litellm_parent_otel_span=litellm_parent_otel_span,
+                ) or {
+                    "current_requests": 1,
+                    "current_tpm": total_tokens,
+                    "current_rpm": 1,
+                }
+
+                new_val = {
+                    "current_requests": max(current["current_requests"] - 1, 0),
+                    "current_tpm": current["current_tpm"] + total_tokens,
+                    "current_rpm": current["current_rpm"],
+                }
+
+                self.print_verbose(
+                    f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
+                )
+                values_to_update_in_cache.append((request_count_api_key, new_val))
+
+            # ------------
+            # Update usage - End User
+            # ------------
+            if user_api_key_end_user_id is not None:
+                total_tokens = 0
+
+                if isinstance(response_obj, ModelResponse):
+                    total_tokens = response_obj.usage.total_tokens  # type: ignore
+
+                request_count_api_key = (
+                    f"{user_api_key_end_user_id}::{precise_minute}::request_count"
+                )
+
+                current = await self.internal_usage_cache.async_get_cache(
+                    key=request_count_api_key,
+                    litellm_parent_otel_span=litellm_parent_otel_span,
+                ) or {
+                    "current_requests": 1,
+                    "current_tpm": total_tokens,
+                    "current_rpm": 1,
+                }
+
+                new_val = {
+                    "current_requests": max(current["current_requests"] - 1, 0),
+                    "current_tpm": current["current_tpm"] + total_tokens,
+                    "current_rpm": current["current_rpm"],
+                }
+
+                self.print_verbose(
+                    f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
+                )
+                values_to_update_in_cache.append((request_count_api_key, new_val))
+
+            await self.internal_usage_cache.async_batch_set_cache(
+                cache_list=values_to_update_in_cache,
+                ttl=60,
+                litellm_parent_otel_span=litellm_parent_otel_span,
+            )
+        except Exception as e:
+            self.print_verbose(e)  # noqa
+
+    async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+        try:
+            self.print_verbose("Inside Max Parallel Request Failure Hook")
+            litellm_parent_otel_span: Union[Span, None] = (
+                _get_parent_otel_span_from_kwargs(kwargs=kwargs)
+            )
+            _metadata = kwargs["litellm_params"].get("metadata", {}) or {}
+            global_max_parallel_requests = _metadata.get(
+                "global_max_parallel_requests", None
+            )
+            user_api_key = _metadata.get("user_api_key", None)
+            self.print_verbose(f"user_api_key: {user_api_key}")
+            if user_api_key is None:
+                return
+
+            ## decrement call count if call failed
+            if CommonProxyErrors.max_parallel_request_limit_reached.value in str(
+                kwargs["exception"]
+            ):
+                pass  # ignore failed calls due to max limit being reached
+            else:
+                # ------------
+                # Setup values
+                # ------------
+
+                if global_max_parallel_requests is not None:
+                    # get value from cache
+                    _key = "global_max_parallel_requests"
+                    (
+                        await self.internal_usage_cache.async_get_cache(
+                            key=_key,
+                            local_only=True,
+                            litellm_parent_otel_span=litellm_parent_otel_span,
+                        )
+                    )
+                    # decrement
+                    await self.internal_usage_cache.async_increment_cache(
+                        key=_key,
+                        value=-1,
+                        local_only=True,
+                        litellm_parent_otel_span=litellm_parent_otel_span,
+                    )
+
+                current_date = datetime.now().strftime("%Y-%m-%d")
+                current_hour = datetime.now().strftime("%H")
+                current_minute = datetime.now().strftime("%M")
+                precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+
+                request_count_api_key = (
+                    f"{user_api_key}::{precise_minute}::request_count"
+                )
+
+                # ------------
+                # Update usage
+                # ------------
+                current = await self.internal_usage_cache.async_get_cache(
+                    key=request_count_api_key,
+                    litellm_parent_otel_span=litellm_parent_otel_span,
+                ) or {
+                    "current_requests": 1,
+                    "current_tpm": 0,
+                    "current_rpm": 0,
+                }
+
+                new_val = {
+                    "current_requests": max(current["current_requests"] - 1, 0),
+                    "current_tpm": current["current_tpm"],
+                    "current_rpm": current["current_rpm"],
+                }
+
+                self.print_verbose(f"updated_value in failure call: {new_val}")
+                await self.internal_usage_cache.async_set_cache(
+                    request_count_api_key,
+                    new_val,
+                    ttl=60,
+                    litellm_parent_otel_span=litellm_parent_otel_span,
+                )  # save in cache for up to 1 min.
+        except Exception as e:
+            verbose_proxy_logger.exception(
+                "Inside Parallel Request Limiter: An exception occurred - {}".format(
+                    str(e)
+                )
+            )
+
+    async def get_internal_user_object(
+        self,
+        user_id: str,
+        user_api_key_dict: UserAPIKeyAuth,
+    ) -> Optional[dict]:
+        """
+        Helper to get the 'Internal User Object'
+
+        It uses the `get_user_object` function from `litellm.proxy.auth.auth_checks`
+
+        We need this because the UserApiKeyAuth object does not contain the rpm/tpm limits for a User AND there could be a perf impact by additionally reading the UserTable.
+        """
+        from litellm._logging import verbose_proxy_logger
+        from litellm.proxy.auth.auth_checks import get_user_object
+        from litellm.proxy.proxy_server import prisma_client
+
+        try:
+            _user_id_rate_limits = await get_user_object(
+                user_id=user_id,
+                prisma_client=prisma_client,
+                user_api_key_cache=self.internal_usage_cache.dual_cache,
+                user_id_upsert=False,
+                parent_otel_span=user_api_key_dict.parent_otel_span,
+                proxy_logging_obj=None,
+            )
+
+            if _user_id_rate_limits is None:
+                return None
+
+            return _user_id_rate_limits.model_dump()
+        except Exception as e:
+            verbose_proxy_logger.debug(
+                "Parallel Request Limiter: Error getting user object", str(e)
+            )
+            return None
+
+    async def async_post_call_success_hook(
+        self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
+    ):
+        """
+        Retrieve the key's remaining rate limits.
+        """
+        api_key = user_api_key_dict.api_key
+        current_date = datetime.now().strftime("%Y-%m-%d")
+        current_hour = datetime.now().strftime("%H")
+        current_minute = datetime.now().strftime("%M")
+        precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+        request_count_api_key = f"{api_key}::{precise_minute}::request_count"
+        current: Optional[CurrentItemRateLimit] = (
+            await self.internal_usage_cache.async_get_cache(
+                key=request_count_api_key,
+                litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+            )
+        )
+
+        key_remaining_rpm_limit: Optional[int] = None
+        key_rpm_limit: Optional[int] = None
+        key_remaining_tpm_limit: Optional[int] = None
+        key_tpm_limit: Optional[int] = None
+        if current is not None:
+            if user_api_key_dict.rpm_limit is not None:
+                key_remaining_rpm_limit = (
+                    user_api_key_dict.rpm_limit - current["current_rpm"]
+                )
+                key_rpm_limit = user_api_key_dict.rpm_limit
+            if user_api_key_dict.tpm_limit is not None:
+                key_remaining_tpm_limit = (
+                    user_api_key_dict.tpm_limit - current["current_tpm"]
+                )
+                key_tpm_limit = user_api_key_dict.tpm_limit
+
+        if hasattr(response, "_hidden_params"):
+            _hidden_params = getattr(response, "_hidden_params")
+        else:
+            _hidden_params = None
+        if _hidden_params is not None and (
+            isinstance(_hidden_params, BaseModel) or isinstance(_hidden_params, dict)
+        ):
+            if isinstance(_hidden_params, BaseModel):
+                _hidden_params = _hidden_params.model_dump()
+
+            _additional_headers = _hidden_params.get("additional_headers", {}) or {}
+
+            if key_remaining_rpm_limit is not None:
+                _additional_headers["x-ratelimit-remaining-requests"] = (
+                    key_remaining_rpm_limit
+                )
+            if key_rpm_limit is not None:
+                _additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit
+            if key_remaining_tpm_limit is not None:
+                _additional_headers["x-ratelimit-remaining-tokens"] = (
+                    key_remaining_tpm_limit
+                )
+            if key_tpm_limit is not None:
+                _additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit
+
+            setattr(
+                response,
+                "_hidden_params",
+                {**_hidden_params, "additional_headers": _additional_headers},
+            )
+
+            return await super().async_post_call_success_hook(
+                data, user_api_key_dict, response
+            )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py
new file mode 100644
index 00000000..b1b2bbee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py
@@ -0,0 +1,280 @@
+# +------------------------------------+
+#
+#        Prompt Injection Detection
+#
+# +------------------------------------+
+#  Thank you users! We ❤️ you! - Krrish & Ishaan
+## Reject a call if it contains a prompt injection attack.
+
+
+from difflib import SequenceMatcher
+from typing import List, Literal, Optional
+
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.prompt_templates.factory import (
+    prompt_injection_detection_default_pt,
+)
+from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth
+from litellm.router import Router
+from litellm.utils import get_formatted_prompt
+
+
+class _OPTIONAL_PromptInjectionDetection(CustomLogger):
+    # Class variables or attributes
+    def __init__(
+        self,
+        prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None,
+    ):
+        self.prompt_injection_params = prompt_injection_params
+        self.llm_router: Optional[Router] = None
+
+        self.verbs = [
+            "Ignore",
+            "Disregard",
+            "Skip",
+            "Forget",
+            "Neglect",
+            "Overlook",
+            "Omit",
+            "Bypass",
+            "Pay no attention to",
+            "Do not follow",
+            "Do not obey",
+        ]
+        self.adjectives = [
+            "",
+            "prior",
+            "previous",
+            "preceding",
+            "above",
+            "foregoing",
+            "earlier",
+            "initial",
+        ]
+        self.prepositions = [
+            "",
+            "and start over",
+            "and start anew",
+            "and begin afresh",
+            "and start from scratch",
+        ]
+
+    def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
+        if level == "INFO":
+            verbose_proxy_logger.info(print_statement)
+        elif level == "DEBUG":
+            verbose_proxy_logger.debug(print_statement)
+
+        if litellm.set_verbose is True:
+            print(print_statement)  # noqa
+
+    def update_environment(self, router: Optional[Router] = None):
+        self.llm_router = router
+
+        if (
+            self.prompt_injection_params is not None
+            and self.prompt_injection_params.llm_api_check is True
+        ):
+            if self.llm_router is None:
+                raise Exception(
+                    "PromptInjectionDetection: Model List not set. Required for Prompt Injection detection."
+                )
+
+            self.print_verbose(
+                f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}"
+            )
+            if (
+                self.prompt_injection_params.llm_api_name is None
+                or self.prompt_injection_params.llm_api_name
+                not in self.llm_router.model_names
+            ):
+                raise Exception(
+                    "PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'."
+                )
+
+    def generate_injection_keywords(self) -> List[str]:
+        combinations = []
+        for verb in self.verbs:
+            for adj in self.adjectives:
+                for prep in self.prepositions:
+                    phrase = " ".join(filter(None, [verb, adj, prep])).strip()
+                    if (
+                        len(phrase.split()) > 2
+                    ):  # additional check to ensure more than 2 words
+                        combinations.append(phrase.lower())
+        return combinations
+
+    def check_user_input_similarity(
+        self, user_input: str, similarity_threshold: float = 0.7
+    ) -> bool:
+        user_input_lower = user_input.lower()
+        keywords = self.generate_injection_keywords()
+
+        for keyword in keywords:
+            # Calculate the length of the keyword to extract substrings of the same length from user input
+            keyword_length = len(keyword)
+
+            for i in range(len(user_input_lower) - keyword_length + 1):
+                # Extract a substring of the same length as the keyword
+                substring = user_input_lower[i : i + keyword_length]
+
+                # Calculate similarity
+                match_ratio = SequenceMatcher(None, substring, keyword).ratio()
+                if match_ratio > similarity_threshold:
+                    self.print_verbose(
+                        print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}",
+                        level="INFO",
+                    )
+                    return True  # Found a highly similar substring
+        return False  # No substring crossed the threshold
+
+    async def async_pre_call_hook(
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        cache: DualCache,
+        data: dict,
+        call_type: str,  # "completion", "embeddings", "image_generation", "moderation"
+    ):
+        try:
+            """
+            - check if user id part of call
+            - check if user id part of blocked list
+            """
+            self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook")
+            try:
+                assert call_type in [
+                    "completion",
+                    "text_completion",
+                    "embeddings",
+                    "image_generation",
+                    "moderation",
+                    "audio_transcription",
+                ]
+            except Exception:
+                self.print_verbose(
+                    f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']"
+                )
+                return data
+            formatted_prompt = get_formatted_prompt(data=data, call_type=call_type)  # type: ignore
+
+            is_prompt_attack = False
+
+            if self.prompt_injection_params is not None:
+                # 1. check if heuristics check turned on
+                if self.prompt_injection_params.heuristics_check is True:
+                    is_prompt_attack = self.check_user_input_similarity(
+                        user_input=formatted_prompt
+                    )
+                    if is_prompt_attack is True:
+                        raise HTTPException(
+                            status_code=400,
+                            detail={
+                                "error": "Rejected message. This is a prompt injection attack."
+                            },
+                        )
+                # 2. check if vector db similarity check turned on [TODO] Not Implemented yet
+                if self.prompt_injection_params.vector_db_check is True:
+                    pass
+            else:
+                is_prompt_attack = self.check_user_input_similarity(
+                    user_input=formatted_prompt
+                )
+
+            if is_prompt_attack is True:
+                raise HTTPException(
+                    status_code=400,
+                    detail={
+                        "error": "Rejected message. This is a prompt injection attack."
+                    },
+                )
+
+            return data
+
+        except HTTPException as e:
+
+            if (
+                e.status_code == 400
+                and isinstance(e.detail, dict)
+                and "error" in e.detail  # type: ignore
+                and self.prompt_injection_params is not None
+                and self.prompt_injection_params.reject_as_response
+            ):
+                return e.detail.get("error")
+            raise e
+        except Exception as e:
+            verbose_proxy_logger.exception(
+                "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
+                    str(e)
+                )
+            )
+
+    async def async_moderation_hook(  # type: ignore
+        self,
+        data: dict,
+        user_api_key_dict: UserAPIKeyAuth,
+        call_type: Literal[
+            "completion",
+            "embeddings",
+            "image_generation",
+            "moderation",
+            "audio_transcription",
+        ],
+    ) -> Optional[bool]:
+        self.print_verbose(
+            f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}"
+        )
+
+        if self.prompt_injection_params is None:
+            return None
+
+        formatted_prompt = get_formatted_prompt(data=data, call_type=call_type)  # type: ignore
+        is_prompt_attack = False
+
+        prompt_injection_system_prompt = getattr(
+            self.prompt_injection_params,
+            "llm_api_system_prompt",
+            prompt_injection_detection_default_pt(),
+        )
+
+        # 3. check if llm api check turned on
+        if (
+            self.prompt_injection_params.llm_api_check is True
+            and self.prompt_injection_params.llm_api_name is not None
+            and self.llm_router is not None
+        ):
+            # make a call to the llm api
+            response = await self.llm_router.acompletion(
+                model=self.prompt_injection_params.llm_api_name,
+                messages=[
+                    {
+                        "role": "system",
+                        "content": prompt_injection_system_prompt,
+                    },
+                    {"role": "user", "content": formatted_prompt},
+                ],
+            )
+
+            self.print_verbose(f"Received LLM Moderation response: {response}")
+            self.print_verbose(
+                f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}"
+            )
+            if isinstance(response, litellm.ModelResponse) and isinstance(
+                response.choices[0], litellm.Choices
+            ):
+                if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content:  # type: ignore
+                    is_prompt_attack = True
+
+        if is_prompt_attack is True:
+            raise HTTPException(
+                status_code=400,
+                detail={
+                    "error": "Rejected message. This is a prompt injection attack."
+                },
+            )
+
+        return is_prompt_attack
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py
new file mode 100644
index 00000000..e8a94732
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py
@@ -0,0 +1,246 @@
+import asyncio
+import traceback
+from datetime import datetime
+from typing import Any, Optional, Union, cast
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.core_helpers import (
+    _get_parent_otel_span_from_kwargs,
+    get_litellm_metadata_from_kwargs,
+)
+from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.proxy.auth.auth_checks import log_db_metrics
+from litellm.proxy.utils import ProxyUpdateSpend
+from litellm.types.utils import (
+    StandardLoggingPayload,
+    StandardLoggingUserAPIKeyMetadata,
+)
+from litellm.utils import get_end_user_id_for_cost_tracking
+
+
+class _ProxyDBLogger(CustomLogger):
+    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+        await self._PROXY_track_cost_callback(
+            kwargs, response_obj, start_time, end_time
+        )
+
+    async def async_post_call_failure_hook(
+        self,
+        request_data: dict,
+        original_exception: Exception,
+        user_api_key_dict: UserAPIKeyAuth,
+    ):
+        from litellm.proxy.proxy_server import update_database
+
+        if _ProxyDBLogger._should_track_errors_in_db() is False:
+            return
+
+        _metadata = dict(
+            StandardLoggingUserAPIKeyMetadata(
+                user_api_key_hash=user_api_key_dict.api_key,
+                user_api_key_alias=user_api_key_dict.key_alias,
+                user_api_key_user_email=user_api_key_dict.user_email,
+                user_api_key_user_id=user_api_key_dict.user_id,
+                user_api_key_team_id=user_api_key_dict.team_id,
+                user_api_key_org_id=user_api_key_dict.org_id,
+                user_api_key_team_alias=user_api_key_dict.team_alias,
+                user_api_key_end_user_id=user_api_key_dict.end_user_id,
+            )
+        )
+        _metadata["user_api_key"] = user_api_key_dict.api_key
+        _metadata["status"] = "failure"
+        _metadata["error_information"] = (
+            StandardLoggingPayloadSetup.get_error_information(
+                original_exception=original_exception,
+            )
+        )
+
+        existing_metadata: dict = request_data.get("metadata", None) or {}
+        existing_metadata.update(_metadata)
+
+        if "litellm_params" not in request_data:
+            request_data["litellm_params"] = {}
+        request_data["litellm_params"]["proxy_server_request"] = (
+            request_data.get("proxy_server_request") or {}
+        )
+        request_data["litellm_params"]["metadata"] = existing_metadata
+        await update_database(
+            token=user_api_key_dict.api_key,
+            response_cost=0.0,
+            user_id=user_api_key_dict.user_id,
+            end_user_id=user_api_key_dict.end_user_id,
+            team_id=user_api_key_dict.team_id,
+            kwargs=request_data,
+            completion_response=original_exception,
+            start_time=datetime.now(),
+            end_time=datetime.now(),
+            org_id=user_api_key_dict.org_id,
+        )
+
+    @log_db_metrics
+    async def _PROXY_track_cost_callback(
+        self,
+        kwargs,  # kwargs to completion
+        completion_response: Optional[
+            Union[litellm.ModelResponse, Any]
+        ],  # response from completion
+        start_time=None,
+        end_time=None,  # start/end time for completion
+    ):
+        from litellm.proxy.proxy_server import (
+            prisma_client,
+            proxy_logging_obj,
+            update_cache,
+            update_database,
+        )
+
+        verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
+        try:
+            verbose_proxy_logger.debug(
+                f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
+            )
+            parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
+            litellm_params = kwargs.get("litellm_params", {}) or {}
+            end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
+            metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
+            user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
+            team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
+            org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
+            key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
+            end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
+            sl_object: Optional[StandardLoggingPayload] = kwargs.get(
+                "standard_logging_object", None
+            )
+            response_cost = (
+                sl_object.get("response_cost", None)
+                if sl_object is not None
+                else kwargs.get("response_cost", None)
+            )
+
+            if response_cost is not None:
+                user_api_key = metadata.get("user_api_key", None)
+                if kwargs.get("cache_hit", False) is True:
+                    response_cost = 0.0
+                    verbose_proxy_logger.info(
+                        f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
+                    )
+
+                verbose_proxy_logger.debug(
+                    f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
+                )
+                if _should_track_cost_callback(
+                    user_api_key=user_api_key,
+                    user_id=user_id,
+                    team_id=team_id,
+                    end_user_id=end_user_id,
+                ):
+                    ## UPDATE DATABASE
+                    await update_database(
+                        token=user_api_key,
+                        response_cost=response_cost,
+                        user_id=user_id,
+                        end_user_id=end_user_id,
+                        team_id=team_id,
+                        kwargs=kwargs,
+                        completion_response=completion_response,
+                        start_time=start_time,
+                        end_time=end_time,
+                        org_id=org_id,
+                    )
+
+                    # update cache
+                    asyncio.create_task(
+                        update_cache(
+                            token=user_api_key,
+                            user_id=user_id,
+                            end_user_id=end_user_id,
+                            response_cost=response_cost,
+                            team_id=team_id,
+                            parent_otel_span=parent_otel_span,
+                        )
+                    )
+
+                    await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
+                        token=user_api_key,
+                        key_alias=key_alias,
+                        end_user_id=end_user_id,
+                        response_cost=response_cost,
+                        max_budget=end_user_max_budget,
+                    )
+                else:
+                    raise Exception(
+                        "User API key and team id and user id missing from custom callback."
+                    )
+            else:
+                if kwargs["stream"] is not True or (
+                    kwargs["stream"] is True and "complete_streaming_response" in kwargs
+                ):
+                    if sl_object is not None:
+                        cost_tracking_failure_debug_info: Union[dict, str] = (
+                            sl_object["response_cost_failure_debug_info"]  # type: ignore
+                            or "response_cost_failure_debug_info is None in standard_logging_object"
+                        )
+                    else:
+                        cost_tracking_failure_debug_info = (
+                            "standard_logging_object not found"
+                        )
+                    model = kwargs.get("model")
+                    raise Exception(
+                        f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
+                    )
+        except Exception as e:
+            error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
+            model = kwargs.get("model", "")
+            metadata = kwargs.get("litellm_params", {}).get("metadata", {})
+            error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
+            asyncio.create_task(
+                proxy_logging_obj.failed_tracking_alert(
+                    error_message=error_msg,
+                    failing_model=model,
+                )
+            )
+            verbose_proxy_logger.exception(
+                "Error in tracking cost callback - %s", str(e)
+            )
+
+    @staticmethod
+    def _should_track_errors_in_db():
+        """
+        Returns True if errors should be tracked in the database
+
+        By default, errors are tracked in the database
+
+        If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings
+        """
+        from litellm.proxy.proxy_server import general_settings
+
+        if general_settings.get("disable_error_logs") is True:
+            return False
+        return
+
+
+def _should_track_cost_callback(
+    user_api_key: Optional[str],
+    user_id: Optional[str],
+    team_id: Optional[str],
+    end_user_id: Optional[str],
+) -> bool:
+    """
+    Determine if the cost callback should be tracked based on the kwargs
+    """
+
+    # don't run track cost callback if user opted into disabling spend
+    if ProxyUpdateSpend.disable_spend_updates() is True:
+        return False
+
+    if (
+        user_api_key is not None
+        or user_id is not None
+        or team_id is not None
+        or end_user_id is not None
+    ):
+        return True
+    return False