aboutsummaryrefslogtreecommitdiff
"""
Wrapper around router cache. Meant to handle model cooldown logic
"""

import time
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict

from litellm import verbose_logger
from litellm.caching.caching import DualCache
from litellm.caching.in_memory_cache import InMemoryCache

if TYPE_CHECKING:
    from opentelemetry.trace import Span as _Span

    Span = _Span
else:
    Span = Any


class CooldownCacheValue(TypedDict):
    exception_received: str
    status_code: str
    timestamp: float
    cooldown_time: float


class CooldownCache:
    def __init__(self, cache: DualCache, default_cooldown_time: float):
        self.cache = cache
        self.default_cooldown_time = default_cooldown_time
        self.in_memory_cache = InMemoryCache()

    def _common_add_cooldown_logic(
        self, model_id: str, original_exception, exception_status, cooldown_time: float
    ) -> Tuple[str, CooldownCacheValue]:
        try:
            current_time = time.time()
            cooldown_key = f"deployment:{model_id}:cooldown"

            # Store the cooldown information for the deployment separately
            cooldown_data = CooldownCacheValue(
                exception_received=str(original_exception),
                status_code=str(exception_status),
                timestamp=current_time,
                cooldown_time=cooldown_time,
            )

            return cooldown_key, cooldown_data
        except Exception as e:
            verbose_logger.error(
                "CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format(
                    str(e)
                )
            )
            raise e

    def add_deployment_to_cooldown(
        self,
        model_id: str,
        original_exception: Exception,
        exception_status: int,
        cooldown_time: Optional[float],
    ):
        try:
            _cooldown_time = cooldown_time or self.default_cooldown_time
            cooldown_key, cooldown_data = self._common_add_cooldown_logic(
                model_id=model_id,
                original_exception=original_exception,
                exception_status=exception_status,
                cooldown_time=_cooldown_time,
            )

            # Set the cache with a TTL equal to the cooldown time
            self.cache.set_cache(
                value=cooldown_data,
                key=cooldown_key,
                ttl=_cooldown_time,
            )
        except Exception as e:
            verbose_logger.error(
                "CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format(
                    str(e)
                )
            )
            raise e

    @staticmethod
    def get_cooldown_cache_key(model_id: str) -> str:
        return f"deployment:{model_id}:cooldown"

    async def async_get_active_cooldowns(
        self, model_ids: List[str], parent_otel_span: Optional[Span]
    ) -> List[Tuple[str, CooldownCacheValue]]:
        # Generate the keys for the deployments
        keys = [
            CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids
        ]

        # Retrieve the values for the keys using mget
        ## more likely to be none if no models ratelimited. So just check redis every 1s
        ## each redis call adds ~100ms latency.

        ## check in memory cache first
        results = await self.cache.async_batch_get_cache(
            keys=keys, parent_otel_span=parent_otel_span
        )
        active_cooldowns: List[Tuple[str, CooldownCacheValue]] = []

        if results is None:
            return active_cooldowns

        # Process the results
        for model_id, result in zip(model_ids, results):
            if result and isinstance(result, dict):
                cooldown_cache_value = CooldownCacheValue(**result)  # type: ignore
                active_cooldowns.append((model_id, cooldown_cache_value))

        return active_cooldowns

    def get_active_cooldowns(
        self, model_ids: List[str], parent_otel_span: Optional[Span]
    ) -> List[Tuple[str, CooldownCacheValue]]:
        # Generate the keys for the deployments
        keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
        # Retrieve the values for the keys using mget
        results = (
            self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
            or []
        )

        active_cooldowns = []
        # Process the results
        for model_id, result in zip(model_ids, results):
            if result and isinstance(result, dict):
                cooldown_cache_value = CooldownCacheValue(**result)  # type: ignore
                active_cooldowns.append((model_id, cooldown_cache_value))

        return active_cooldowns

    def get_min_cooldown(
        self, model_ids: List[str], parent_otel_span: Optional[Span]
    ) -> float:
        """Return min cooldown time required for a group of model id's."""

        # Generate the keys for the deployments
        keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]

        # Retrieve the values for the keys using mget
        results = (
            self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
            or []
        )

        min_cooldown_time: Optional[float] = None
        # Process the results
        for model_id, result in zip(model_ids, results):
            if result and isinstance(result, dict):
                cooldown_cache_value = CooldownCacheValue(**result)  # type: ignore
                if min_cooldown_time is None:
                    min_cooldown_time = cooldown_cache_value["cooldown_time"]
                elif cooldown_cache_value["cooldown_time"] < min_cooldown_time:
                    min_cooldown_time = cooldown_cache_value["cooldown_time"]

        return min_cooldown_time or self.default_cooldown_time


# Usage example:
# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time)
# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status)
# active_cooldowns = cooldown_cache.get_active_cooldowns()