diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py | 170 |
1 files changed, 170 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py new file mode 100644 index 00000000..f096b026 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py @@ -0,0 +1,170 @@ +""" +Wrapper around router cache. Meant to handle model cooldown logic +""" + +import time +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict + +from litellm import verbose_logger +from litellm.caching.caching import DualCache +from litellm.caching.in_memory_cache import InMemoryCache + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class CooldownCacheValue(TypedDict): + exception_received: str + status_code: str + timestamp: float + cooldown_time: float + + +class CooldownCache: + def __init__(self, cache: DualCache, default_cooldown_time: float): + self.cache = cache + self.default_cooldown_time = default_cooldown_time + self.in_memory_cache = InMemoryCache() + + def _common_add_cooldown_logic( + self, model_id: str, original_exception, exception_status, cooldown_time: float + ) -> Tuple[str, CooldownCacheValue]: + try: + current_time = time.time() + cooldown_key = f"deployment:{model_id}:cooldown" + + # Store the cooldown information for the deployment separately + cooldown_data = CooldownCacheValue( + exception_received=str(original_exception), + status_code=str(exception_status), + timestamp=current_time, + cooldown_time=cooldown_time, + ) + + return cooldown_key, cooldown_data + except Exception as e: + verbose_logger.error( + "CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format( + str(e) + ) + ) + raise e + + def add_deployment_to_cooldown( + self, + model_id: str, + original_exception: Exception, + exception_status: int, + cooldown_time: Optional[float], + ): + try: + _cooldown_time = cooldown_time or self.default_cooldown_time + cooldown_key, cooldown_data = self._common_add_cooldown_logic( + model_id=model_id, + original_exception=original_exception, + exception_status=exception_status, + cooldown_time=_cooldown_time, + ) + + # Set the cache with a TTL equal to the cooldown time + self.cache.set_cache( + value=cooldown_data, + key=cooldown_key, + ttl=_cooldown_time, + ) + except Exception as e: + verbose_logger.error( + "CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format( + str(e) + ) + ) + raise e + + @staticmethod + def get_cooldown_cache_key(model_id: str) -> str: + return f"deployment:{model_id}:cooldown" + + async def async_get_active_cooldowns( + self, model_ids: List[str], parent_otel_span: Optional[Span] + ) -> List[Tuple[str, CooldownCacheValue]]: + # Generate the keys for the deployments + keys = [ + CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids + ] + + # Retrieve the values for the keys using mget + ## more likely to be none if no models ratelimited. So just check redis every 1s + ## each redis call adds ~100ms latency. + + ## check in memory cache first + results = await self.cache.async_batch_get_cache( + keys=keys, parent_otel_span=parent_otel_span + ) + active_cooldowns: List[Tuple[str, CooldownCacheValue]] = [] + + if results is None: + return active_cooldowns + + # Process the results + for model_id, result in zip(model_ids, results): + if result and isinstance(result, dict): + cooldown_cache_value = CooldownCacheValue(**result) # type: ignore + active_cooldowns.append((model_id, cooldown_cache_value)) + + return active_cooldowns + + def get_active_cooldowns( + self, model_ids: List[str], parent_otel_span: Optional[Span] + ) -> List[Tuple[str, CooldownCacheValue]]: + # Generate the keys for the deployments + keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] + # Retrieve the values for the keys using mget + results = ( + self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span) + or [] + ) + + active_cooldowns = [] + # Process the results + for model_id, result in zip(model_ids, results): + if result and isinstance(result, dict): + cooldown_cache_value = CooldownCacheValue(**result) # type: ignore + active_cooldowns.append((model_id, cooldown_cache_value)) + + return active_cooldowns + + def get_min_cooldown( + self, model_ids: List[str], parent_otel_span: Optional[Span] + ) -> float: + """Return min cooldown time required for a group of model id's.""" + + # Generate the keys for the deployments + keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] + + # Retrieve the values for the keys using mget + results = ( + self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span) + or [] + ) + + min_cooldown_time: Optional[float] = None + # Process the results + for model_id, result in zip(model_ids, results): + if result and isinstance(result, dict): + cooldown_cache_value = CooldownCacheValue(**result) # type: ignore + if min_cooldown_time is None: + min_cooldown_time = cooldown_cache_value["cooldown_time"] + elif cooldown_cache_value["cooldown_time"] < min_cooldown_time: + min_cooldown_time = cooldown_cache_value["cooldown_time"] + + return min_cooldown_time or self.default_cooldown_time + + +# Usage example: +# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time) +# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status) +# active_cooldowns = cooldown_cache.get_active_cooldowns() |