diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router_utils/prompt_caching_cache.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/router_utils/prompt_caching_cache.py | 171 |
1 files changed, 171 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/prompt_caching_cache.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/prompt_caching_cache.py new file mode 100644 index 00000000..1bf686d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/prompt_caching_cache.py @@ -0,0 +1,171 @@ +""" +Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called. +""" + +import hashlib +import json +from typing import TYPE_CHECKING, Any, List, Optional, TypedDict + +from litellm.caching.caching import DualCache +from litellm.caching.in_memory_cache import InMemoryCache +from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + from litellm.router import Router + + litellm_router = Router + Span = _Span +else: + Span = Any + litellm_router = Any + + +class PromptCachingCacheValue(TypedDict): + model_id: str + + +class PromptCachingCache: + def __init__(self, cache: DualCache): + self.cache = cache + self.in_memory_cache = InMemoryCache() + + @staticmethod + def serialize_object(obj: Any) -> Any: + """Helper function to serialize Pydantic objects, dictionaries, or fallback to string.""" + if hasattr(obj, "dict"): + # If the object is a Pydantic model, use its `dict()` method + return obj.dict() + elif isinstance(obj, dict): + # If the object is a dictionary, serialize it with sorted keys + return json.dumps( + obj, sort_keys=True, separators=(",", ":") + ) # Standardize serialization + + elif isinstance(obj, list): + # Serialize lists by ensuring each element is handled properly + return [PromptCachingCache.serialize_object(item) for item in obj] + elif isinstance(obj, (int, float, bool)): + return obj # Keep primitive types as-is + return str(obj) + + @staticmethod + def get_prompt_caching_cache_key( + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> Optional[str]: + if messages is None and tools is None: + return None + # Use serialize_object for consistent and stable serialization + data_to_hash = {} + if messages is not None: + serialized_messages = PromptCachingCache.serialize_object(messages) + data_to_hash["messages"] = serialized_messages + if tools is not None: + serialized_tools = PromptCachingCache.serialize_object(tools) + data_to_hash["tools"] = serialized_tools + + # Combine serialized data into a single string + data_to_hash_str = json.dumps( + data_to_hash, + sort_keys=True, + separators=(",", ":"), + ) + + # Create a hash of the serialized data for a stable cache key + hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest() + return f"deployment:{hashed_data}:prompt_caching" + + def add_model_id( + self, + model_id: str, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> None: + if messages is None and tools is None: + return None + + cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) + self.cache.set_cache( + cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300 + ) + return None + + async def async_add_model_id( + self, + model_id: str, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> None: + if messages is None and tools is None: + return None + + cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) + await self.cache.async_set_cache( + cache_key, + PromptCachingCacheValue(model_id=model_id), + ttl=300, # store for 5 minutes + ) + return None + + async def async_get_model_id( + self, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> Optional[PromptCachingCacheValue]: + """ + if messages is not none + - check full messages + - check messages[:-1] + - check messages[:-2] + - check messages[:-3] + + use self.cache.async_batch_get_cache(keys=potential_cache_keys]) + """ + if messages is None and tools is None: + return None + + # Generate potential cache keys by slicing messages + + potential_cache_keys = [] + + if messages is not None: + full_cache_key = PromptCachingCache.get_prompt_caching_cache_key( + messages, tools + ) + potential_cache_keys.append(full_cache_key) + + # Check progressively shorter message slices + for i in range(1, min(4, len(messages))): + partial_messages = messages[:-i] + partial_cache_key = PromptCachingCache.get_prompt_caching_cache_key( + partial_messages, tools + ) + potential_cache_keys.append(partial_cache_key) + + # Perform batch cache lookup + cache_results = await self.cache.async_batch_get_cache( + keys=potential_cache_keys + ) + + if cache_results is None: + return None + + # Return the first non-None cache result + for result in cache_results: + if result is not None: + return result + + return None + + def get_model_id( + self, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> Optional[PromptCachingCacheValue]: + if messages is None and tools is None: + return None + + cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) + return self.cache.get_cache(cache_key) |