about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py170
1 files changed, 170 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py
new file mode 100644
index 00000000..f096b026
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py
@@ -0,0 +1,170 @@
+"""
+Wrapper around router cache. Meant to handle model cooldown logic
+"""
+
+import time
+from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict
+
+from litellm import verbose_logger
+from litellm.caching.caching import DualCache
+from litellm.caching.in_memory_cache import InMemoryCache
+
+if TYPE_CHECKING:
+    from opentelemetry.trace import Span as _Span
+
+    Span = _Span
+else:
+    Span = Any
+
+
+class CooldownCacheValue(TypedDict):
+    exception_received: str
+    status_code: str
+    timestamp: float
+    cooldown_time: float
+
+
+class CooldownCache:
+    def __init__(self, cache: DualCache, default_cooldown_time: float):
+        self.cache = cache
+        self.default_cooldown_time = default_cooldown_time
+        self.in_memory_cache = InMemoryCache()
+
+    def _common_add_cooldown_logic(
+        self, model_id: str, original_exception, exception_status, cooldown_time: float
+    ) -> Tuple[str, CooldownCacheValue]:
+        try:
+            current_time = time.time()
+            cooldown_key = f"deployment:{model_id}:cooldown"
+
+            # Store the cooldown information for the deployment separately
+            cooldown_data = CooldownCacheValue(
+                exception_received=str(original_exception),
+                status_code=str(exception_status),
+                timestamp=current_time,
+                cooldown_time=cooldown_time,
+            )
+
+            return cooldown_key, cooldown_data
+        except Exception as e:
+            verbose_logger.error(
+                "CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format(
+                    str(e)
+                )
+            )
+            raise e
+
+    def add_deployment_to_cooldown(
+        self,
+        model_id: str,
+        original_exception: Exception,
+        exception_status: int,
+        cooldown_time: Optional[float],
+    ):
+        try:
+            _cooldown_time = cooldown_time or self.default_cooldown_time
+            cooldown_key, cooldown_data = self._common_add_cooldown_logic(
+                model_id=model_id,
+                original_exception=original_exception,
+                exception_status=exception_status,
+                cooldown_time=_cooldown_time,
+            )
+
+            # Set the cache with a TTL equal to the cooldown time
+            self.cache.set_cache(
+                value=cooldown_data,
+                key=cooldown_key,
+                ttl=_cooldown_time,
+            )
+        except Exception as e:
+            verbose_logger.error(
+                "CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format(
+                    str(e)
+                )
+            )
+            raise e
+
+    @staticmethod
+    def get_cooldown_cache_key(model_id: str) -> str:
+        return f"deployment:{model_id}:cooldown"
+
+    async def async_get_active_cooldowns(
+        self, model_ids: List[str], parent_otel_span: Optional[Span]
+    ) -> List[Tuple[str, CooldownCacheValue]]:
+        # Generate the keys for the deployments
+        keys = [
+            CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids
+        ]
+
+        # Retrieve the values for the keys using mget
+        ## more likely to be none if no models ratelimited. So just check redis every 1s
+        ## each redis call adds ~100ms latency.
+
+        ## check in memory cache first
+        results = await self.cache.async_batch_get_cache(
+            keys=keys, parent_otel_span=parent_otel_span
+        )
+        active_cooldowns: List[Tuple[str, CooldownCacheValue]] = []
+
+        if results is None:
+            return active_cooldowns
+
+        # Process the results
+        for model_id, result in zip(model_ids, results):
+            if result and isinstance(result, dict):
+                cooldown_cache_value = CooldownCacheValue(**result)  # type: ignore
+                active_cooldowns.append((model_id, cooldown_cache_value))
+
+        return active_cooldowns
+
+    def get_active_cooldowns(
+        self, model_ids: List[str], parent_otel_span: Optional[Span]
+    ) -> List[Tuple[str, CooldownCacheValue]]:
+        # Generate the keys for the deployments
+        keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
+        # Retrieve the values for the keys using mget
+        results = (
+            self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
+            or []
+        )
+
+        active_cooldowns = []
+        # Process the results
+        for model_id, result in zip(model_ids, results):
+            if result and isinstance(result, dict):
+                cooldown_cache_value = CooldownCacheValue(**result)  # type: ignore
+                active_cooldowns.append((model_id, cooldown_cache_value))
+
+        return active_cooldowns
+
+    def get_min_cooldown(
+        self, model_ids: List[str], parent_otel_span: Optional[Span]
+    ) -> float:
+        """Return min cooldown time required for a group of model id's."""
+
+        # Generate the keys for the deployments
+        keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
+
+        # Retrieve the values for the keys using mget
+        results = (
+            self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
+            or []
+        )
+
+        min_cooldown_time: Optional[float] = None
+        # Process the results
+        for model_id, result in zip(model_ids, results):
+            if result and isinstance(result, dict):
+                cooldown_cache_value = CooldownCacheValue(**result)  # type: ignore
+                if min_cooldown_time is None:
+                    min_cooldown_time = cooldown_cache_value["cooldown_time"]
+                elif cooldown_cache_value["cooldown_time"] < min_cooldown_time:
+                    min_cooldown_time = cooldown_cache_value["cooldown_time"]
+
+        return min_cooldown_time or self.default_cooldown_time
+
+
+# Usage example:
+# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time)
+# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status)
+# active_cooldowns = cooldown_cache.get_active_cooldowns()