aboutsummaryrefslogtreecommitdiff
"""
Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously.

Has 4 primary methods:
    - set_cache
    - get_cache
    - async_set_cache
    - async_get_cache
"""

import asyncio
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, List, Optional

import litellm
from litellm._logging import print_verbose, verbose_logger

from .base_cache import BaseCache
from .in_memory_cache import InMemoryCache
from .redis_cache import RedisCache

if TYPE_CHECKING:
    from opentelemetry.trace import Span as _Span

    Span = _Span
else:
    Span = Any

from collections import OrderedDict


class LimitedSizeOrderedDict(OrderedDict):
    def __init__(self, *args, max_size=100, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_size = max_size

    def __setitem__(self, key, value):
        # If inserting a new key exceeds max size, remove the oldest item
        if len(self) >= self.max_size:
            self.popitem(last=False)
        super().__setitem__(key, value)


class DualCache(BaseCache):
    """
    DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
    When data is updated or inserted, it is written to both the in-memory cache + Redis.
    This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
    """

    def __init__(
        self,
        in_memory_cache: Optional[InMemoryCache] = None,
        redis_cache: Optional[RedisCache] = None,
        default_in_memory_ttl: Optional[float] = None,
        default_redis_ttl: Optional[float] = None,
        default_redis_batch_cache_expiry: Optional[float] = None,
        default_max_redis_batch_cache_size: int = 100,
    ) -> None:
        super().__init__()
        # If in_memory_cache is not provided, use the default InMemoryCache
        self.in_memory_cache = in_memory_cache or InMemoryCache()
        # If redis_cache is not provided, use the default RedisCache
        self.redis_cache = redis_cache
        self.last_redis_batch_access_time = LimitedSizeOrderedDict(
            max_size=default_max_redis_batch_cache_size
        )
        self.redis_batch_cache_expiry = (
            default_redis_batch_cache_expiry
            or litellm.default_redis_batch_cache_expiry
            or 10
        )
        self.default_in_memory_ttl = (
            default_in_memory_ttl or litellm.default_in_memory_ttl
        )
        self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl

    def update_cache_ttl(
        self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float]
    ):
        if default_in_memory_ttl is not None:
            self.default_in_memory_ttl = default_in_memory_ttl

        if default_redis_ttl is not None:
            self.default_redis_ttl = default_redis_ttl

    def set_cache(self, key, value, local_only: bool = False, **kwargs):
        # Update both Redis and in-memory cache
        try:
            if self.in_memory_cache is not None:
                if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
                    kwargs["ttl"] = self.default_in_memory_ttl

                self.in_memory_cache.set_cache(key, value, **kwargs)

            if self.redis_cache is not None and local_only is False:
                self.redis_cache.set_cache(key, value, **kwargs)
        except Exception as e:
            print_verbose(e)

    def increment_cache(
        self, key, value: int, local_only: bool = False, **kwargs
    ) -> int:
        """
        Key - the key in cache

        Value - int - the value you want to increment by

        Returns - int - the incremented value
        """
        try:
            result: int = value
            if self.in_memory_cache is not None:
                result = self.in_memory_cache.increment_cache(key, value, **kwargs)

            if self.redis_cache is not None and local_only is False:
                result = self.redis_cache.increment_cache(key, value, **kwargs)

            return result
        except Exception as e:
            verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
            raise e

    def get_cache(
        self,
        key,
        parent_otel_span: Optional[Span] = None,
        local_only: bool = False,
        **kwargs,
    ):
        # Try to fetch from in-memory cache first
        try:
            result = None
            if self.in_memory_cache is not None:
                in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)

                if in_memory_result is not None:
                    result = in_memory_result

            if result is None and self.redis_cache is not None and local_only is False:
                # If not found in in-memory cache, try fetching from Redis
                redis_result = self.redis_cache.get_cache(
                    key, parent_otel_span=parent_otel_span
                )

                if redis_result is not None:
                    # Update in-memory cache with the value from Redis
                    self.in_memory_cache.set_cache(key, redis_result, **kwargs)

                result = redis_result

            print_verbose(f"get cache: cache result: {result}")
            return result
        except Exception:
            verbose_logger.error(traceback.format_exc())

    def batch_get_cache(
        self,
        keys: list,
        parent_otel_span: Optional[Span] = None,
        local_only: bool = False,
        **kwargs,
    ):
        received_args = locals()
        received_args.pop("self")

        def run_in_new_loop():
            """Run the coroutine in a new event loop within this thread."""
            new_loop = asyncio.new_event_loop()
            try:
                asyncio.set_event_loop(new_loop)
                return new_loop.run_until_complete(
                    self.async_batch_get_cache(**received_args)
                )
            finally:
                new_loop.close()
                asyncio.set_event_loop(None)

        try:
            # First, try to get the current event loop
            _ = asyncio.get_running_loop()
            # If we're already in an event loop, run in a separate thread
            # to avoid nested event loop issues
            with ThreadPoolExecutor(max_workers=1) as executor:
                future = executor.submit(run_in_new_loop)
                return future.result()

        except RuntimeError:
            # No running event loop, we can safely run in this thread
            return run_in_new_loop()

    async def async_get_cache(
        self,
        key,
        parent_otel_span: Optional[Span] = None,
        local_only: bool = False,
        **kwargs,
    ):
        # Try to fetch from in-memory cache first
        try:
            print_verbose(
                f"async get cache: cache key: {key}; local_only: {local_only}"
            )
            result = None
            if self.in_memory_cache is not None:
                in_memory_result = await self.in_memory_cache.async_get_cache(
                    key, **kwargs
                )

                print_verbose(f"in_memory_result: {in_memory_result}")
                if in_memory_result is not None:
                    result = in_memory_result

            if result is None and self.redis_cache is not None and local_only is False:
                # If not found in in-memory cache, try fetching from Redis
                redis_result = await self.redis_cache.async_get_cache(
                    key, parent_otel_span=parent_otel_span
                )

                if redis_result is not None:
                    # Update in-memory cache with the value from Redis
                    await self.in_memory_cache.async_set_cache(
                        key, redis_result, **kwargs
                    )

                result = redis_result

            print_verbose(f"get cache: cache result: {result}")
            return result
        except Exception:
            verbose_logger.error(traceback.format_exc())

    def get_redis_batch_keys(
        self,
        current_time: float,
        keys: List[str],
        result: List[Any],
    ) -> List[str]:
        sublist_keys = []
        for key, value in zip(keys, result):
            if value is None:
                if (
                    key not in self.last_redis_batch_access_time
                    or current_time - self.last_redis_batch_access_time[key]
                    >= self.redis_batch_cache_expiry
                ):
                    sublist_keys.append(key)
        return sublist_keys

    async def async_batch_get_cache(
        self,
        keys: list,
        parent_otel_span: Optional[Span] = None,
        local_only: bool = False,
        **kwargs,
    ):
        try:
            result = [None for _ in range(len(keys))]
            if self.in_memory_cache is not None:
                in_memory_result = await self.in_memory_cache.async_batch_get_cache(
                    keys, **kwargs
                )

                if in_memory_result is not None:
                    result = in_memory_result

            if None in result and self.redis_cache is not None and local_only is False:
                """
                - for the none values in the result
                - check the redis cache
                """
                current_time = time.time()
                sublist_keys = self.get_redis_batch_keys(current_time, keys, result)

                # Only hit Redis if the last access time was more than 5 seconds ago
                if len(sublist_keys) > 0:
                    # If not found in in-memory cache, try fetching from Redis
                    redis_result = await self.redis_cache.async_batch_get_cache(
                        sublist_keys, parent_otel_span=parent_otel_span
                    )

                    if redis_result is not None:
                        # Update in-memory cache with the value from Redis
                        for key, value in redis_result.items():
                            if value is not None:
                                await self.in_memory_cache.async_set_cache(
                                    key, redis_result[key], **kwargs
                                )
                            # Update the last access time for each key fetched from Redis
                            self.last_redis_batch_access_time[key] = current_time

                    for key, value in redis_result.items():
                        index = keys.index(key)
                        result[index] = value

            return result
        except Exception:
            verbose_logger.error(traceback.format_exc())

    async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
        print_verbose(
            f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
        )
        try:
            if self.in_memory_cache is not None:
                await self.in_memory_cache.async_set_cache(key, value, **kwargs)

            if self.redis_cache is not None and local_only is False:
                await self.redis_cache.async_set_cache(key, value, **kwargs)
        except Exception as e:
            verbose_logger.exception(
                f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
            )

    # async_batch_set_cache
    async def async_set_cache_pipeline(
        self, cache_list: list, local_only: bool = False, **kwargs
    ):
        """
        Batch write values to the cache
        """
        print_verbose(
            f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
        )
        try:
            if self.in_memory_cache is not None:
                await self.in_memory_cache.async_set_cache_pipeline(
                    cache_list=cache_list, **kwargs
                )

            if self.redis_cache is not None and local_only is False:
                await self.redis_cache.async_set_cache_pipeline(
                    cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
                )
        except Exception as e:
            verbose_logger.exception(
                f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
            )

    async def async_increment_cache(
        self,
        key,
        value: float,
        parent_otel_span: Optional[Span] = None,
        local_only: bool = False,
        **kwargs,
    ) -> float:
        """
        Key - the key in cache

        Value - float - the value you want to increment by

        Returns - float - the incremented value
        """
        try:
            result: float = value
            if self.in_memory_cache is not None:
                result = await self.in_memory_cache.async_increment(
                    key, value, **kwargs
                )

            if self.redis_cache is not None and local_only is False:
                result = await self.redis_cache.async_increment(
                    key,
                    value,
                    parent_otel_span=parent_otel_span,
                    ttl=kwargs.get("ttl", None),
                )

            return result
        except Exception as e:
            raise e  # don't log if exception is raised

    async def async_set_cache_sadd(
        self, key, value: List, local_only: bool = False, **kwargs
    ) -> None:
        """
        Add value to a set

        Key - the key in cache

        Value - str - the value you want to add to the set

        Returns - None
        """
        try:
            if self.in_memory_cache is not None:
                _ = await self.in_memory_cache.async_set_cache_sadd(
                    key, value, ttl=kwargs.get("ttl", None)
                )

            if self.redis_cache is not None and local_only is False:
                _ = await self.redis_cache.async_set_cache_sadd(
                    key, value, ttl=kwargs.get("ttl", None)
                )

            return None
        except Exception as e:
            raise e  # don't log, if exception is raised

    def flush_cache(self):
        if self.in_memory_cache is not None:
            self.in_memory_cache.flush_cache()
        if self.redis_cache is not None:
            self.redis_cache.flush_cache()

    def delete_cache(self, key):
        """
        Delete a key from the cache
        """
        if self.in_memory_cache is not None:
            self.in_memory_cache.delete_cache(key)
        if self.redis_cache is not None:
            self.redis_cache.delete_cache(key)

    async def async_delete_cache(self, key: str):
        """
        Delete a key from the cache
        """
        if self.in_memory_cache is not None:
            self.in_memory_cache.delete_cache(key)
        if self.redis_cache is not None:
            await self.redis_cache.async_delete_cache(key)

    async def async_get_ttl(self, key: str) -> Optional[int]:
        """
        Get the remaining TTL of a key in in-memory cache or redis
        """
        ttl = await self.in_memory_cache.async_get_ttl(key)
        if ttl is None and self.redis_cache is not None:
            ttl = await self.redis_cache.async_get_ttl(key)
        return ttl