diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/caching | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/caching')
15 files changed, 4636 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/Readme.md b/.venv/lib/python3.12/site-packages/litellm/caching/Readme.md new file mode 100644 index 00000000..6b0210a6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/Readme.md @@ -0,0 +1,40 @@ +# Caching on LiteLLM + +LiteLLM supports multiple caching mechanisms. This allows users to choose the most suitable caching solution for their use case. + +The following caching mechanisms are supported: + +1. **RedisCache** +2. **RedisSemanticCache** +3. **QdrantSemanticCache** +4. **InMemoryCache** +5. **DiskCache** +6. **S3Cache** +7. **DualCache** (updates both Redis and an in-memory cache simultaneously) + +## Folder Structure + +``` +litellm/caching/ +├── base_cache.py +├── caching.py +├── caching_handler.py +├── disk_cache.py +├── dual_cache.py +├── in_memory_cache.py +├── qdrant_semantic_cache.py +├── redis_cache.py +├── redis_semantic_cache.py +├── s3_cache.py +``` + +## Documentation +- [Caching on LiteLLM Gateway](https://docs.litellm.ai/docs/proxy/caching) +- [Caching on LiteLLM Python](https://docs.litellm.ai/docs/caching/all_caches) + + + + + + + diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/__init__.py b/.venv/lib/python3.12/site-packages/litellm/caching/__init__.py new file mode 100644 index 00000000..e10d01ff --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/__init__.py @@ -0,0 +1,9 @@ +from .caching import Cache, LiteLLMCacheType +from .disk_cache import DiskCache +from .dual_cache import DualCache +from .in_memory_cache import InMemoryCache +from .qdrant_semantic_cache import QdrantSemanticCache +from .redis_cache import RedisCache +from .redis_cluster_cache import RedisClusterCache +from .redis_semantic_cache import RedisSemanticCache +from .s3_cache import S3Cache diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/_internal_lru_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/_internal_lru_cache.py new file mode 100644 index 00000000..54b0fe96 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/_internal_lru_cache.py @@ -0,0 +1,30 @@ +from functools import lru_cache +from typing import Callable, Optional, TypeVar + +T = TypeVar("T") + + +def lru_cache_wrapper( + maxsize: Optional[int] = None, +) -> Callable[[Callable[..., T]], Callable[..., T]]: + """ + Wrapper for lru_cache that caches success and exceptions + """ + + def decorator(f: Callable[..., T]) -> Callable[..., T]: + @lru_cache(maxsize=maxsize) + def wrapper(*args, **kwargs): + try: + return ("success", f(*args, **kwargs)) + except Exception as e: + return ("error", e) + + def wrapped(*args, **kwargs): + result = wrapper(*args, **kwargs) + if result[0] == "error": + raise result[1] + return result[1] + + return wrapped + + return decorator diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/base_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/base_cache.py new file mode 100644 index 00000000..7109951d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/base_cache.py @@ -0,0 +1,55 @@ +""" +Base Cache implementation. All cache implementations should inherit from this class. + +Has 4 methods: + - set_cache + - get_cache + - async_set_cache + - async_get_cache +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class BaseCache(ABC): + def __init__(self, default_ttl: int = 60): + self.default_ttl = default_ttl + + def get_ttl(self, **kwargs) -> Optional[int]: + kwargs_ttl: Optional[int] = kwargs.get("ttl") + if kwargs_ttl is not None: + try: + return int(kwargs_ttl) + except ValueError: + return self.default_ttl + return self.default_ttl + + def set_cache(self, key, value, **kwargs): + raise NotImplementedError + + async def async_set_cache(self, key, value, **kwargs): + raise NotImplementedError + + @abstractmethod + async def async_set_cache_pipeline(self, cache_list, **kwargs): + pass + + def get_cache(self, key, **kwargs): + raise NotImplementedError + + async def async_get_cache(self, key, **kwargs): + raise NotImplementedError + + async def batch_cache_write(self, key, value, **kwargs): + raise NotImplementedError + + async def disconnect(self): + raise NotImplementedError diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/caching.py b/.venv/lib/python3.12/site-packages/litellm/caching/caching.py new file mode 100644 index 00000000..415c49ed --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/caching.py @@ -0,0 +1,797 @@ +# +-----------------------------------------------+ +# | | +# | Give Feedback / Get Help | +# | https://github.com/BerriAI/litellm/issues/new | +# | | +# +-----------------------------------------------+ +# +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import ast +import hashlib +import json +import time +import traceback +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel + +import litellm +from litellm._logging import verbose_logger +from litellm.litellm_core_utils.model_param_helper import ModelParamHelper +from litellm.types.caching import * +from litellm.types.utils import all_litellm_params + +from .base_cache import BaseCache +from .disk_cache import DiskCache +from .dual_cache import DualCache # noqa +from .in_memory_cache import InMemoryCache +from .qdrant_semantic_cache import QdrantSemanticCache +from .redis_cache import RedisCache +from .redis_cluster_cache import RedisClusterCache +from .redis_semantic_cache import RedisSemanticCache +from .s3_cache import S3Cache + + +def print_verbose(print_statement): + try: + verbose_logger.debug(print_statement) + if litellm.set_verbose: + print(print_statement) # noqa + except Exception: + pass + + +class CacheMode(str, Enum): + default_on = "default_on" + default_off = "default_off" + + +#### LiteLLM.Completion / Embedding Cache #### +class Cache: + def __init__( + self, + type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, + mode: Optional[ + CacheMode + ] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in + host: Optional[str] = None, + port: Optional[str] = None, + password: Optional[str] = None, + namespace: Optional[str] = None, + ttl: Optional[float] = None, + default_in_memory_ttl: Optional[float] = None, + default_in_redis_ttl: Optional[float] = None, + similarity_threshold: Optional[float] = None, + supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + "atext_completion", + "text_completion", + "arerank", + "rerank", + ], + # s3 Bucket, boto3 configuration + s3_bucket_name: Optional[str] = None, + s3_region_name: Optional[str] = None, + s3_api_version: Optional[str] = None, + s3_use_ssl: Optional[bool] = True, + s3_verify: Optional[Union[bool, str]] = None, + s3_endpoint_url: Optional[str] = None, + s3_aws_access_key_id: Optional[str] = None, + s3_aws_secret_access_key: Optional[str] = None, + s3_aws_session_token: Optional[str] = None, + s3_config: Optional[Any] = None, + s3_path: Optional[str] = None, + redis_semantic_cache_use_async=False, + redis_semantic_cache_embedding_model="text-embedding-ada-002", + redis_flush_size: Optional[int] = None, + redis_startup_nodes: Optional[List] = None, + disk_cache_dir=None, + qdrant_api_base: Optional[str] = None, + qdrant_api_key: Optional[str] = None, + qdrant_collection_name: Optional[str] = None, + qdrant_quantization_config: Optional[str] = None, + qdrant_semantic_cache_embedding_model="text-embedding-ada-002", + **kwargs, + ): + """ + Initializes the cache based on the given type. + + Args: + type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local". + + # Redis Cache Args + host (str, optional): The host address for the Redis cache. Required if type is "redis". + port (int, optional): The port number for the Redis cache. Required if type is "redis". + password (str, optional): The password for the Redis cache. Required if type is "redis". + namespace (str, optional): The namespace for the Redis cache. Required if type is "redis". + ttl (float, optional): The ttl for the Redis cache + redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used. + redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None. + + # Qdrant Cache Args + qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic". + qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster. + qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic". + similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic". + + # Disk Cache Args + disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None. + + # S3 Cache Args + s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None. + s3_region_name (str, optional): The region name for the s3 cache. Defaults to None. + s3_api_version (str, optional): The api version for the s3 cache. Defaults to None. + s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True. + s3_verify (bool, optional): The verify for the s3 cache. Defaults to None. + s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None. + s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None. + s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None. + s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None. + s3_config (dict, optional): The config for the s3 cache. Defaults to None. + + # Common Cache Args + supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types. + **kwargs: Additional keyword arguments for redis.Redis() cache + + Raises: + ValueError: If an invalid cache type is provided. + + Returns: + None. Cache is set as a litellm param + """ + if type == LiteLLMCacheType.REDIS: + if redis_startup_nodes: + self.cache: BaseCache = RedisClusterCache( + host=host, + port=port, + password=password, + redis_flush_size=redis_flush_size, + startup_nodes=redis_startup_nodes, + **kwargs, + ) + else: + self.cache = RedisCache( + host=host, + port=port, + password=password, + redis_flush_size=redis_flush_size, + **kwargs, + ) + elif type == LiteLLMCacheType.REDIS_SEMANTIC: + self.cache = RedisSemanticCache( + host=host, + port=port, + password=password, + similarity_threshold=similarity_threshold, + use_async=redis_semantic_cache_use_async, + embedding_model=redis_semantic_cache_embedding_model, + **kwargs, + ) + elif type == LiteLLMCacheType.QDRANT_SEMANTIC: + self.cache = QdrantSemanticCache( + qdrant_api_base=qdrant_api_base, + qdrant_api_key=qdrant_api_key, + collection_name=qdrant_collection_name, + similarity_threshold=similarity_threshold, + quantization_config=qdrant_quantization_config, + embedding_model=qdrant_semantic_cache_embedding_model, + ) + elif type == LiteLLMCacheType.LOCAL: + self.cache = InMemoryCache() + elif type == LiteLLMCacheType.S3: + self.cache = S3Cache( + s3_bucket_name=s3_bucket_name, + s3_region_name=s3_region_name, + s3_api_version=s3_api_version, + s3_use_ssl=s3_use_ssl, + s3_verify=s3_verify, + s3_endpoint_url=s3_endpoint_url, + s3_aws_access_key_id=s3_aws_access_key_id, + s3_aws_secret_access_key=s3_aws_secret_access_key, + s3_aws_session_token=s3_aws_session_token, + s3_config=s3_config, + s3_path=s3_path, + **kwargs, + ) + elif type == LiteLLMCacheType.DISK: + self.cache = DiskCache(disk_cache_dir=disk_cache_dir) + if "cache" not in litellm.input_callback: + litellm.input_callback.append("cache") + if "cache" not in litellm.success_callback: + litellm.logging_callback_manager.add_litellm_success_callback("cache") + if "cache" not in litellm._async_success_callback: + litellm.logging_callback_manager.add_litellm_async_success_callback("cache") + self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"] + self.type = type + self.namespace = namespace + self.redis_flush_size = redis_flush_size + self.ttl = ttl + self.mode: CacheMode = mode or CacheMode.default_on + + if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None: + self.ttl = default_in_memory_ttl + + if ( + self.type == LiteLLMCacheType.REDIS + or self.type == LiteLLMCacheType.REDIS_SEMANTIC + ) and default_in_redis_ttl is not None: + self.ttl = default_in_redis_ttl + + if self.namespace is not None and isinstance(self.cache, RedisCache): + self.cache.namespace = self.namespace + + def get_cache_key(self, **kwargs) -> str: + """ + Get the cache key for the given arguments. + + Args: + **kwargs: kwargs to litellm.completion() or embedding() + + Returns: + str: The cache key generated from the arguments, or None if no cache key could be generated. + """ + cache_key = "" + # verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs) + + preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs) + if preset_cache_key is not None: + verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key) + return preset_cache_key + + combined_kwargs = ModelParamHelper._get_all_llm_api_params() + litellm_param_kwargs = all_litellm_params + for param in kwargs: + if param in combined_kwargs: + param_value: Optional[str] = self._get_param_value(param, kwargs) + if param_value is not None: + cache_key += f"{str(param)}: {str(param_value)}" + elif ( + param not in litellm_param_kwargs + ): # check if user passed in optional param - e.g. top_k + if ( + litellm.enable_caching_on_provider_specific_optional_params is True + ): # feature flagged for now + if kwargs[param] is None: + continue # ignore None params + param_value = kwargs[param] + cache_key += f"{str(param)}: {str(param_value)}" + + verbose_logger.debug("\nCreated cache key: %s", cache_key) + hashed_cache_key = Cache._get_hashed_cache_key(cache_key) + hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs) + self._set_preset_cache_key_in_kwargs( + preset_cache_key=hashed_cache_key, **kwargs + ) + return hashed_cache_key + + def _get_param_value( + self, + param: str, + kwargs: dict, + ) -> Optional[str]: + """ + Get the value for the given param from kwargs + """ + if param == "model": + return self._get_model_param_value(kwargs) + elif param == "file": + return self._get_file_param_value(kwargs) + return kwargs[param] + + def _get_model_param_value(self, kwargs: dict) -> str: + """ + Handles getting the value for the 'model' param from kwargs + + 1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups + 2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router() + 3. Else use the `model` passed in kwargs + """ + metadata: Dict = kwargs.get("metadata", {}) or {} + litellm_params: Dict = kwargs.get("litellm_params", {}) or {} + metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {} + model_group: Optional[str] = metadata.get( + "model_group" + ) or metadata_in_litellm_params.get("model_group") + caching_group = self._get_caching_group(metadata, model_group) + return caching_group or model_group or kwargs["model"] + + def _get_caching_group( + self, metadata: dict, model_group: Optional[str] + ) -> Optional[str]: + caching_groups: Optional[List] = metadata.get("caching_groups", []) + if caching_groups: + for group in caching_groups: + if model_group in group: + return str(group) + return None + + def _get_file_param_value(self, kwargs: dict) -> str: + """ + Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests + """ + file = kwargs.get("file") + metadata = kwargs.get("metadata", {}) + litellm_params = kwargs.get("litellm_params", {}) + return ( + metadata.get("file_checksum") + or getattr(file, "name", None) + or metadata.get("file_name") + or litellm_params.get("file_name") + ) + + def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]: + """ + Get the preset cache key from kwargs["litellm_params"] + + We use _get_preset_cache_keys for two reasons + + 1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens + 2. avoid doing duplicate / repeated work + """ + if kwargs: + if "litellm_params" in kwargs: + return kwargs["litellm_params"].get("preset_cache_key", None) + return None + + def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None: + """ + Set the calculated cache key in kwargs + + This is used to avoid doing duplicate / repeated work + + Placed in kwargs["litellm_params"] + """ + if kwargs: + if "litellm_params" in kwargs: + kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key + + @staticmethod + def _get_hashed_cache_key(cache_key: str) -> str: + """ + Get the hashed cache key for the given cache key. + + Use hashlib to create a sha256 hash of the cache key + + Args: + cache_key (str): The cache key to hash. + + Returns: + str: The hashed cache key. + """ + hash_object = hashlib.sha256(cache_key.encode()) + # Hexadecimal representation of the hash + hash_hex = hash_object.hexdigest() + verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex) + return hash_hex + + def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str: + """ + If a redis namespace is provided, add it to the cache key + + Args: + hash_hex (str): The hashed cache key. + **kwargs: Additional keyword arguments. + + Returns: + str: The final hashed cache key with the redis namespace. + """ + dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {}) + namespace = ( + dynamic_cache_control.get("namespace") + or kwargs.get("metadata", {}).get("redis_namespace") + or self.namespace + ) + if namespace: + hash_hex = f"{namespace}:{hash_hex}" + verbose_logger.debug("Final hashed key: %s", hash_hex) + return hash_hex + + def generate_streaming_content(self, content): + chunk_size = 5 # Adjust the chunk size as needed + for i in range(0, len(content), chunk_size): + yield { + "choices": [ + { + "delta": { + "role": "assistant", + "content": content[i : i + chunk_size], + } + } + ] + } + time.sleep(0.02) + + def _get_cache_logic( + self, + cached_result: Optional[Any], + max_age: Optional[float], + ): + """ + Common get cache logic across sync + async implementations + """ + # Check if a timestamp was stored with the cached response + if ( + cached_result is not None + and isinstance(cached_result, dict) + and "timestamp" in cached_result + ): + timestamp = cached_result["timestamp"] + current_time = time.time() + + # Calculate age of the cached response + response_age = current_time - timestamp + + # Check if the cached response is older than the max-age + if max_age is not None and response_age > max_age: + return None # Cached response is too old + + # If the response is fresh, or there's no max-age requirement, return the cached response + # cached_response is in `b{} convert it to ModelResponse + cached_response = cached_result.get("response") + try: + if isinstance(cached_response, dict): + pass + else: + cached_response = json.loads( + cached_response # type: ignore + ) # Convert string to dictionary + except Exception: + cached_response = ast.literal_eval(cached_response) # type: ignore + return cached_response + return cached_result + + def get_cache(self, **kwargs): + """ + Retrieves the cached result for the given arguments. + + Args: + *args: args to litellm.completion() or embedding() + **kwargs: kwargs to litellm.completion() or embedding() + + Returns: + The cached result if it exists, otherwise None. + """ + try: # never block execution + if self.should_use_cache(**kwargs) is not True: + return + messages = kwargs.get("messages", []) + if "cache_key" in kwargs: + cache_key = kwargs["cache_key"] + else: + cache_key = self.get_cache_key(**kwargs) + if cache_key is not None: + cache_control_args: DynamicCacheControl = kwargs.get("cache", {}) + max_age = ( + cache_control_args.get("s-maxage") + or cache_control_args.get("s-max-age") + or float("inf") + ) + cached_result = self.cache.get_cache(cache_key, messages=messages) + cached_result = self.cache.get_cache(cache_key, messages=messages) + return self._get_cache_logic( + cached_result=cached_result, max_age=max_age + ) + except Exception: + print_verbose(f"An exception occurred: {traceback.format_exc()}") + return None + + async def async_get_cache(self, **kwargs): + """ + Async get cache implementation. + + Used for embedding calls in async wrapper + """ + + try: # never block execution + if self.should_use_cache(**kwargs) is not True: + return + + kwargs.get("messages", []) + if "cache_key" in kwargs: + cache_key = kwargs["cache_key"] + else: + cache_key = self.get_cache_key(**kwargs) + if cache_key is not None: + cache_control_args = kwargs.get("cache", {}) + max_age = cache_control_args.get( + "s-max-age", cache_control_args.get("s-maxage", float("inf")) + ) + cached_result = await self.cache.async_get_cache(cache_key, **kwargs) + return self._get_cache_logic( + cached_result=cached_result, max_age=max_age + ) + except Exception: + print_verbose(f"An exception occurred: {traceback.format_exc()}") + return None + + def _add_cache_logic(self, result, **kwargs): + """ + Common implementation across sync + async add_cache functions + """ + try: + if "cache_key" in kwargs: + cache_key = kwargs["cache_key"] + else: + cache_key = self.get_cache_key(**kwargs) + if cache_key is not None: + if isinstance(result, BaseModel): + result = result.model_dump_json() + + ## DEFAULT TTL ## + if self.ttl is not None: + kwargs["ttl"] = self.ttl + ## Get Cache-Controls ## + _cache_kwargs = kwargs.get("cache", None) + if isinstance(_cache_kwargs, dict): + for k, v in _cache_kwargs.items(): + if k == "ttl": + kwargs["ttl"] = v + + cached_data = {"timestamp": time.time(), "response": result} + return cache_key, cached_data, kwargs + else: + raise Exception("cache key is None") + except Exception as e: + raise e + + def add_cache(self, result, **kwargs): + """ + Adds a result to the cache. + + Args: + *args: args to litellm.completion() or embedding() + **kwargs: kwargs to litellm.completion() or embedding() + + Returns: + None + """ + try: + if self.should_use_cache(**kwargs) is not True: + return + cache_key, cached_data, kwargs = self._add_cache_logic( + result=result, **kwargs + ) + self.cache.set_cache(cache_key, cached_data, **kwargs) + except Exception as e: + verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + + async def async_add_cache(self, result, **kwargs): + """ + Async implementation of add_cache + """ + try: + if self.should_use_cache(**kwargs) is not True: + return + if self.type == "redis" and self.redis_flush_size is not None: + # high traffic - fill in results in memory and then flush + await self.batch_cache_write(result, **kwargs) + else: + cache_key, cached_data, kwargs = self._add_cache_logic( + result=result, **kwargs + ) + + await self.cache.async_set_cache(cache_key, cached_data, **kwargs) + except Exception as e: + verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + + async def async_add_cache_pipeline(self, result, **kwargs): + """ + Async implementation of add_cache for Embedding calls + + Does a bulk write, to prevent using too many clients + """ + try: + if self.should_use_cache(**kwargs) is not True: + return + + # set default ttl if not set + if self.ttl is not None: + kwargs["ttl"] = self.ttl + + cache_list = [] + for idx, i in enumerate(kwargs["input"]): + preset_cache_key = self.get_cache_key(**{**kwargs, "input": i}) + kwargs["cache_key"] = preset_cache_key + embedding_response = result.data[idx] + cache_key, cached_data, kwargs = self._add_cache_logic( + result=embedding_response, + **kwargs, + ) + cache_list.append((cache_key, cached_data)) + + await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs) + # if async_set_cache_pipeline: + # await async_set_cache_pipeline(cache_list=cache_list, **kwargs) + # else: + # tasks = [] + # for val in cache_list: + # tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs)) + # await asyncio.gather(*tasks) + except Exception as e: + verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + + def should_use_cache(self, **kwargs): + """ + Returns true if we should use the cache for LLM API calls + + If cache is default_on then this is True + If cache is default_off then this is only true when user has opted in to use cache + """ + if self.mode == CacheMode.default_on: + return True + + # when mode == default_off -> Cache is opt in only + _cache = kwargs.get("cache", None) + verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache) + if _cache and isinstance(_cache, dict): + if _cache.get("use-cache", False) is True: + return True + return False + + async def batch_cache_write(self, result, **kwargs): + cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs) + await self.cache.batch_cache_write(cache_key, cached_data, **kwargs) + + async def ping(self): + cache_ping = getattr(self.cache, "ping") + if cache_ping: + return await cache_ping() + return None + + async def delete_cache_keys(self, keys): + cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys") + if cache_delete_cache_keys: + return await cache_delete_cache_keys(keys) + return None + + async def disconnect(self): + if hasattr(self.cache, "disconnect"): + await self.cache.disconnect() + + def _supports_async(self) -> bool: + """ + Internal method to check if the cache type supports async get/set operations + + Only S3 Cache Does NOT support async operations + + """ + if self.type and self.type == LiteLLMCacheType.S3: + return False + return True + + +def enable_cache( + type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, + host: Optional[str] = None, + port: Optional[str] = None, + password: Optional[str] = None, + supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + "atext_completion", + "text_completion", + "arerank", + "rerank", + ], + **kwargs, +): + """ + Enable cache with the specified configuration. + + Args: + type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local". + host (Optional[str]): The host address of the cache server. Defaults to None. + port (Optional[str]): The port number of the cache server. Defaults to None. + password (Optional[str]): The password for the cache server. Defaults to None. + supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]): + The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"]. + **kwargs: Additional keyword arguments. + + Returns: + None + + Raises: + None + """ + print_verbose("LiteLLM: Enabling Cache") + if "cache" not in litellm.input_callback: + litellm.input_callback.append("cache") + if "cache" not in litellm.success_callback: + litellm.logging_callback_manager.add_litellm_success_callback("cache") + if "cache" not in litellm._async_success_callback: + litellm.logging_callback_manager.add_litellm_async_success_callback("cache") + + if litellm.cache is None: + litellm.cache = Cache( + type=type, + host=host, + port=port, + password=password, + supported_call_types=supported_call_types, + **kwargs, + ) + print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}") + print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}") + + +def update_cache( + type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, + host: Optional[str] = None, + port: Optional[str] = None, + password: Optional[str] = None, + supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + "atext_completion", + "text_completion", + "arerank", + "rerank", + ], + **kwargs, +): + """ + Update the cache for LiteLLM. + + Args: + type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local". + host (Optional[str]): The host of the cache. Defaults to None. + port (Optional[str]): The port of the cache. Defaults to None. + password (Optional[str]): The password for the cache. Defaults to None. + supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]): + The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"]. + **kwargs: Additional keyword arguments for the cache. + + Returns: + None + + """ + print_verbose("LiteLLM: Updating Cache") + litellm.cache = Cache( + type=type, + host=host, + port=port, + password=password, + supported_call_types=supported_call_types, + **kwargs, + ) + print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}") + print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}") + + +def disable_cache(): + """ + Disable the cache used by LiteLLM. + + This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None. + + Parameters: + None + + Returns: + None + """ + from contextlib import suppress + + print_verbose("LiteLLM: Disabling Cache") + with suppress(ValueError): + litellm.input_callback.remove("cache") + litellm.success_callback.remove("cache") + litellm._async_success_callback.remove("cache") + + litellm.cache = None + print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}") diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/caching_handler.py b/.venv/lib/python3.12/site-packages/litellm/caching/caching_handler.py new file mode 100644 index 00000000..09fabf1c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/caching_handler.py @@ -0,0 +1,909 @@ +""" +This contains LLMCachingHandler + +This exposes two methods: + - async_get_cache + - async_set_cache + +This file is a wrapper around caching.py + +This class is used to handle caching logic specific for LLM API requests (completion / embedding / text_completion / transcription etc) + +It utilizes the (RedisCache, s3Cache, RedisSemanticCache, QdrantSemanticCache, InMemoryCache, DiskCache) based on what the user has setup + +In each method it will call the appropriate method from caching.py +""" + +import asyncio +import datetime +import inspect +import threading +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Callable, + Dict, + Generator, + List, + Optional, + Tuple, + Union, +) + +from pydantic import BaseModel + +import litellm +from litellm._logging import print_verbose, verbose_logger +from litellm.caching.caching import S3Cache +from litellm.litellm_core_utils.logging_utils import ( + _assemble_complete_response_from_streaming_chunks, +) +from litellm.types.rerank import RerankResponse +from litellm.types.utils import ( + CallTypes, + Embedding, + EmbeddingResponse, + ModelResponse, + TextCompletionResponse, + TranscriptionResponse, +) + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.utils import CustomStreamWrapper +else: + LiteLLMLoggingObj = Any + CustomStreamWrapper = Any + + +class CachingHandlerResponse(BaseModel): + """ + This is the response object for the caching handler. We need to separate embedding cached responses and (completion / text_completion / transcription) cached responses + + For embeddings there can be a cache hit for some of the inputs in the list and a cache miss for others + """ + + cached_result: Optional[Any] = None + final_embedding_cached_response: Optional[EmbeddingResponse] = None + embedding_all_elements_cache_hit: bool = ( + False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call + ) + + +class LLMCachingHandler: + def __init__( + self, + original_function: Callable, + request_kwargs: Dict[str, Any], + start_time: datetime.datetime, + ): + self.async_streaming_chunks: List[ModelResponse] = [] + self.sync_streaming_chunks: List[ModelResponse] = [] + self.request_kwargs = request_kwargs + self.original_function = original_function + self.start_time = start_time + pass + + async def _async_get_cache( + self, + model: str, + original_function: Callable, + logging_obj: LiteLLMLoggingObj, + start_time: datetime.datetime, + call_type: str, + kwargs: Dict[str, Any], + args: Optional[Tuple[Any, ...]] = None, + ) -> CachingHandlerResponse: + """ + Internal method to get from the cache. + Handles different call types (embeddings, chat/completions, text_completion, transcription) + and accordingly returns the cached response + + Args: + model: str: + original_function: Callable: + logging_obj: LiteLLMLoggingObj: + start_time: datetime.datetime: + call_type: str: + kwargs: Dict[str, Any]: + args: Optional[Tuple[Any, ...]] = None: + + + Returns: + CachingHandlerResponse: + Raises: + None + """ + from litellm.utils import CustomStreamWrapper + + args = args or () + + final_embedding_cached_response: Optional[EmbeddingResponse] = None + embedding_all_elements_cache_hit: bool = False + cached_result: Optional[Any] = None + if ( + (kwargs.get("caching", None) is None and litellm.cache is not None) + or kwargs.get("caching", False) is True + ) and ( + kwargs.get("cache", {}).get("no-cache", False) is not True + ): # allow users to control returning cached responses from the completion function + if litellm.cache is not None and self._is_call_type_supported_by_cache( + original_function=original_function + ): + verbose_logger.debug("Checking Cache") + cached_result = await self._retrieve_from_cache( + call_type=call_type, + kwargs=kwargs, + args=args, + ) + + if cached_result is not None and not isinstance(cached_result, list): + verbose_logger.debug("Cache Hit!") + cache_hit = True + end_time = datetime.datetime.now() + model, _, _, _ = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get("custom_llm_provider", None), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + self._update_litellm_logging_obj_environment( + logging_obj=logging_obj, + model=model, + kwargs=kwargs, + cached_result=cached_result, + is_async=True, + ) + + call_type = original_function.__name__ + + cached_result = self._convert_cached_result_to_model_response( + cached_result=cached_result, + call_type=call_type, + kwargs=kwargs, + logging_obj=logging_obj, + model=model, + custom_llm_provider=kwargs.get("custom_llm_provider", None), + args=args, + ) + if kwargs.get("stream", False) is False: + # LOG SUCCESS + self._async_log_cache_hit_on_callbacks( + logging_obj=logging_obj, + cached_result=cached_result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + ) + cache_key = litellm.cache._get_preset_cache_key_from_kwargs( + **kwargs + ) + if ( + isinstance(cached_result, BaseModel) + or isinstance(cached_result, CustomStreamWrapper) + ) and hasattr(cached_result, "_hidden_params"): + cached_result._hidden_params["cache_key"] = cache_key # type: ignore + return CachingHandlerResponse(cached_result=cached_result) + elif ( + call_type == CallTypes.aembedding.value + and cached_result is not None + and isinstance(cached_result, list) + and litellm.cache is not None + and not isinstance( + litellm.cache.cache, S3Cache + ) # s3 doesn't support bulk writing. Exclude. + ): + ( + final_embedding_cached_response, + embedding_all_elements_cache_hit, + ) = self._process_async_embedding_cached_response( + final_embedding_cached_response=final_embedding_cached_response, + cached_result=cached_result, + kwargs=kwargs, + logging_obj=logging_obj, + start_time=start_time, + model=model, + ) + return CachingHandlerResponse( + final_embedding_cached_response=final_embedding_cached_response, + embedding_all_elements_cache_hit=embedding_all_elements_cache_hit, + ) + verbose_logger.debug(f"CACHE RESULT: {cached_result}") + return CachingHandlerResponse( + cached_result=cached_result, + final_embedding_cached_response=final_embedding_cached_response, + ) + + def _sync_get_cache( + self, + model: str, + original_function: Callable, + logging_obj: LiteLLMLoggingObj, + start_time: datetime.datetime, + call_type: str, + kwargs: Dict[str, Any], + args: Optional[Tuple[Any, ...]] = None, + ) -> CachingHandlerResponse: + from litellm.utils import CustomStreamWrapper + + args = args or () + new_kwargs = kwargs.copy() + new_kwargs.update( + convert_args_to_kwargs( + self.original_function, + args, + ) + ) + cached_result: Optional[Any] = None + if litellm.cache is not None and self._is_call_type_supported_by_cache( + original_function=original_function + ): + print_verbose("Checking Cache") + cached_result = litellm.cache.get_cache(**new_kwargs) + if cached_result is not None: + if "detail" in cached_result: + # implies an error occurred + pass + else: + call_type = original_function.__name__ + cached_result = self._convert_cached_result_to_model_response( + cached_result=cached_result, + call_type=call_type, + kwargs=kwargs, + logging_obj=logging_obj, + model=model, + custom_llm_provider=kwargs.get("custom_llm_provider", None), + args=args, + ) + + # LOG SUCCESS + cache_hit = True + end_time = datetime.datetime.now() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model or "", + custom_llm_provider=kwargs.get("custom_llm_provider", None), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + self._update_litellm_logging_obj_environment( + logging_obj=logging_obj, + model=model, + kwargs=kwargs, + cached_result=cached_result, + is_async=False, + ) + + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() + cache_key = litellm.cache._get_preset_cache_key_from_kwargs( + **kwargs + ) + if ( + isinstance(cached_result, BaseModel) + or isinstance(cached_result, CustomStreamWrapper) + ) and hasattr(cached_result, "_hidden_params"): + cached_result._hidden_params["cache_key"] = cache_key # type: ignore + return CachingHandlerResponse(cached_result=cached_result) + return CachingHandlerResponse(cached_result=cached_result) + + def _process_async_embedding_cached_response( + self, + final_embedding_cached_response: Optional[EmbeddingResponse], + cached_result: List[Optional[Dict[str, Any]]], + kwargs: Dict[str, Any], + logging_obj: LiteLLMLoggingObj, + start_time: datetime.datetime, + model: str, + ) -> Tuple[Optional[EmbeddingResponse], bool]: + """ + Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit + + For embedding responses, there can be a cache hit for some of the inputs in the list and a cache miss for others + This function processes the cached embedding responses and returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit + + Args: + final_embedding_cached_response: Optional[EmbeddingResponse]: + cached_result: List[Optional[Dict[str, Any]]]: + kwargs: Dict[str, Any]: + logging_obj: LiteLLMLoggingObj: + start_time: datetime.datetime: + model: str: + + Returns: + Tuple[Optional[EmbeddingResponse], bool]: + Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit + + + """ + embedding_all_elements_cache_hit: bool = False + remaining_list = [] + non_null_list = [] + for idx, cr in enumerate(cached_result): + if cr is None: + remaining_list.append(kwargs["input"][idx]) + else: + non_null_list.append((idx, cr)) + original_kwargs_input = kwargs["input"] + kwargs["input"] = remaining_list + if len(non_null_list) > 0: + print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}") + final_embedding_cached_response = EmbeddingResponse( + model=kwargs.get("model"), + data=[None] * len(original_kwargs_input), + ) + final_embedding_cached_response._hidden_params["cache_hit"] = True + + for val in non_null_list: + idx, cr = val # (idx, cr) tuple + if cr is not None: + final_embedding_cached_response.data[idx] = Embedding( + embedding=cr["embedding"], + index=idx, + object="embedding", + ) + if len(remaining_list) == 0: + # LOG SUCCESS + cache_hit = True + embedding_all_elements_cache_hit = True + end_time = datetime.datetime.now() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get("custom_llm_provider", None), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + + self._update_litellm_logging_obj_environment( + logging_obj=logging_obj, + model=model, + kwargs=kwargs, + cached_result=final_embedding_cached_response, + is_async=True, + is_embedding=True, + ) + self._async_log_cache_hit_on_callbacks( + logging_obj=logging_obj, + cached_result=final_embedding_cached_response, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + ) + return final_embedding_cached_response, embedding_all_elements_cache_hit + return final_embedding_cached_response, embedding_all_elements_cache_hit + + def _combine_cached_embedding_response_with_api_result( + self, + _caching_handler_response: CachingHandlerResponse, + embedding_response: EmbeddingResponse, + start_time: datetime.datetime, + end_time: datetime.datetime, + ) -> EmbeddingResponse: + """ + Combines the cached embedding response with the API EmbeddingResponse + + For caching there can be a cache hit for some of the inputs in the list and a cache miss for others + This function combines the cached embedding response with the API EmbeddingResponse + + Args: + caching_handler_response: CachingHandlerResponse: + embedding_response: EmbeddingResponse: + + Returns: + EmbeddingResponse: + """ + if _caching_handler_response.final_embedding_cached_response is None: + return embedding_response + + idx = 0 + final_data_list = [] + for item in _caching_handler_response.final_embedding_cached_response.data: + if item is None and embedding_response.data is not None: + final_data_list.append(embedding_response.data[idx]) + idx += 1 + else: + final_data_list.append(item) + + _caching_handler_response.final_embedding_cached_response.data = final_data_list + _caching_handler_response.final_embedding_cached_response._hidden_params[ + "cache_hit" + ] = True + _caching_handler_response.final_embedding_cached_response._response_ms = ( + end_time - start_time + ).total_seconds() * 1000 + return _caching_handler_response.final_embedding_cached_response + + def _async_log_cache_hit_on_callbacks( + self, + logging_obj: LiteLLMLoggingObj, + cached_result: Any, + start_time: datetime.datetime, + end_time: datetime.datetime, + cache_hit: bool, + ): + """ + Helper function to log the success of a cached result on callbacks + + Args: + logging_obj (LiteLLMLoggingObj): The logging object. + cached_result: The cached result. + start_time (datetime): The start time of the operation. + end_time (datetime): The end time of the operation. + cache_hit (bool): Whether it was a cache hit. + """ + asyncio.create_task( + logging_obj.async_success_handler( + cached_result, start_time, end_time, cache_hit + ) + ) + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() + + async def _retrieve_from_cache( + self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...] + ) -> Optional[Any]: + """ + Internal method to + - get cache key + - check what type of cache is used - Redis, RedisSemantic, Qdrant, S3 + - async get cache value + - return the cached value + + Args: + call_type: str: + kwargs: Dict[str, Any]: + args: Optional[Tuple[Any, ...]] = None: + + Returns: + Optional[Any]: + Raises: + None + """ + if litellm.cache is None: + return None + + new_kwargs = kwargs.copy() + new_kwargs.update( + convert_args_to_kwargs( + self.original_function, + args, + ) + ) + cached_result: Optional[Any] = None + if call_type == CallTypes.aembedding.value and isinstance( + new_kwargs["input"], list + ): + tasks = [] + for idx, i in enumerate(new_kwargs["input"]): + preset_cache_key = litellm.cache.get_cache_key( + **{**new_kwargs, "input": i} + ) + tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key)) + cached_result = await asyncio.gather(*tasks) + ## check if cached result is None ## + if cached_result is not None and isinstance(cached_result, list): + # set cached_result to None if all elements are None + if all(result is None for result in cached_result): + cached_result = None + else: + if litellm.cache._supports_async() is True: + cached_result = await litellm.cache.async_get_cache(**new_kwargs) + else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync] + cached_result = litellm.cache.get_cache(**new_kwargs) + return cached_result + + def _convert_cached_result_to_model_response( + self, + cached_result: Any, + call_type: str, + kwargs: Dict[str, Any], + logging_obj: LiteLLMLoggingObj, + model: str, + args: Tuple[Any, ...], + custom_llm_provider: Optional[str] = None, + ) -> Optional[ + Union[ + ModelResponse, + TextCompletionResponse, + EmbeddingResponse, + RerankResponse, + TranscriptionResponse, + CustomStreamWrapper, + ] + ]: + """ + Internal method to process the cached result + + Checks the call type and converts the cached result to the appropriate model response object + example if call type is text_completion -> returns TextCompletionResponse object + + Args: + cached_result: Any: + call_type: str: + kwargs: Dict[str, Any]: + logging_obj: LiteLLMLoggingObj: + model: str: + custom_llm_provider: Optional[str] = None: + args: Optional[Tuple[Any, ...]] = None: + + Returns: + Optional[Any]: + """ + from litellm.utils import convert_to_model_response_object + + if ( + call_type == CallTypes.acompletion.value + or call_type == CallTypes.completion.value + ) and isinstance(cached_result, dict): + if kwargs.get("stream", False) is True: + cached_result = self._convert_cached_stream_response( + cached_result=cached_result, + call_type=call_type, + logging_obj=logging_obj, + model=model, + ) + else: + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=ModelResponse(), + ) + if ( + call_type == CallTypes.atext_completion.value + or call_type == CallTypes.text_completion.value + ) and isinstance(cached_result, dict): + if kwargs.get("stream", False) is True: + cached_result = self._convert_cached_stream_response( + cached_result=cached_result, + call_type=call_type, + logging_obj=logging_obj, + model=model, + ) + else: + cached_result = TextCompletionResponse(**cached_result) + elif ( + call_type == CallTypes.aembedding.value + or call_type == CallTypes.embedding.value + ) and isinstance(cached_result, dict): + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=EmbeddingResponse(), + response_type="embedding", + ) + + elif ( + call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value + ) and isinstance(cached_result, dict): + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=None, + response_type="rerank", + ) + elif ( + call_type == CallTypes.atranscription.value + or call_type == CallTypes.transcription.value + ) and isinstance(cached_result, dict): + hidden_params = { + "model": "whisper-1", + "custom_llm_provider": custom_llm_provider, + "cache_hit": True, + } + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=TranscriptionResponse(), + response_type="audio_transcription", + hidden_params=hidden_params, + ) + + if ( + hasattr(cached_result, "_hidden_params") + and cached_result._hidden_params is not None + and isinstance(cached_result._hidden_params, dict) + ): + cached_result._hidden_params["cache_hit"] = True + return cached_result + + def _convert_cached_stream_response( + self, + cached_result: Any, + call_type: str, + logging_obj: LiteLLMLoggingObj, + model: str, + ) -> CustomStreamWrapper: + from litellm.utils import ( + CustomStreamWrapper, + convert_to_streaming_response, + convert_to_streaming_response_async, + ) + + _stream_cached_result: Union[AsyncGenerator, Generator] + if ( + call_type == CallTypes.acompletion.value + or call_type == CallTypes.atext_completion.value + ): + _stream_cached_result = convert_to_streaming_response_async( + response_object=cached_result, + ) + else: + _stream_cached_result = convert_to_streaming_response( + response_object=cached_result, + ) + return CustomStreamWrapper( + completion_stream=_stream_cached_result, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + + async def async_set_cache( + self, + result: Any, + original_function: Callable, + kwargs: Dict[str, Any], + args: Optional[Tuple[Any, ...]] = None, + ): + """ + Internal method to check the type of the result & cache used and adds the result to the cache accordingly + + Args: + result: Any: + original_function: Callable: + kwargs: Dict[str, Any]: + args: Optional[Tuple[Any, ...]] = None: + + Returns: + None + Raises: + None + """ + if litellm.cache is None: + return + + new_kwargs = kwargs.copy() + new_kwargs.update( + convert_args_to_kwargs( + original_function, + args, + ) + ) + # [OPTIONAL] ADD TO CACHE + if self._should_store_result_in_cache( + original_function=original_function, kwargs=new_kwargs + ): + if ( + isinstance(result, litellm.ModelResponse) + or isinstance(result, litellm.EmbeddingResponse) + or isinstance(result, TranscriptionResponse) + or isinstance(result, RerankResponse) + ): + if ( + isinstance(result, EmbeddingResponse) + and isinstance(new_kwargs["input"], list) + and litellm.cache is not None + and not isinstance( + litellm.cache.cache, S3Cache + ) # s3 doesn't support bulk writing. Exclude. + ): + asyncio.create_task( + litellm.cache.async_add_cache_pipeline(result, **new_kwargs) + ) + elif isinstance(litellm.cache.cache, S3Cache): + threading.Thread( + target=litellm.cache.add_cache, + args=(result,), + kwargs=new_kwargs, + ).start() + else: + asyncio.create_task( + litellm.cache.async_add_cache( + result.model_dump_json(), **new_kwargs + ) + ) + else: + asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs)) + + def sync_set_cache( + self, + result: Any, + kwargs: Dict[str, Any], + args: Optional[Tuple[Any, ...]] = None, + ): + """ + Sync internal method to add the result to the cache + """ + + new_kwargs = kwargs.copy() + new_kwargs.update( + convert_args_to_kwargs( + self.original_function, + args, + ) + ) + if litellm.cache is None: + return + + if self._should_store_result_in_cache( + original_function=self.original_function, kwargs=new_kwargs + ): + + litellm.cache.add_cache(result, **new_kwargs) + + return + + def _should_store_result_in_cache( + self, original_function: Callable, kwargs: Dict[str, Any] + ) -> bool: + """ + Helper function to determine if the result should be stored in the cache. + + Returns: + bool: True if the result should be stored in the cache, False otherwise. + """ + return ( + (litellm.cache is not None) + and litellm.cache.supported_call_types is not None + and (str(original_function.__name__) in litellm.cache.supported_call_types) + and (kwargs.get("cache", {}).get("no-store", False) is not True) + ) + + def _is_call_type_supported_by_cache( + self, + original_function: Callable, + ) -> bool: + """ + Helper function to determine if the call type is supported by the cache. + + call types are acompletion, aembedding, atext_completion, atranscription, arerank + + Defined on `litellm.types.utils.CallTypes` + + Returns: + bool: True if the call type is supported by the cache, False otherwise. + """ + if ( + litellm.cache is not None + and litellm.cache.supported_call_types is not None + and str(original_function.__name__) in litellm.cache.supported_call_types + ): + return True + return False + + async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse): + """ + Internal method to add the streaming response to the cache + + + - If 'streaming_chunk' has a 'finish_reason' then assemble a litellm.ModelResponse object + - Else append the chunk to self.async_streaming_chunks + + """ + + complete_streaming_response: Optional[ + Union[ModelResponse, TextCompletionResponse] + ] = _assemble_complete_response_from_streaming_chunks( + result=processed_chunk, + start_time=self.start_time, + end_time=datetime.datetime.now(), + request_kwargs=self.request_kwargs, + streaming_chunks=self.async_streaming_chunks, + is_async=True, + ) + # if a complete_streaming_response is assembled, add it to the cache + if complete_streaming_response is not None: + await self.async_set_cache( + result=complete_streaming_response, + original_function=self.original_function, + kwargs=self.request_kwargs, + ) + + def _sync_add_streaming_response_to_cache(self, processed_chunk: ModelResponse): + """ + Sync internal method to add the streaming response to the cache + """ + complete_streaming_response: Optional[ + Union[ModelResponse, TextCompletionResponse] + ] = _assemble_complete_response_from_streaming_chunks( + result=processed_chunk, + start_time=self.start_time, + end_time=datetime.datetime.now(), + request_kwargs=self.request_kwargs, + streaming_chunks=self.sync_streaming_chunks, + is_async=False, + ) + + # if a complete_streaming_response is assembled, add it to the cache + if complete_streaming_response is not None: + self.sync_set_cache( + result=complete_streaming_response, + kwargs=self.request_kwargs, + ) + + def _update_litellm_logging_obj_environment( + self, + logging_obj: LiteLLMLoggingObj, + model: str, + kwargs: Dict[str, Any], + cached_result: Any, + is_async: bool, + is_embedding: bool = False, + ): + """ + Helper function to update the LiteLLMLoggingObj environment variables. + + Args: + logging_obj (LiteLLMLoggingObj): The logging object to update. + model (str): The model being used. + kwargs (Dict[str, Any]): The keyword arguments from the original function call. + cached_result (Any): The cached result to log. + is_async (bool): Whether the call is asynchronous or not. + is_embedding (bool): Whether the call is for embeddings or not. + + Returns: + None + """ + litellm_params = { + "logger_fn": kwargs.get("logger_fn", None), + "acompletion": is_async, + "api_base": kwargs.get("api_base", ""), + "metadata": kwargs.get("metadata", {}), + "model_info": kwargs.get("model_info", {}), + "proxy_server_request": kwargs.get("proxy_server_request", None), + "stream_response": kwargs.get("stream_response", {}), + } + + if litellm.cache is not None: + litellm_params["preset_cache_key"] = ( + litellm.cache._get_preset_cache_key_from_kwargs(**kwargs) + ) + else: + litellm_params["preset_cache_key"] = None + + logging_obj.update_environment_variables( + model=model, + user=kwargs.get("user", None), + optional_params={}, + litellm_params=litellm_params, + input=( + kwargs.get("messages", "") + if not is_embedding + else kwargs.get("input", "") + ), + api_key=kwargs.get("api_key", None), + original_response=str(cached_result), + additional_args=None, + stream=kwargs.get("stream", False), + ) + + +def convert_args_to_kwargs( + original_function: Callable, + args: Optional[Tuple[Any, ...]] = None, +) -> Dict[str, Any]: + # Get the signature of the original function + signature = inspect.signature(original_function) + + # Get parameter names in the order they appear in the original function + param_names = list(signature.parameters.keys()) + + # Create a mapping of positional arguments to parameter names + args_to_kwargs = {} + if args: + for index, arg in enumerate(args): + if index < len(param_names): + param_name = param_names[index] + args_to_kwargs[param_name] = arg + + return args_to_kwargs diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py new file mode 100644 index 00000000..abf3203f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py @@ -0,0 +1,88 @@ +import json +from typing import TYPE_CHECKING, Any, Optional + +from .base_cache import BaseCache + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class DiskCache(BaseCache): + def __init__(self, disk_cache_dir: Optional[str] = None): + import diskcache as dc + + # if users don't provider one, use the default litellm cache + if disk_cache_dir is None: + self.disk_cache = dc.Cache(".litellm_cache") + else: + self.disk_cache = dc.Cache(disk_cache_dir) + + def set_cache(self, key, value, **kwargs): + if "ttl" in kwargs: + self.disk_cache.set(key, value, expire=kwargs["ttl"]) + else: + self.disk_cache.set(key, value) + + async def async_set_cache(self, key, value, **kwargs): + self.set_cache(key=key, value=value, **kwargs) + + async def async_set_cache_pipeline(self, cache_list, **kwargs): + for cache_key, cache_value in cache_list: + if "ttl" in kwargs: + self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"]) + else: + self.set_cache(key=cache_key, value=cache_value) + + def get_cache(self, key, **kwargs): + original_cached_response = self.disk_cache.get(key) + if original_cached_response: + try: + cached_response = json.loads(original_cached_response) # type: ignore + except Exception: + cached_response = original_cached_response + return cached_response + return None + + def batch_get_cache(self, keys: list, **kwargs): + return_val = [] + for k in keys: + val = self.get_cache(key=k, **kwargs) + return_val.append(val) + return return_val + + def increment_cache(self, key, value: int, **kwargs) -> int: + # get the value + init_value = self.get_cache(key=key) or 0 + value = init_value + value # type: ignore + self.set_cache(key, value, **kwargs) + return value + + async def async_get_cache(self, key, **kwargs): + return self.get_cache(key=key, **kwargs) + + async def async_batch_get_cache(self, keys: list, **kwargs): + return_val = [] + for k in keys: + val = self.get_cache(key=k, **kwargs) + return_val.append(val) + return return_val + + async def async_increment(self, key, value: int, **kwargs) -> int: + # get the value + init_value = await self.async_get_cache(key=key) or 0 + value = init_value + value # type: ignore + await self.async_set_cache(key, value, **kwargs) + return value + + def flush_cache(self): + self.disk_cache.clear() + + async def disconnect(self): + pass + + def delete_cache(self, key): + self.disk_cache.pop(key) diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/dual_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/dual_cache.py new file mode 100644 index 00000000..5f598f7d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/dual_cache.py @@ -0,0 +1,434 @@ +""" +Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously. + +Has 4 primary methods: + - set_cache + - get_cache + - async_set_cache + - async_get_cache +""" + +import asyncio +import time +import traceback +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Any, List, Optional + +import litellm +from litellm._logging import print_verbose, verbose_logger + +from .base_cache import BaseCache +from .in_memory_cache import InMemoryCache +from .redis_cache import RedisCache + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + +from collections import OrderedDict + + +class LimitedSizeOrderedDict(OrderedDict): + def __init__(self, *args, max_size=100, **kwargs): + super().__init__(*args, **kwargs) + self.max_size = max_size + + def __setitem__(self, key, value): + # If inserting a new key exceeds max size, remove the oldest item + if len(self) >= self.max_size: + self.popitem(last=False) + super().__setitem__(key, value) + + +class DualCache(BaseCache): + """ + DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously. + When data is updated or inserted, it is written to both the in-memory cache + Redis. + This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data. + """ + + def __init__( + self, + in_memory_cache: Optional[InMemoryCache] = None, + redis_cache: Optional[RedisCache] = None, + default_in_memory_ttl: Optional[float] = None, + default_redis_ttl: Optional[float] = None, + default_redis_batch_cache_expiry: Optional[float] = None, + default_max_redis_batch_cache_size: int = 100, + ) -> None: + super().__init__() + # If in_memory_cache is not provided, use the default InMemoryCache + self.in_memory_cache = in_memory_cache or InMemoryCache() + # If redis_cache is not provided, use the default RedisCache + self.redis_cache = redis_cache + self.last_redis_batch_access_time = LimitedSizeOrderedDict( + max_size=default_max_redis_batch_cache_size + ) + self.redis_batch_cache_expiry = ( + default_redis_batch_cache_expiry + or litellm.default_redis_batch_cache_expiry + or 10 + ) + self.default_in_memory_ttl = ( + default_in_memory_ttl or litellm.default_in_memory_ttl + ) + self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl + + def update_cache_ttl( + self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float] + ): + if default_in_memory_ttl is not None: + self.default_in_memory_ttl = default_in_memory_ttl + + if default_redis_ttl is not None: + self.default_redis_ttl = default_redis_ttl + + def set_cache(self, key, value, local_only: bool = False, **kwargs): + # Update both Redis and in-memory cache + try: + if self.in_memory_cache is not None: + if "ttl" not in kwargs and self.default_in_memory_ttl is not None: + kwargs["ttl"] = self.default_in_memory_ttl + + self.in_memory_cache.set_cache(key, value, **kwargs) + + if self.redis_cache is not None and local_only is False: + self.redis_cache.set_cache(key, value, **kwargs) + except Exception as e: + print_verbose(e) + + def increment_cache( + self, key, value: int, local_only: bool = False, **kwargs + ) -> int: + """ + Key - the key in cache + + Value - int - the value you want to increment by + + Returns - int - the incremented value + """ + try: + result: int = value + if self.in_memory_cache is not None: + result = self.in_memory_cache.increment_cache(key, value, **kwargs) + + if self.redis_cache is not None and local_only is False: + result = self.redis_cache.increment_cache(key, value, **kwargs) + + return result + except Exception as e: + verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") + raise e + + def get_cache( + self, + key, + parent_otel_span: Optional[Span] = None, + local_only: bool = False, + **kwargs, + ): + # Try to fetch from in-memory cache first + try: + result = None + if self.in_memory_cache is not None: + in_memory_result = self.in_memory_cache.get_cache(key, **kwargs) + + if in_memory_result is not None: + result = in_memory_result + + if result is None and self.redis_cache is not None and local_only is False: + # If not found in in-memory cache, try fetching from Redis + redis_result = self.redis_cache.get_cache( + key, parent_otel_span=parent_otel_span + ) + + if redis_result is not None: + # Update in-memory cache with the value from Redis + self.in_memory_cache.set_cache(key, redis_result, **kwargs) + + result = redis_result + + print_verbose(f"get cache: cache result: {result}") + return result + except Exception: + verbose_logger.error(traceback.format_exc()) + + def batch_get_cache( + self, + keys: list, + parent_otel_span: Optional[Span] = None, + local_only: bool = False, + **kwargs, + ): + received_args = locals() + received_args.pop("self") + + def run_in_new_loop(): + """Run the coroutine in a new event loop within this thread.""" + new_loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(new_loop) + return new_loop.run_until_complete( + self.async_batch_get_cache(**received_args) + ) + finally: + new_loop.close() + asyncio.set_event_loop(None) + + try: + # First, try to get the current event loop + _ = asyncio.get_running_loop() + # If we're already in an event loop, run in a separate thread + # to avoid nested event loop issues + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + return future.result() + + except RuntimeError: + # No running event loop, we can safely run in this thread + return run_in_new_loop() + + async def async_get_cache( + self, + key, + parent_otel_span: Optional[Span] = None, + local_only: bool = False, + **kwargs, + ): + # Try to fetch from in-memory cache first + try: + print_verbose( + f"async get cache: cache key: {key}; local_only: {local_only}" + ) + result = None + if self.in_memory_cache is not None: + in_memory_result = await self.in_memory_cache.async_get_cache( + key, **kwargs + ) + + print_verbose(f"in_memory_result: {in_memory_result}") + if in_memory_result is not None: + result = in_memory_result + + if result is None and self.redis_cache is not None and local_only is False: + # If not found in in-memory cache, try fetching from Redis + redis_result = await self.redis_cache.async_get_cache( + key, parent_otel_span=parent_otel_span + ) + + if redis_result is not None: + # Update in-memory cache with the value from Redis + await self.in_memory_cache.async_set_cache( + key, redis_result, **kwargs + ) + + result = redis_result + + print_verbose(f"get cache: cache result: {result}") + return result + except Exception: + verbose_logger.error(traceback.format_exc()) + + def get_redis_batch_keys( + self, + current_time: float, + keys: List[str], + result: List[Any], + ) -> List[str]: + sublist_keys = [] + for key, value in zip(keys, result): + if value is None: + if ( + key not in self.last_redis_batch_access_time + or current_time - self.last_redis_batch_access_time[key] + >= self.redis_batch_cache_expiry + ): + sublist_keys.append(key) + return sublist_keys + + async def async_batch_get_cache( + self, + keys: list, + parent_otel_span: Optional[Span] = None, + local_only: bool = False, + **kwargs, + ): + try: + result = [None for _ in range(len(keys))] + if self.in_memory_cache is not None: + in_memory_result = await self.in_memory_cache.async_batch_get_cache( + keys, **kwargs + ) + + if in_memory_result is not None: + result = in_memory_result + + if None in result and self.redis_cache is not None and local_only is False: + """ + - for the none values in the result + - check the redis cache + """ + current_time = time.time() + sublist_keys = self.get_redis_batch_keys(current_time, keys, result) + + # Only hit Redis if the last access time was more than 5 seconds ago + if len(sublist_keys) > 0: + # If not found in in-memory cache, try fetching from Redis + redis_result = await self.redis_cache.async_batch_get_cache( + sublist_keys, parent_otel_span=parent_otel_span + ) + + if redis_result is not None: + # Update in-memory cache with the value from Redis + for key, value in redis_result.items(): + if value is not None: + await self.in_memory_cache.async_set_cache( + key, redis_result[key], **kwargs + ) + # Update the last access time for each key fetched from Redis + self.last_redis_batch_access_time[key] = current_time + + for key, value in redis_result.items(): + index = keys.index(key) + result[index] = value + + return result + except Exception: + verbose_logger.error(traceback.format_exc()) + + async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): + print_verbose( + f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}" + ) + try: + if self.in_memory_cache is not None: + await self.in_memory_cache.async_set_cache(key, value, **kwargs) + + if self.redis_cache is not None and local_only is False: + await self.redis_cache.async_set_cache(key, value, **kwargs) + except Exception as e: + verbose_logger.exception( + f"LiteLLM Cache: Excepton async add_cache: {str(e)}" + ) + + # async_batch_set_cache + async def async_set_cache_pipeline( + self, cache_list: list, local_only: bool = False, **kwargs + ): + """ + Batch write values to the cache + """ + print_verbose( + f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}" + ) + try: + if self.in_memory_cache is not None: + await self.in_memory_cache.async_set_cache_pipeline( + cache_list=cache_list, **kwargs + ) + + if self.redis_cache is not None and local_only is False: + await self.redis_cache.async_set_cache_pipeline( + cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs + ) + except Exception as e: + verbose_logger.exception( + f"LiteLLM Cache: Excepton async add_cache: {str(e)}" + ) + + async def async_increment_cache( + self, + key, + value: float, + parent_otel_span: Optional[Span] = None, + local_only: bool = False, + **kwargs, + ) -> float: + """ + Key - the key in cache + + Value - float - the value you want to increment by + + Returns - float - the incremented value + """ + try: + result: float = value + if self.in_memory_cache is not None: + result = await self.in_memory_cache.async_increment( + key, value, **kwargs + ) + + if self.redis_cache is not None and local_only is False: + result = await self.redis_cache.async_increment( + key, + value, + parent_otel_span=parent_otel_span, + ttl=kwargs.get("ttl", None), + ) + + return result + except Exception as e: + raise e # don't log if exception is raised + + async def async_set_cache_sadd( + self, key, value: List, local_only: bool = False, **kwargs + ) -> None: + """ + Add value to a set + + Key - the key in cache + + Value - str - the value you want to add to the set + + Returns - None + """ + try: + if self.in_memory_cache is not None: + _ = await self.in_memory_cache.async_set_cache_sadd( + key, value, ttl=kwargs.get("ttl", None) + ) + + if self.redis_cache is not None and local_only is False: + _ = await self.redis_cache.async_set_cache_sadd( + key, value, ttl=kwargs.get("ttl", None) + ) + + return None + except Exception as e: + raise e # don't log, if exception is raised + + def flush_cache(self): + if self.in_memory_cache is not None: + self.in_memory_cache.flush_cache() + if self.redis_cache is not None: + self.redis_cache.flush_cache() + + def delete_cache(self, key): + """ + Delete a key from the cache + """ + if self.in_memory_cache is not None: + self.in_memory_cache.delete_cache(key) + if self.redis_cache is not None: + self.redis_cache.delete_cache(key) + + async def async_delete_cache(self, key: str): + """ + Delete a key from the cache + """ + if self.in_memory_cache is not None: + self.in_memory_cache.delete_cache(key) + if self.redis_cache is not None: + await self.redis_cache.async_delete_cache(key) + + async def async_get_ttl(self, key: str) -> Optional[int]: + """ + Get the remaining TTL of a key in in-memory cache or redis + """ + ttl = await self.in_memory_cache.async_get_ttl(key) + if ttl is None and self.redis_cache is not None: + ttl = await self.redis_cache.async_get_ttl(key) + return ttl diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/in_memory_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/in_memory_cache.py new file mode 100644 index 00000000..5e09fe84 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/in_memory_cache.py @@ -0,0 +1,202 @@ +""" +In-Memory Cache implementation + +Has 4 methods: + - set_cache + - get_cache + - async_set_cache + - async_get_cache +""" + +import json +import sys +import time +from typing import Any, List, Optional + +from pydantic import BaseModel + +from ..constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB +from .base_cache import BaseCache + + +class InMemoryCache(BaseCache): + def __init__( + self, + max_size_in_memory: Optional[int] = 200, + default_ttl: Optional[ + int + ] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute + max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB + ): + """ + max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default + """ + self.max_size_in_memory = ( + max_size_in_memory or 200 + ) # set an upper bound of 200 items in-memory + self.default_ttl = default_ttl or 600 + self.max_size_per_item = ( + max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB + ) # 1MB = 1024KB + + # in-memory cache + self.cache_dict: dict = {} + self.ttl_dict: dict = {} + + def check_value_size(self, value: Any): + """ + Check if value size exceeds max_size_per_item (1MB) + Returns True if value size is acceptable, False otherwise + """ + try: + # Fast path for common primitive types that are typically small + if ( + isinstance(value, (bool, int, float, str)) + and len(str(value)) < self.max_size_per_item * 512 + ): # Conservative estimate + return True + + # Direct size check for bytes objects + if isinstance(value, bytes): + return sys.getsizeof(value) / 1024 <= self.max_size_per_item + + # Handle special types without full conversion when possible + if hasattr(value, "__sizeof__"): # Use __sizeof__ if available + size = value.__sizeof__() / 1024 + return size <= self.max_size_per_item + + # Fallback for complex types + if isinstance(value, BaseModel) and hasattr( + value, "model_dump" + ): # Pydantic v2 + value = value.model_dump() + elif hasattr(value, "isoformat"): # datetime objects + return True # datetime strings are always small + + # Only convert to JSON if absolutely necessary + if not isinstance(value, (str, bytes)): + value = json.dumps(value, default=str) + + return sys.getsizeof(value) / 1024 <= self.max_size_per_item + + except Exception: + return False + + def evict_cache(self): + """ + Eviction policy: + - check if any items in ttl_dict are expired -> remove them from ttl_dict and cache_dict + + + This guarantees the following: + - 1. When item ttl not set: At minimumm each item will remain in memory for 5 minutes + - 2. When ttl is set: the item will remain in memory for at least that amount of time + - 3. the size of in-memory cache is bounded + + """ + for key in list(self.ttl_dict.keys()): + if time.time() > self.ttl_dict[key]: + self.cache_dict.pop(key, None) + self.ttl_dict.pop(key, None) + + # de-reference the removed item + # https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/ + # One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used. + # This can occur when an object is referenced by another object, but the reference is never removed. + + def set_cache(self, key, value, **kwargs): + if len(self.cache_dict) >= self.max_size_in_memory: + # only evict when cache is full + self.evict_cache() + if not self.check_value_size(value): + return + + self.cache_dict[key] = value + if "ttl" in kwargs and kwargs["ttl"] is not None: + self.ttl_dict[key] = time.time() + kwargs["ttl"] + else: + self.ttl_dict[key] = time.time() + self.default_ttl + + async def async_set_cache(self, key, value, **kwargs): + self.set_cache(key=key, value=value, **kwargs) + + async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs): + for cache_key, cache_value in cache_list: + if ttl is not None: + self.set_cache(key=cache_key, value=cache_value, ttl=ttl) + else: + self.set_cache(key=cache_key, value=cache_value) + + async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]): + """ + Add value to set + """ + # get the value + init_value = self.get_cache(key=key) or set() + for val in value: + init_value.add(val) + self.set_cache(key, init_value, ttl=ttl) + return value + + def get_cache(self, key, **kwargs): + if key in self.cache_dict: + if key in self.ttl_dict: + if time.time() > self.ttl_dict[key]: + self.cache_dict.pop(key, None) + return None + original_cached_response = self.cache_dict[key] + try: + cached_response = json.loads(original_cached_response) + except Exception: + cached_response = original_cached_response + return cached_response + return None + + def batch_get_cache(self, keys: list, **kwargs): + return_val = [] + for k in keys: + val = self.get_cache(key=k, **kwargs) + return_val.append(val) + return return_val + + def increment_cache(self, key, value: int, **kwargs) -> int: + # get the value + init_value = self.get_cache(key=key) or 0 + value = init_value + value + self.set_cache(key, value, **kwargs) + return value + + async def async_get_cache(self, key, **kwargs): + return self.get_cache(key=key, **kwargs) + + async def async_batch_get_cache(self, keys: list, **kwargs): + return_val = [] + for k in keys: + val = self.get_cache(key=k, **kwargs) + return_val.append(val) + return return_val + + async def async_increment(self, key, value: float, **kwargs) -> float: + # get the value + init_value = await self.async_get_cache(key=key) or 0 + value = init_value + value + await self.async_set_cache(key, value, **kwargs) + + return value + + def flush_cache(self): + self.cache_dict.clear() + self.ttl_dict.clear() + + async def disconnect(self): + pass + + def delete_cache(self, key): + self.cache_dict.pop(key, None) + self.ttl_dict.pop(key, None) + + async def async_get_ttl(self, key: str) -> Optional[int]: + """ + Get the remaining TTL of a key in in-memory cache + """ + return self.ttl_dict.get(key, None) diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/llm_caching_handler.py b/.venv/lib/python3.12/site-packages/litellm/caching/llm_caching_handler.py new file mode 100644 index 00000000..429634b7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/llm_caching_handler.py @@ -0,0 +1,40 @@ +""" +Add the event loop to the cache key, to prevent event loop closed errors. +""" + +import asyncio + +from .in_memory_cache import InMemoryCache + + +class LLMClientCache(InMemoryCache): + + def update_cache_key_with_event_loop(self, key): + """ + Add the event loop to the cache key, to prevent event loop closed errors. + If none, use the key as is. + """ + try: + event_loop = asyncio.get_event_loop() + stringified_event_loop = str(id(event_loop)) + return f"{key}-{stringified_event_loop}" + except Exception: # handle no current event loop + return key + + def set_cache(self, key, value, **kwargs): + key = self.update_cache_key_with_event_loop(key) + return super().set_cache(key, value, **kwargs) + + async def async_set_cache(self, key, value, **kwargs): + key = self.update_cache_key_with_event_loop(key) + return await super().async_set_cache(key, value, **kwargs) + + def get_cache(self, key, **kwargs): + key = self.update_cache_key_with_event_loop(key) + + return super().get_cache(key, **kwargs) + + async def async_get_cache(self, key, **kwargs): + key = self.update_cache_key_with_event_loop(key) + + return await super().async_get_cache(key, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py new file mode 100644 index 00000000..bdfd3770 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py @@ -0,0 +1,430 @@ +""" +Qdrant Semantic Cache implementation + +Has 4 methods: + - set_cache + - get_cache + - async_set_cache + - async_get_cache +""" + +import ast +import asyncio +import json +from typing import Any + +import litellm +from litellm._logging import print_verbose + +from .base_cache import BaseCache + + +class QdrantSemanticCache(BaseCache): + def __init__( # noqa: PLR0915 + self, + qdrant_api_base=None, + qdrant_api_key=None, + collection_name=None, + similarity_threshold=None, + quantization_config=None, + embedding_model="text-embedding-ada-002", + host_type=None, + ): + import os + + from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, + httpxSpecialProvider, + ) + from litellm.secret_managers.main import get_secret_str + + if collection_name is None: + raise Exception("collection_name must be provided, passed None") + + self.collection_name = collection_name + print_verbose( + f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}" + ) + + if similarity_threshold is None: + raise Exception("similarity_threshold must be provided, passed None") + self.similarity_threshold = similarity_threshold + self.embedding_model = embedding_model + headers = {} + + # check if defined as os.environ/ variable + if qdrant_api_base: + if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith( + "os.environ/" + ): + qdrant_api_base = get_secret_str(qdrant_api_base) + if qdrant_api_key: + if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith( + "os.environ/" + ): + qdrant_api_key = get_secret_str(qdrant_api_key) + + qdrant_api_base = ( + qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE") + ) + qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY") + headers = {"Content-Type": "application/json"} + if qdrant_api_key: + headers["api-key"] = qdrant_api_key + + if qdrant_api_base is None: + raise ValueError("Qdrant url must be provided") + + self.qdrant_api_base = qdrant_api_base + self.qdrant_api_key = qdrant_api_key + print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}") + + self.headers = headers + + self.sync_client = _get_httpx_client() + self.async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.Caching + ) + + if quantization_config is None: + print_verbose( + "Quantization config is not provided. Default binary quantization will be used." + ) + collection_exists = self.sync_client.get( + url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists", + headers=self.headers, + ) + if collection_exists.status_code != 200: + raise ValueError( + f"Error from qdrant checking if /collections exist {collection_exists.text}" + ) + + if collection_exists.json()["result"]["exists"]: + collection_details = self.sync_client.get( + url=f"{self.qdrant_api_base}/collections/{self.collection_name}", + headers=self.headers, + ) + self.collection_info = collection_details.json() + print_verbose( + f"Collection already exists.\nCollection details:{self.collection_info}" + ) + else: + if quantization_config is None or quantization_config == "binary": + quantization_params = { + "binary": { + "always_ram": False, + } + } + elif quantization_config == "scalar": + quantization_params = { + "scalar": {"type": "int8", "quantile": 0.99, "always_ram": False} + } + elif quantization_config == "product": + quantization_params = { + "product": {"compression": "x16", "always_ram": False} + } + else: + raise Exception( + "Quantization config must be one of 'scalar', 'binary' or 'product'" + ) + + new_collection_status = self.sync_client.put( + url=f"{self.qdrant_api_base}/collections/{self.collection_name}", + json={ + "vectors": {"size": 1536, "distance": "Cosine"}, + "quantization_config": quantization_params, + }, + headers=self.headers, + ) + if new_collection_status.json()["result"]: + collection_details = self.sync_client.get( + url=f"{self.qdrant_api_base}/collections/{self.collection_name}", + headers=self.headers, + ) + self.collection_info = collection_details.json() + print_verbose( + f"New collection created.\nCollection details:{self.collection_info}" + ) + else: + raise Exception("Error while creating new collection") + + def _get_cache_logic(self, cached_response: Any): + if cached_response is None: + return cached_response + try: + cached_response = json.loads( + cached_response + ) # Convert string to dictionary + except Exception: + cached_response = ast.literal_eval(cached_response) + return cached_response + + def set_cache(self, key, value, **kwargs): + print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}") + import uuid + + # get the prompt + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + + # create an embedding for prompt + embedding_response = litellm.embedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + value = str(value) + assert isinstance(value, str) + + data = { + "points": [ + { + "id": str(uuid.uuid4()), + "vector": embedding, + "payload": { + "text": prompt, + "response": value, + }, + }, + ] + } + self.sync_client.put( + url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points", + headers=self.headers, + json=data, + ) + return + + def get_cache(self, key, **kwargs): + print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}") + + # get the messages + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + + # convert to embedding + embedding_response = litellm.embedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + data = { + "vector": embedding, + "params": { + "quantization": { + "ignore": False, + "rescore": True, + "oversampling": 3.0, + } + }, + "limit": 1, + "with_payload": True, + } + + search_response = self.sync_client.post( + url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search", + headers=self.headers, + json=data, + ) + results = search_response.json()["result"] + + if results is None: + return None + if isinstance(results, list): + if len(results) == 0: + return None + + similarity = results[0]["score"] + cached_prompt = results[0]["payload"]["text"] + + # check similarity, if more than self.similarity_threshold, return results + print_verbose( + f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" + ) + if similarity >= self.similarity_threshold: + # cache hit ! + cached_value = results[0]["payload"]["response"] + print_verbose( + f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" + ) + return self._get_cache_logic(cached_response=cached_value) + else: + # cache miss ! + return None + pass + + async def async_set_cache(self, key, value, **kwargs): + import uuid + + from litellm.proxy.proxy_server import llm_model_list, llm_router + + print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}") + + # get the prompt + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + # create an embedding for prompt + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if llm_router is not None and self.embedding_model in router_model_names: + user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") + embedding_response = await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + metadata={ + "user_api_key": user_api_key, + "semantic-cache-embedding": True, + "trace_id": kwargs.get("metadata", {}).get("trace_id", None), + }, + ) + else: + # convert to embedding + embedding_response = await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + value = str(value) + assert isinstance(value, str) + + data = { + "points": [ + { + "id": str(uuid.uuid4()), + "vector": embedding, + "payload": { + "text": prompt, + "response": value, + }, + }, + ] + } + + await self.async_client.put( + url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points", + headers=self.headers, + json=data, + ) + return + + async def async_get_cache(self, key, **kwargs): + print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}") + from litellm.proxy.proxy_server import llm_model_list, llm_router + + # get the messages + messages = kwargs["messages"] + prompt = "" + for message in messages: + prompt += message["content"] + + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if llm_router is not None and self.embedding_model in router_model_names: + user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") + embedding_response = await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + metadata={ + "user_api_key": user_api_key, + "semantic-cache-embedding": True, + "trace_id": kwargs.get("metadata", {}).get("trace_id", None), + }, + ) + else: + # convert to embedding + embedding_response = await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + data = { + "vector": embedding, + "params": { + "quantization": { + "ignore": False, + "rescore": True, + "oversampling": 3.0, + } + }, + "limit": 1, + "with_payload": True, + } + + search_response = await self.async_client.post( + url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search", + headers=self.headers, + json=data, + ) + + results = search_response.json()["result"] + + if results is None: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None + if isinstance(results, list): + if len(results) == 0: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None + + similarity = results[0]["score"] + cached_prompt = results[0]["payload"]["text"] + + # check similarity, if more than self.similarity_threshold, return results + print_verbose( + f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" + ) + + # update kwargs["metadata"] with similarity, don't rewrite the original metadata + kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity + + if similarity >= self.similarity_threshold: + # cache hit ! + cached_value = results[0]["payload"]["response"] + print_verbose( + f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" + ) + return self._get_cache_logic(cached_response=cached_value) + else: + # cache miss ! + return None + pass + + async def _collection_info(self): + return self.collection_info + + async def async_set_cache_pipeline(self, cache_list, **kwargs): + tasks = [] + for val in cache_list: + tasks.append(self.async_set_cache(val[0], val[1], **kwargs)) + await asyncio.gather(*tasks) diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/redis_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/redis_cache.py new file mode 100644 index 00000000..0571ac9f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/redis_cache.py @@ -0,0 +1,1047 @@ +""" +Redis Cache implementation + +Has 4 primary methods: + - set_cache + - get_cache + - async_set_cache + - async_get_cache +""" + +import ast +import asyncio +import inspect +import json +import time +from datetime import timedelta +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +import litellm +from litellm._logging import print_verbose, verbose_logger +from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.types.caching import RedisPipelineIncrementOperation +from litellm.types.services import ServiceTypes + +from .base_cache import BaseCache + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + from redis.asyncio import Redis, RedisCluster + from redis.asyncio.client import Pipeline + from redis.asyncio.cluster import ClusterPipeline + + pipeline = Pipeline + cluster_pipeline = ClusterPipeline + async_redis_client = Redis + async_redis_cluster_client = RedisCluster + Span = _Span +else: + pipeline = Any + cluster_pipeline = Any + async_redis_client = Any + async_redis_cluster_client = Any + Span = Any + + +class RedisCache(BaseCache): + # if users don't provider one, use the default litellm cache + + def __init__( + self, + host=None, + port=None, + password=None, + redis_flush_size: Optional[int] = 100, + namespace: Optional[str] = None, + startup_nodes: Optional[List] = None, # for redis-cluster + socket_timeout: Optional[float] = 5.0, # default 5 second timeout + **kwargs, + ): + + from litellm._service_logger import ServiceLogging + + from .._redis import get_redis_client, get_redis_connection_pool + + redis_kwargs = {} + if host is not None: + redis_kwargs["host"] = host + if port is not None: + redis_kwargs["port"] = port + if password is not None: + redis_kwargs["password"] = password + if startup_nodes is not None: + redis_kwargs["startup_nodes"] = startup_nodes + if socket_timeout is not None: + redis_kwargs["socket_timeout"] = socket_timeout + + ### HEALTH MONITORING OBJECT ### + if kwargs.get("service_logger_obj", None) is not None and isinstance( + kwargs["service_logger_obj"], ServiceLogging + ): + self.service_logger_obj = kwargs.pop("service_logger_obj") + else: + self.service_logger_obj = ServiceLogging() + + redis_kwargs.update(kwargs) + self.redis_client = get_redis_client(**redis_kwargs) + self.redis_async_client: Optional[async_redis_client] = None + self.redis_kwargs = redis_kwargs + self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs) + + # redis namespaces + self.namespace = namespace + # for high traffic, we store the redis results in memory and then batch write to redis + self.redis_batch_writing_buffer: list = [] + if redis_flush_size is None: + self.redis_flush_size: int = 100 + else: + self.redis_flush_size = redis_flush_size + self.redis_version = "Unknown" + try: + if not inspect.iscoroutinefunction(self.redis_client): + self.redis_version = self.redis_client.info()["redis_version"] # type: ignore + except Exception: + pass + + ### ASYNC HEALTH PING ### + try: + # asyncio.get_running_loop().create_task(self.ping()) + _ = asyncio.get_running_loop().create_task(self.ping()) + except Exception as e: + if "no running event loop" in str(e): + verbose_logger.debug( + "Ignoring async redis ping. No running event loop." + ) + else: + verbose_logger.error( + "Error connecting to Async Redis client - {}".format(str(e)), + extra={"error": str(e)}, + ) + + ### SYNC HEALTH PING ### + try: + if hasattr(self.redis_client, "ping"): + self.redis_client.ping() # type: ignore + except Exception as e: + verbose_logger.error( + "Error connecting to Sync Redis client", extra={"error": str(e)} + ) + + if litellm.default_redis_ttl is not None: + super().__init__(default_ttl=int(litellm.default_redis_ttl)) + else: + super().__init__() # defaults to 60s + + def init_async_client( + self, + ) -> Union[async_redis_client, async_redis_cluster_client]: + from .._redis import get_redis_async_client + + if self.redis_async_client is None: + self.redis_async_client = get_redis_async_client( + connection_pool=self.async_redis_conn_pool, **self.redis_kwargs + ) + return self.redis_async_client + + def check_and_fix_namespace(self, key: str) -> str: + """ + Make sure each key starts with the given namespace + """ + if self.namespace is not None and not key.startswith(self.namespace): + key = self.namespace + ":" + key + + return key + + def set_cache(self, key, value, **kwargs): + ttl = self.get_ttl(**kwargs) + print_verbose( + f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}" + ) + key = self.check_and_fix_namespace(key=key) + try: + start_time = time.time() + self.redis_client.set(name=key, value=str(value), ex=ttl) + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="set_cache", + start_time=start_time, + end_time=end_time, + ) + except Exception as e: + # NON blocking - notify users Redis is throwing an exception + print_verbose( + f"litellm.caching.caching: set() - Got exception from REDIS : {str(e)}" + ) + + def increment_cache( + self, key, value: int, ttl: Optional[float] = None, **kwargs + ) -> int: + _redis_client = self.redis_client + start_time = time.time() + set_ttl = self.get_ttl(ttl=ttl) + try: + start_time = time.time() + result: int = _redis_client.incr(name=key, amount=value) # type: ignore + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="increment_cache", + start_time=start_time, + end_time=end_time, + ) + + if set_ttl is not None: + # check if key already has ttl, if not -> set ttl + start_time = time.time() + current_ttl = _redis_client.ttl(key) + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="increment_cache_ttl", + start_time=start_time, + end_time=end_time, + ) + if current_ttl == -1: + # Key has no expiration + start_time = time.time() + _redis_client.expire(key, set_ttl) # type: ignore + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="increment_cache_expire", + start_time=start_time, + end_time=end_time, + ) + return result + except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + verbose_logger.error( + "LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + raise e + + async def async_scan_iter(self, pattern: str, count: int = 100) -> list: + from redis.asyncio import Redis + + start_time = time.time() + try: + keys = [] + _redis_client: Redis = self.init_async_client() # type: ignore + + async for key in _redis_client.scan_iter(match=pattern + "*", count=count): + keys.append(key) + if len(keys) >= count: + break + + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_scan_iter", + start_time=start_time, + end_time=end_time, + ) + ) # DO NOT SLOW DOWN CALL B/C OF THIS + return keys + except Exception as e: + # NON blocking - notify users Redis is throwing an exception + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_scan_iter", + start_time=start_time, + end_time=end_time, + ) + ) + raise e + + async def async_set_cache(self, key, value, **kwargs): + from redis.asyncio import Redis + + start_time = time.time() + try: + _redis_client: Redis = self.init_async_client() # type: ignore + except Exception as e: + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + call_type="async_set_cache", + ) + ) + verbose_logger.error( + "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + raise e + + key = self.check_and_fix_namespace(key=key) + ttl = self.get_ttl(**kwargs) + print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}") + + try: + if not hasattr(_redis_client, "set"): + raise Exception("Redis client cannot set cache. Attribute not found.") + await _redis_client.set(name=key, value=json.dumps(value), ex=ttl) + print_verbose( + f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" + ) + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_set_cache", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + event_metadata={"key": key}, + ) + ) + except Exception as e: + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_set_cache", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + event_metadata={"key": key}, + ) + ) + verbose_logger.error( + "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + + async def _pipeline_helper( + self, + pipe: Union[pipeline, cluster_pipeline], + cache_list: List[Tuple[Any, Any]], + ttl: Optional[float], + ) -> List: + """ + Helper function for executing a pipeline of set operations on Redis + """ + ttl = self.get_ttl(ttl=ttl) + # Iterate through each key-value pair in the cache_list and set them in the pipeline. + for cache_key, cache_value in cache_list: + cache_key = self.check_and_fix_namespace(key=cache_key) + print_verbose( + f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" + ) + json_cache_value = json.dumps(cache_value) + # Set the value with a TTL if it's provided. + _td: Optional[timedelta] = None + if ttl is not None: + _td = timedelta(seconds=ttl) + pipe.set( # type: ignore + name=cache_key, + value=json_cache_value, + ex=_td, + ) + # Execute the pipeline and return the results. + results = await pipe.execute() + return results + + async def async_set_cache_pipeline( + self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs + ): + """ + Use Redis Pipelines for bulk write operations + """ + # don't waste a network request if there's nothing to set + if len(cache_list) == 0: + return + + _redis_client = self.init_async_client() + start_time = time.time() + + print_verbose( + f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}" + ) + cache_value: Any = None + try: + async with _redis_client.pipeline(transaction=False) as pipe: + results = await self._pipeline_helper(pipe, cache_list, ttl) + + print_verbose(f"pipeline results: {results}") + # Optionally, you could process 'results' to make sure that all set operations were successful. + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_set_cache_pipeline", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + return None + except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_set_cache_pipeline", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + + verbose_logger.error( + "LiteLLM Redis Caching: async set_cache_pipeline() - Got exception from REDIS %s, Writing value=%s", + str(e), + cache_value, + ) + + async def _set_cache_sadd_helper( + self, + redis_client: async_redis_client, + key: str, + value: List, + ttl: Optional[float], + ) -> None: + """Helper function for async_set_cache_sadd. Separated for testing.""" + ttl = self.get_ttl(ttl=ttl) + try: + await redis_client.sadd(key, *value) # type: ignore + if ttl is not None: + _td = timedelta(seconds=ttl) + await redis_client.expire(key, _td) + except Exception: + raise + + async def async_set_cache_sadd( + self, key, value: List, ttl: Optional[float], **kwargs + ): + from redis.asyncio import Redis + + start_time = time.time() + try: + _redis_client: Redis = self.init_async_client() # type: ignore + except Exception as e: + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + call_type="async_set_cache_sadd", + ) + ) + # NON blocking - notify users Redis is throwing an exception + verbose_logger.error( + "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + raise e + + key = self.check_and_fix_namespace(key=key) + print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}") + try: + await self._set_cache_sadd_helper( + redis_client=_redis_client, key=key, value=value, ttl=ttl + ) + print_verbose( + f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}" + ) + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_set_cache_sadd", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + except Exception as e: + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_set_cache_sadd", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + # NON blocking - notify users Redis is throwing an exception + verbose_logger.error( + "LiteLLM Redis Caching: async set_cache_sadd() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + + async def batch_cache_write(self, key, value, **kwargs): + print_verbose( + f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}", + ) + key = self.check_and_fix_namespace(key=key) + self.redis_batch_writing_buffer.append((key, value)) + if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: + await self.flush_cache_buffer() # logging done in here + + async def async_increment( + self, + key, + value: float, + ttl: Optional[int] = None, + parent_otel_span: Optional[Span] = None, + ) -> float: + from redis.asyncio import Redis + + _redis_client: Redis = self.init_async_client() # type: ignore + start_time = time.time() + _used_ttl = self.get_ttl(ttl=ttl) + key = self.check_and_fix_namespace(key=key) + try: + result = await _redis_client.incrbyfloat(name=key, amount=value) + if _used_ttl is not None: + # check if key already has ttl, if not -> set ttl + current_ttl = await _redis_client.ttl(key) + if current_ttl == -1: + # Key has no expiration + await _redis_client.expire(key, _used_ttl) + + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_increment", + start_time=start_time, + end_time=end_time, + parent_otel_span=parent_otel_span, + ) + ) + return result + except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_increment", + start_time=start_time, + end_time=end_time, + parent_otel_span=parent_otel_span, + ) + ) + verbose_logger.error( + "LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + raise e + + async def flush_cache_buffer(self): + print_verbose( + f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}" + ) + await self.async_set_cache_pipeline(self.redis_batch_writing_buffer) + self.redis_batch_writing_buffer = [] + + def _get_cache_logic(self, cached_response: Any): + """ + Common 'get_cache_logic' across sync + async redis client implementations + """ + if cached_response is None: + return cached_response + # cached_response is in `b{} convert it to ModelResponse + cached_response = cached_response.decode("utf-8") # Convert bytes to string + try: + cached_response = json.loads( + cached_response + ) # Convert string to dictionary + except Exception: + cached_response = ast.literal_eval(cached_response) + return cached_response + + def get_cache(self, key, parent_otel_span: Optional[Span] = None, **kwargs): + try: + key = self.check_and_fix_namespace(key=key) + print_verbose(f"Get Redis Cache: key: {key}") + start_time = time.time() + cached_response = self.redis_client.get(key) + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="get_cache", + start_time=start_time, + end_time=end_time, + parent_otel_span=parent_otel_span, + ) + print_verbose( + f"Got Redis Cache: key: {key}, cached_response {cached_response}" + ) + return self._get_cache_logic(cached_response=cached_response) + except Exception as e: + # NON blocking - notify users Redis is throwing an exception + verbose_logger.error( + "litellm.caching.caching: get() - Got exception from REDIS: ", e + ) + + def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]: + """ + Wrapper to call `mget` on the redis client + + We use a wrapper so RedisCluster can override this method + """ + return self.redis_client.mget(keys=keys) # type: ignore + + async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]: + """ + Wrapper to call `mget` on the redis client + + We use a wrapper so RedisCluster can override this method + """ + async_redis_client = self.init_async_client() + return await async_redis_client.mget(keys=keys) # type: ignore + + def batch_get_cache( + self, + key_list: Union[List[str], List[Optional[str]]], + parent_otel_span: Optional[Span] = None, + ) -> dict: + """ + Use Redis for bulk read operations + + Args: + key_list: List of keys to get from Redis + parent_otel_span: Optional parent OpenTelemetry span + + Returns: + dict: A dictionary mapping keys to their cached values + """ + key_value_dict = {} + _key_list = [key for key in key_list if key is not None] + + try: + _keys = [] + for cache_key in _key_list: + cache_key = self.check_and_fix_namespace(key=cache_key or "") + _keys.append(cache_key) + start_time = time.time() + results: List = self._run_redis_mget_operation(keys=_keys) + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="batch_get_cache", + start_time=start_time, + end_time=end_time, + parent_otel_span=parent_otel_span, + ) + + # Associate the results back with their keys. + # 'results' is a list of values corresponding to the order of keys in '_key_list'. + key_value_dict = dict(zip(_key_list, results)) + + decoded_results = {} + for k, v in key_value_dict.items(): + if isinstance(k, bytes): + k = k.decode("utf-8") + v = self._get_cache_logic(v) + decoded_results[k] = v + + return decoded_results + except Exception as e: + verbose_logger.error(f"Error occurred in batch get cache - {str(e)}") + return key_value_dict + + async def async_get_cache( + self, key, parent_otel_span: Optional[Span] = None, **kwargs + ): + from redis.asyncio import Redis + + _redis_client: Redis = self.init_async_client() # type: ignore + key = self.check_and_fix_namespace(key=key) + start_time = time.time() + + try: + print_verbose(f"Get Async Redis Cache: key: {key}") + cached_response = await _redis_client.get(key) + print_verbose( + f"Got Async Redis Cache: key: {key}, cached_response {cached_response}" + ) + response = self._get_cache_logic(cached_response=cached_response) + + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_get_cache", + start_time=start_time, + end_time=end_time, + parent_otel_span=parent_otel_span, + event_metadata={"key": key}, + ) + ) + return response + except Exception as e: + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_get_cache", + start_time=start_time, + end_time=end_time, + parent_otel_span=parent_otel_span, + event_metadata={"key": key}, + ) + ) + print_verbose( + f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}" + ) + + async def async_batch_get_cache( + self, + key_list: Union[List[str], List[Optional[str]]], + parent_otel_span: Optional[Span] = None, + ) -> dict: + """ + Use Redis for bulk read operations + + Args: + key_list: List of keys to get from Redis + parent_otel_span: Optional parent OpenTelemetry span + + Returns: + dict: A dictionary mapping keys to their cached values + + `.mget` does not support None keys. This will filter out None keys. + """ + # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget` + key_value_dict = {} + start_time = time.time() + _key_list = [key for key in key_list if key is not None] + try: + _keys = [] + for cache_key in _key_list: + cache_key = self.check_and_fix_namespace(key=cache_key) + _keys.append(cache_key) + results = await self._async_run_redis_mget_operation(keys=_keys) + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_batch_get_cache", + start_time=start_time, + end_time=end_time, + parent_otel_span=parent_otel_span, + ) + ) + + # Associate the results back with their keys. + # 'results' is a list of values corresponding to the order of keys in 'key_list'. + key_value_dict = dict(zip(_key_list, results)) + + decoded_results = {} + for k, v in key_value_dict.items(): + if isinstance(k, bytes): + k = k.decode("utf-8") + v = self._get_cache_logic(v) + decoded_results[k] = v + + return decoded_results + except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_batch_get_cache", + start_time=start_time, + end_time=end_time, + parent_otel_span=parent_otel_span, + ) + ) + verbose_logger.error(f"Error occurred in async batch get cache - {str(e)}") + return key_value_dict + + def sync_ping(self) -> bool: + """ + Tests if the sync redis client is correctly setup. + """ + print_verbose("Pinging Sync Redis Cache") + start_time = time.time() + try: + response: bool = self.redis_client.ping() # type: ignore + print_verbose(f"Redis Cache PING: {response}") + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="sync_ping", + start_time=start_time, + end_time=end_time, + ) + return response + except Exception as e: + # NON blocking - notify users Redis is throwing an exception + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="sync_ping", + ) + verbose_logger.error( + f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}" + ) + raise e + + async def ping(self) -> bool: + # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ping` + _redis_client: Any = self.init_async_client() + start_time = time.time() + print_verbose("Pinging Async Redis Cache") + try: + response = await _redis_client.ping() + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_ping", + ) + ) + return response + except Exception as e: + # NON blocking - notify users Redis is throwing an exception + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_ping", + ) + ) + verbose_logger.error( + f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}" + ) + raise e + + async def delete_cache_keys(self, keys): + # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete` + _redis_client: Any = self.init_async_client() + # keys is a list, unpack it so it gets passed as individual elements to delete + await _redis_client.delete(*keys) + + def client_list(self) -> List: + client_list: List = self.redis_client.client_list() # type: ignore + return client_list + + def info(self): + info = self.redis_client.info() + return info + + def flush_cache(self): + self.redis_client.flushall() + + def flushall(self): + self.redis_client.flushall() + + async def disconnect(self): + await self.async_redis_conn_pool.disconnect(inuse_connections=True) + + async def async_delete_cache(self, key: str): + # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete` + _redis_client: Any = self.init_async_client() + # keys is str + await _redis_client.delete(key) + + def delete_cache(self, key): + self.redis_client.delete(key) + + async def _pipeline_increment_helper( + self, + pipe: pipeline, + increment_list: List[RedisPipelineIncrementOperation], + ) -> Optional[List[float]]: + """Helper function for pipeline increment operations""" + # Iterate through each increment operation and add commands to pipeline + for increment_op in increment_list: + cache_key = self.check_and_fix_namespace(key=increment_op["key"]) + print_verbose( + f"Increment ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {increment_op['increment_value']}\nttl={increment_op['ttl']}" + ) + pipe.incrbyfloat(cache_key, increment_op["increment_value"]) + if increment_op["ttl"] is not None: + _td = timedelta(seconds=increment_op["ttl"]) + pipe.expire(cache_key, _td) + # Execute the pipeline and return results + results = await pipe.execute() + print_verbose(f"Increment ASYNC Redis Cache PIPELINE: results: {results}") + return results + + async def async_increment_pipeline( + self, increment_list: List[RedisPipelineIncrementOperation], **kwargs + ) -> Optional[List[float]]: + """ + Use Redis Pipelines for bulk increment operations + Args: + increment_list: List of RedisPipelineIncrementOperation dicts containing: + - key: str + - increment_value: float + - ttl_seconds: int + """ + # don't waste a network request if there's nothing to increment + if len(increment_list) == 0: + return None + + from redis.asyncio import Redis + + _redis_client: Redis = self.init_async_client() # type: ignore + start_time = time.time() + + print_verbose( + f"Increment Async Redis Cache Pipeline: increment list: {increment_list}" + ) + + try: + async with _redis_client.pipeline(transaction=False) as pipe: + results = await self._pipeline_increment_helper(pipe, increment_list) + + print_verbose(f"pipeline increment results: {results}") + + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_increment_pipeline", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + return results + except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_increment_pipeline", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + verbose_logger.error( + "LiteLLM Redis Caching: async increment_pipeline() - Got exception from REDIS %s", + str(e), + ) + raise e + + async def async_get_ttl(self, key: str) -> Optional[int]: + """ + Get the remaining TTL of a key in Redis + + Args: + key (str): The key to get TTL for + + Returns: + Optional[int]: The remaining TTL in seconds, or None if key doesn't exist + + Redis ref: https://redis.io/docs/latest/commands/ttl/ + """ + try: + # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ttl` + _redis_client: Any = self.init_async_client() + ttl = await _redis_client.ttl(key) + if ttl <= -1: # -1 means the key does not exist, -2 key does not exist + return None + return ttl + except Exception as e: + verbose_logger.debug(f"Redis TTL Error: {e}") + return None diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/redis_cluster_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/redis_cluster_cache.py new file mode 100644 index 00000000..2e7d1de1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/redis_cluster_cache.py @@ -0,0 +1,59 @@ +""" +Redis Cluster Cache implementation + +Key differences: +- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created +""" + +from typing import TYPE_CHECKING, Any, List, Optional + +from litellm.caching.redis_cache import RedisCache + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + from redis.asyncio import Redis, RedisCluster + from redis.asyncio.client import Pipeline + + pipeline = Pipeline + async_redis_client = Redis + Span = _Span +else: + pipeline = Any + async_redis_client = Any + Span = Any + + +class RedisClusterCache(RedisCache): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.redis_async_redis_cluster_client: Optional[RedisCluster] = None + self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None + + def init_async_client(self): + from redis.asyncio import RedisCluster + + from .._redis import get_redis_async_client + + if self.redis_async_redis_cluster_client: + return self.redis_async_redis_cluster_client + + _redis_client = get_redis_async_client( + connection_pool=self.async_redis_conn_pool, **self.redis_kwargs + ) + if isinstance(_redis_client, RedisCluster): + self.redis_async_redis_cluster_client = _redis_client + + return _redis_client + + def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]: + """ + Overrides `_run_redis_mget_operation` in redis_cache.py + """ + return self.redis_client.mget_nonatomic(keys=keys) # type: ignore + + async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]: + """ + Overrides `_async_run_redis_mget_operation` in redis_cache.py + """ + async_redis_cluster_client = self.init_async_client() + return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py new file mode 100644 index 00000000..b609286a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py @@ -0,0 +1,337 @@ +""" +Redis Semantic Cache implementation + +Has 4 methods: + - set_cache + - get_cache + - async_set_cache + - async_get_cache +""" + +import ast +import asyncio +import json +from typing import Any + +import litellm +from litellm._logging import print_verbose + +from .base_cache import BaseCache + + +class RedisSemanticCache(BaseCache): + def __init__( + self, + host=None, + port=None, + password=None, + redis_url=None, + similarity_threshold=None, + use_async=False, + embedding_model="text-embedding-ada-002", + **kwargs, + ): + from redisvl.index import SearchIndex + + print_verbose( + "redis semantic-cache initializing INDEX - litellm_semantic_cache_index" + ) + if similarity_threshold is None: + raise Exception("similarity_threshold must be provided, passed None") + self.similarity_threshold = similarity_threshold + self.embedding_model = embedding_model + schema = { + "index": { + "name": "litellm_semantic_cache_index", + "prefix": "litellm", + "storage_type": "hash", + }, + "fields": { + "text": [{"name": "response"}], + "vector": [ + { + "name": "litellm_embedding", + "dims": 1536, + "distance_metric": "cosine", + "algorithm": "flat", + "datatype": "float32", + } + ], + }, + } + if redis_url is None: + # if no url passed, check if host, port and password are passed, if not raise an Exception + if host is None or port is None or password is None: + # try checking env for host, port and password + import os + + host = os.getenv("REDIS_HOST") + port = os.getenv("REDIS_PORT") + password = os.getenv("REDIS_PASSWORD") + if host is None or port is None or password is None: + raise Exception("Redis host, port, and password must be provided") + + redis_url = "redis://:" + password + "@" + host + ":" + port + print_verbose(f"redis semantic-cache redis_url: {redis_url}") + if use_async is False: + self.index = SearchIndex.from_dict(schema) + self.index.connect(redis_url=redis_url) + try: + self.index.create(overwrite=False) # don't overwrite existing index + except Exception as e: + print_verbose(f"Got exception creating semantic cache index: {str(e)}") + elif use_async is True: + schema["index"]["name"] = "litellm_semantic_cache_index_async" + self.index = SearchIndex.from_dict(schema) + self.index.connect(redis_url=redis_url, use_async=True) + + # + def _get_cache_logic(self, cached_response: Any): + """ + Common 'get_cache_logic' across sync + async redis client implementations + """ + if cached_response is None: + return cached_response + + # check if cached_response is bytes + if isinstance(cached_response, bytes): + cached_response = cached_response.decode("utf-8") + + try: + cached_response = json.loads( + cached_response + ) # Convert string to dictionary + except Exception: + cached_response = ast.literal_eval(cached_response) + return cached_response + + def set_cache(self, key, value, **kwargs): + import numpy as np + + print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}") + + # get the prompt + messages = kwargs["messages"] + prompt = "".join(message["content"] for message in messages) + + # create an embedding for prompt + embedding_response = litellm.embedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + # make the embedding a numpy array, convert to bytes + embedding_bytes = np.array(embedding, dtype=np.float32).tobytes() + value = str(value) + assert isinstance(value, str) + + new_data = [ + {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes} + ] + + # Add more data + self.index.load(new_data) + + return + + def get_cache(self, key, **kwargs): + print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}") + from redisvl.query import VectorQuery + + # query + # get the messages + messages = kwargs["messages"] + prompt = "".join(message["content"] for message in messages) + + # convert to embedding + embedding_response = litellm.embedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + query = VectorQuery( + vector=embedding, + vector_field_name="litellm_embedding", + return_fields=["response", "prompt", "vector_distance"], + num_results=1, + ) + + results = self.index.query(query) + if results is None: + return None + if isinstance(results, list): + if len(results) == 0: + return None + + vector_distance = results[0]["vector_distance"] + vector_distance = float(vector_distance) + similarity = 1 - vector_distance + cached_prompt = results[0]["prompt"] + + # check similarity, if more than self.similarity_threshold, return results + print_verbose( + f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" + ) + if similarity > self.similarity_threshold: + # cache hit ! + cached_value = results[0]["response"] + print_verbose( + f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" + ) + return self._get_cache_logic(cached_response=cached_value) + else: + # cache miss ! + return None + + pass + + async def async_set_cache(self, key, value, **kwargs): + import numpy as np + + from litellm.proxy.proxy_server import llm_model_list, llm_router + + try: + await self.index.acreate(overwrite=False) # don't overwrite existing index + except Exception as e: + print_verbose(f"Got exception creating semantic cache index: {str(e)}") + print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}") + + # get the prompt + messages = kwargs["messages"] + prompt = "".join(message["content"] for message in messages) + # create an embedding for prompt + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if llm_router is not None and self.embedding_model in router_model_names: + user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") + embedding_response = await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + metadata={ + "user_api_key": user_api_key, + "semantic-cache-embedding": True, + "trace_id": kwargs.get("metadata", {}).get("trace_id", None), + }, + ) + else: + # convert to embedding + embedding_response = await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + # make the embedding a numpy array, convert to bytes + embedding_bytes = np.array(embedding, dtype=np.float32).tobytes() + value = str(value) + assert isinstance(value, str) + + new_data = [ + {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes} + ] + + # Add more data + await self.index.aload(new_data) + return + + async def async_get_cache(self, key, **kwargs): + print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}") + from redisvl.query import VectorQuery + + from litellm.proxy.proxy_server import llm_model_list, llm_router + + # query + # get the messages + messages = kwargs["messages"] + prompt = "".join(message["content"] for message in messages) + + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if llm_router is not None and self.embedding_model in router_model_names: + user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") + embedding_response = await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + metadata={ + "user_api_key": user_api_key, + "semantic-cache-embedding": True, + "trace_id": kwargs.get("metadata", {}).get("trace_id", None), + }, + ) + else: + # convert to embedding + embedding_response = await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + query = VectorQuery( + vector=embedding, + vector_field_name="litellm_embedding", + return_fields=["response", "prompt", "vector_distance"], + ) + results = await self.index.aquery(query) + if results is None: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None + if isinstance(results, list): + if len(results) == 0: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None + + vector_distance = results[0]["vector_distance"] + vector_distance = float(vector_distance) + similarity = 1 - vector_distance + cached_prompt = results[0]["prompt"] + + # check similarity, if more than self.similarity_threshold, return results + print_verbose( + f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" + ) + + # update kwargs["metadata"] with similarity, don't rewrite the original metadata + kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity + + if similarity > self.similarity_threshold: + # cache hit ! + cached_value = results[0]["response"] + print_verbose( + f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" + ) + return self._get_cache_logic(cached_response=cached_value) + else: + # cache miss ! + return None + pass + + async def _index_info(self): + return await self.index.ainfo() + + async def async_set_cache_pipeline(self, cache_list, **kwargs): + tasks = [] + for val in cache_list: + tasks.append(self.async_set_cache(val[0], val[1], **kwargs)) + await asyncio.gather(*tasks) diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/s3_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/s3_cache.py new file mode 100644 index 00000000..301591c6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/s3_cache.py @@ -0,0 +1,159 @@ +""" +S3 Cache implementation +WARNING: DO NOT USE THIS IN PRODUCTION - This is not ASYNC + +Has 4 methods: + - set_cache + - get_cache + - async_set_cache + - async_get_cache +""" + +import ast +import asyncio +import json +from typing import Optional + +from litellm._logging import print_verbose, verbose_logger + +from .base_cache import BaseCache + + +class S3Cache(BaseCache): + def __init__( + self, + s3_bucket_name, + s3_region_name=None, + s3_api_version=None, + s3_use_ssl: Optional[bool] = True, + s3_verify=None, + s3_endpoint_url=None, + s3_aws_access_key_id=None, + s3_aws_secret_access_key=None, + s3_aws_session_token=None, + s3_config=None, + s3_path=None, + **kwargs, + ): + import boto3 + + self.bucket_name = s3_bucket_name + self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else "" + # Create an S3 client with custom endpoint URL + + self.s3_client = boto3.client( + "s3", + region_name=s3_region_name, + endpoint_url=s3_endpoint_url, + api_version=s3_api_version, + use_ssl=s3_use_ssl, + verify=s3_verify, + aws_access_key_id=s3_aws_access_key_id, + aws_secret_access_key=s3_aws_secret_access_key, + aws_session_token=s3_aws_session_token, + config=s3_config, + **kwargs, + ) + + def set_cache(self, key, value, **kwargs): + try: + print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}") + ttl = kwargs.get("ttl", None) + # Convert value to JSON before storing in S3 + serialized_value = json.dumps(value) + key = self.key_prefix + key + + if ttl is not None: + cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}" + import datetime + + # Calculate expiration time + expiration_time = datetime.datetime.now() + ttl + + # Upload the data to S3 with the calculated expiration time + self.s3_client.put_object( + Bucket=self.bucket_name, + Key=key, + Body=serialized_value, + Expires=expiration_time, + CacheControl=cache_control, + ContentType="application/json", + ContentLanguage="en", + ContentDisposition=f'inline; filename="{key}.json"', + ) + else: + cache_control = "immutable, max-age=31536000, s-maxage=31536000" + # Upload the data to S3 without specifying Expires + self.s3_client.put_object( + Bucket=self.bucket_name, + Key=key, + Body=serialized_value, + CacheControl=cache_control, + ContentType="application/json", + ContentLanguage="en", + ContentDisposition=f'inline; filename="{key}.json"', + ) + except Exception as e: + # NON blocking - notify users S3 is throwing an exception + print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}") + + async def async_set_cache(self, key, value, **kwargs): + self.set_cache(key=key, value=value, **kwargs) + + def get_cache(self, key, **kwargs): + import botocore + + try: + key = self.key_prefix + key + + print_verbose(f"Get S3 Cache: key: {key}") + # Download the data from S3 + cached_response = self.s3_client.get_object( + Bucket=self.bucket_name, Key=key + ) + + if cached_response is not None: + # cached_response is in `b{} convert it to ModelResponse + cached_response = ( + cached_response["Body"].read().decode("utf-8") + ) # Convert bytes to string + try: + cached_response = json.loads( + cached_response + ) # Convert string to dictionary + except Exception: + cached_response = ast.literal_eval(cached_response) + if type(cached_response) is not dict: + cached_response = dict(cached_response) + verbose_logger.debug( + f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}" + ) + + return cached_response + except botocore.exceptions.ClientError as e: # type: ignore + if e.response["Error"]["Code"] == "NoSuchKey": + verbose_logger.debug( + f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket." + ) + return None + + except Exception as e: + # NON blocking - notify users S3 is throwing an exception + verbose_logger.error( + f"S3 Caching: get_cache() - Got exception from S3: {e}" + ) + + async def async_get_cache(self, key, **kwargs): + return self.get_cache(key=key, **kwargs) + + def flush_cache(self): + pass + + async def disconnect(self): + pass + + async def async_set_cache_pipeline(self, cache_list, **kwargs): + tasks = [] + for val in cache_list: + tasks.append(self.async_set_cache(val[0], val[1], **kwargs)) + await asyncio.gather(*tasks) |