aboutsummaryrefslogtreecommitdiff
"""
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
"""

import hashlib
import json
from typing import TYPE_CHECKING, Any, List, Optional, TypedDict

from litellm.caching.caching import DualCache
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam

if TYPE_CHECKING:
    from opentelemetry.trace import Span as _Span

    from litellm.router import Router

    litellm_router = Router
    Span = _Span
else:
    Span = Any
    litellm_router = Any


class PromptCachingCacheValue(TypedDict):
    model_id: str


class PromptCachingCache:
    def __init__(self, cache: DualCache):
        self.cache = cache
        self.in_memory_cache = InMemoryCache()

    @staticmethod
    def serialize_object(obj: Any) -> Any:
        """Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
        if hasattr(obj, "dict"):
            # If the object is a Pydantic model, use its `dict()` method
            return obj.dict()
        elif isinstance(obj, dict):
            # If the object is a dictionary, serialize it with sorted keys
            return json.dumps(
                obj, sort_keys=True, separators=(",", ":")
            )  # Standardize serialization

        elif isinstance(obj, list):
            # Serialize lists by ensuring each element is handled properly
            return [PromptCachingCache.serialize_object(item) for item in obj]
        elif isinstance(obj, (int, float, bool)):
            return obj  # Keep primitive types as-is
        return str(obj)

    @staticmethod
    def get_prompt_caching_cache_key(
        messages: Optional[List[AllMessageValues]],
        tools: Optional[List[ChatCompletionToolParam]],
    ) -> Optional[str]:
        if messages is None and tools is None:
            return None
        # Use serialize_object for consistent and stable serialization
        data_to_hash = {}
        if messages is not None:
            serialized_messages = PromptCachingCache.serialize_object(messages)
            data_to_hash["messages"] = serialized_messages
        if tools is not None:
            serialized_tools = PromptCachingCache.serialize_object(tools)
            data_to_hash["tools"] = serialized_tools

        # Combine serialized data into a single string
        data_to_hash_str = json.dumps(
            data_to_hash,
            sort_keys=True,
            separators=(",", ":"),
        )

        # Create a hash of the serialized data for a stable cache key
        hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
        return f"deployment:{hashed_data}:prompt_caching"

    def add_model_id(
        self,
        model_id: str,
        messages: Optional[List[AllMessageValues]],
        tools: Optional[List[ChatCompletionToolParam]],
    ) -> None:
        if messages is None and tools is None:
            return None

        cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
        self.cache.set_cache(
            cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
        )
        return None

    async def async_add_model_id(
        self,
        model_id: str,
        messages: Optional[List[AllMessageValues]],
        tools: Optional[List[ChatCompletionToolParam]],
    ) -> None:
        if messages is None and tools is None:
            return None

        cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
        await self.cache.async_set_cache(
            cache_key,
            PromptCachingCacheValue(model_id=model_id),
            ttl=300,  # store for 5 minutes
        )
        return None

    async def async_get_model_id(
        self,
        messages: Optional[List[AllMessageValues]],
        tools: Optional[List[ChatCompletionToolParam]],
    ) -> Optional[PromptCachingCacheValue]:
        """
        if messages is not none
        - check full messages
        - check messages[:-1]
        - check messages[:-2]
        - check messages[:-3]

        use self.cache.async_batch_get_cache(keys=potential_cache_keys])
        """
        if messages is None and tools is None:
            return None

        # Generate potential cache keys by slicing messages

        potential_cache_keys = []

        if messages is not None:
            full_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
                messages, tools
            )
            potential_cache_keys.append(full_cache_key)

            # Check progressively shorter message slices
            for i in range(1, min(4, len(messages))):
                partial_messages = messages[:-i]
                partial_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
                    partial_messages, tools
                )
                potential_cache_keys.append(partial_cache_key)

        # Perform batch cache lookup
        cache_results = await self.cache.async_batch_get_cache(
            keys=potential_cache_keys
        )

        if cache_results is None:
            return None

        # Return the first non-None cache result
        for result in cache_results:
            if result is not None:
                return result

        return None

    def get_model_id(
        self,
        messages: Optional[List[AllMessageValues]],
        tools: Optional[List[ChatCompletionToolParam]],
    ) -> Optional[PromptCachingCacheValue]:
        if messages is None and tools is None:
            return None

        cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
        return self.cache.get_cache(cache_key)