aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/caching
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/caching
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/caching')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/Readme.md40
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/__init__.py9
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/_internal_lru_cache.py30
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/base_cache.py55
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/caching.py797
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/caching_handler.py909
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py88
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/dual_cache.py434
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/in_memory_cache.py202
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/llm_caching_handler.py40
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py430
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/redis_cache.py1047
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/redis_cluster_cache.py59
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py337
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/s3_cache.py159
15 files changed, 4636 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/Readme.md b/.venv/lib/python3.12/site-packages/litellm/caching/Readme.md
new file mode 100644
index 00000000..6b0210a6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/Readme.md
@@ -0,0 +1,40 @@
+# Caching on LiteLLM
+
+LiteLLM supports multiple caching mechanisms. This allows users to choose the most suitable caching solution for their use case.
+
+The following caching mechanisms are supported:
+
+1. **RedisCache**
+2. **RedisSemanticCache**
+3. **QdrantSemanticCache**
+4. **InMemoryCache**
+5. **DiskCache**
+6. **S3Cache**
+7. **DualCache** (updates both Redis and an in-memory cache simultaneously)
+
+## Folder Structure
+
+```
+litellm/caching/
+├── base_cache.py
+├── caching.py
+├── caching_handler.py
+├── disk_cache.py
+├── dual_cache.py
+├── in_memory_cache.py
+├── qdrant_semantic_cache.py
+├── redis_cache.py
+├── redis_semantic_cache.py
+├── s3_cache.py
+```
+
+## Documentation
+- [Caching on LiteLLM Gateway](https://docs.litellm.ai/docs/proxy/caching)
+- [Caching on LiteLLM Python](https://docs.litellm.ai/docs/caching/all_caches)
+
+
+
+
+
+
+
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/__init__.py b/.venv/lib/python3.12/site-packages/litellm/caching/__init__.py
new file mode 100644
index 00000000..e10d01ff
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/__init__.py
@@ -0,0 +1,9 @@
+from .caching import Cache, LiteLLMCacheType
+from .disk_cache import DiskCache
+from .dual_cache import DualCache
+from .in_memory_cache import InMemoryCache
+from .qdrant_semantic_cache import QdrantSemanticCache
+from .redis_cache import RedisCache
+from .redis_cluster_cache import RedisClusterCache
+from .redis_semantic_cache import RedisSemanticCache
+from .s3_cache import S3Cache
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/_internal_lru_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/_internal_lru_cache.py
new file mode 100644
index 00000000..54b0fe96
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/_internal_lru_cache.py
@@ -0,0 +1,30 @@
+from functools import lru_cache
+from typing import Callable, Optional, TypeVar
+
+T = TypeVar("T")
+
+
+def lru_cache_wrapper(
+ maxsize: Optional[int] = None,
+) -> Callable[[Callable[..., T]], Callable[..., T]]:
+ """
+ Wrapper for lru_cache that caches success and exceptions
+ """
+
+ def decorator(f: Callable[..., T]) -> Callable[..., T]:
+ @lru_cache(maxsize=maxsize)
+ def wrapper(*args, **kwargs):
+ try:
+ return ("success", f(*args, **kwargs))
+ except Exception as e:
+ return ("error", e)
+
+ def wrapped(*args, **kwargs):
+ result = wrapper(*args, **kwargs)
+ if result[0] == "error":
+ raise result[1]
+ return result[1]
+
+ return wrapped
+
+ return decorator
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/base_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/base_cache.py
new file mode 100644
index 00000000..7109951d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/base_cache.py
@@ -0,0 +1,55 @@
+"""
+Base Cache implementation. All cache implementations should inherit from this class.
+
+Has 4 methods:
+ - set_cache
+ - get_cache
+ - async_set_cache
+ - async_get_cache
+"""
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Optional
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ Span = _Span
+else:
+ Span = Any
+
+
+class BaseCache(ABC):
+ def __init__(self, default_ttl: int = 60):
+ self.default_ttl = default_ttl
+
+ def get_ttl(self, **kwargs) -> Optional[int]:
+ kwargs_ttl: Optional[int] = kwargs.get("ttl")
+ if kwargs_ttl is not None:
+ try:
+ return int(kwargs_ttl)
+ except ValueError:
+ return self.default_ttl
+ return self.default_ttl
+
+ def set_cache(self, key, value, **kwargs):
+ raise NotImplementedError
+
+ async def async_set_cache(self, key, value, **kwargs):
+ raise NotImplementedError
+
+ @abstractmethod
+ async def async_set_cache_pipeline(self, cache_list, **kwargs):
+ pass
+
+ def get_cache(self, key, **kwargs):
+ raise NotImplementedError
+
+ async def async_get_cache(self, key, **kwargs):
+ raise NotImplementedError
+
+ async def batch_cache_write(self, key, value, **kwargs):
+ raise NotImplementedError
+
+ async def disconnect(self):
+ raise NotImplementedError
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/caching.py b/.venv/lib/python3.12/site-packages/litellm/caching/caching.py
new file mode 100644
index 00000000..415c49ed
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/caching.py
@@ -0,0 +1,797 @@
+# +-----------------------------------------------+
+# | |
+# | Give Feedback / Get Help |
+# | https://github.com/BerriAI/litellm/issues/new |
+# | |
+# +-----------------------------------------------+
+#
+# Thank you users! We ❤️ you! - Krrish & Ishaan
+
+import ast
+import hashlib
+import json
+import time
+import traceback
+from enum import Enum
+from typing import Any, Dict, List, Optional, Union
+
+from pydantic import BaseModel
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
+from litellm.types.caching import *
+from litellm.types.utils import all_litellm_params
+
+from .base_cache import BaseCache
+from .disk_cache import DiskCache
+from .dual_cache import DualCache # noqa
+from .in_memory_cache import InMemoryCache
+from .qdrant_semantic_cache import QdrantSemanticCache
+from .redis_cache import RedisCache
+from .redis_cluster_cache import RedisClusterCache
+from .redis_semantic_cache import RedisSemanticCache
+from .s3_cache import S3Cache
+
+
+def print_verbose(print_statement):
+ try:
+ verbose_logger.debug(print_statement)
+ if litellm.set_verbose:
+ print(print_statement) # noqa
+ except Exception:
+ pass
+
+
+class CacheMode(str, Enum):
+ default_on = "default_on"
+ default_off = "default_off"
+
+
+#### LiteLLM.Completion / Embedding Cache ####
+class Cache:
+ def __init__(
+ self,
+ type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
+ mode: Optional[
+ CacheMode
+ ] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
+ host: Optional[str] = None,
+ port: Optional[str] = None,
+ password: Optional[str] = None,
+ namespace: Optional[str] = None,
+ ttl: Optional[float] = None,
+ default_in_memory_ttl: Optional[float] = None,
+ default_in_redis_ttl: Optional[float] = None,
+ similarity_threshold: Optional[float] = None,
+ supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
+ "completion",
+ "acompletion",
+ "embedding",
+ "aembedding",
+ "atranscription",
+ "transcription",
+ "atext_completion",
+ "text_completion",
+ "arerank",
+ "rerank",
+ ],
+ # s3 Bucket, boto3 configuration
+ s3_bucket_name: Optional[str] = None,
+ s3_region_name: Optional[str] = None,
+ s3_api_version: Optional[str] = None,
+ s3_use_ssl: Optional[bool] = True,
+ s3_verify: Optional[Union[bool, str]] = None,
+ s3_endpoint_url: Optional[str] = None,
+ s3_aws_access_key_id: Optional[str] = None,
+ s3_aws_secret_access_key: Optional[str] = None,
+ s3_aws_session_token: Optional[str] = None,
+ s3_config: Optional[Any] = None,
+ s3_path: Optional[str] = None,
+ redis_semantic_cache_use_async=False,
+ redis_semantic_cache_embedding_model="text-embedding-ada-002",
+ redis_flush_size: Optional[int] = None,
+ redis_startup_nodes: Optional[List] = None,
+ disk_cache_dir=None,
+ qdrant_api_base: Optional[str] = None,
+ qdrant_api_key: Optional[str] = None,
+ qdrant_collection_name: Optional[str] = None,
+ qdrant_quantization_config: Optional[str] = None,
+ qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
+ **kwargs,
+ ):
+ """
+ Initializes the cache based on the given type.
+
+ Args:
+ type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
+
+ # Redis Cache Args
+ host (str, optional): The host address for the Redis cache. Required if type is "redis".
+ port (int, optional): The port number for the Redis cache. Required if type is "redis".
+ password (str, optional): The password for the Redis cache. Required if type is "redis".
+ namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
+ ttl (float, optional): The ttl for the Redis cache
+ redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
+ redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
+
+ # Qdrant Cache Args
+ qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
+ qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
+ qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
+ similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
+
+ # Disk Cache Args
+ disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
+
+ # S3 Cache Args
+ s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
+ s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
+ s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
+ s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
+ s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
+ s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
+ s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
+ s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
+ s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
+ s3_config (dict, optional): The config for the s3 cache. Defaults to None.
+
+ # Common Cache Args
+ supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
+ **kwargs: Additional keyword arguments for redis.Redis() cache
+
+ Raises:
+ ValueError: If an invalid cache type is provided.
+
+ Returns:
+ None. Cache is set as a litellm param
+ """
+ if type == LiteLLMCacheType.REDIS:
+ if redis_startup_nodes:
+ self.cache: BaseCache = RedisClusterCache(
+ host=host,
+ port=port,
+ password=password,
+ redis_flush_size=redis_flush_size,
+ startup_nodes=redis_startup_nodes,
+ **kwargs,
+ )
+ else:
+ self.cache = RedisCache(
+ host=host,
+ port=port,
+ password=password,
+ redis_flush_size=redis_flush_size,
+ **kwargs,
+ )
+ elif type == LiteLLMCacheType.REDIS_SEMANTIC:
+ self.cache = RedisSemanticCache(
+ host=host,
+ port=port,
+ password=password,
+ similarity_threshold=similarity_threshold,
+ use_async=redis_semantic_cache_use_async,
+ embedding_model=redis_semantic_cache_embedding_model,
+ **kwargs,
+ )
+ elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
+ self.cache = QdrantSemanticCache(
+ qdrant_api_base=qdrant_api_base,
+ qdrant_api_key=qdrant_api_key,
+ collection_name=qdrant_collection_name,
+ similarity_threshold=similarity_threshold,
+ quantization_config=qdrant_quantization_config,
+ embedding_model=qdrant_semantic_cache_embedding_model,
+ )
+ elif type == LiteLLMCacheType.LOCAL:
+ self.cache = InMemoryCache()
+ elif type == LiteLLMCacheType.S3:
+ self.cache = S3Cache(
+ s3_bucket_name=s3_bucket_name,
+ s3_region_name=s3_region_name,
+ s3_api_version=s3_api_version,
+ s3_use_ssl=s3_use_ssl,
+ s3_verify=s3_verify,
+ s3_endpoint_url=s3_endpoint_url,
+ s3_aws_access_key_id=s3_aws_access_key_id,
+ s3_aws_secret_access_key=s3_aws_secret_access_key,
+ s3_aws_session_token=s3_aws_session_token,
+ s3_config=s3_config,
+ s3_path=s3_path,
+ **kwargs,
+ )
+ elif type == LiteLLMCacheType.DISK:
+ self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
+ if "cache" not in litellm.input_callback:
+ litellm.input_callback.append("cache")
+ if "cache" not in litellm.success_callback:
+ litellm.logging_callback_manager.add_litellm_success_callback("cache")
+ if "cache" not in litellm._async_success_callback:
+ litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
+ self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
+ self.type = type
+ self.namespace = namespace
+ self.redis_flush_size = redis_flush_size
+ self.ttl = ttl
+ self.mode: CacheMode = mode or CacheMode.default_on
+
+ if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
+ self.ttl = default_in_memory_ttl
+
+ if (
+ self.type == LiteLLMCacheType.REDIS
+ or self.type == LiteLLMCacheType.REDIS_SEMANTIC
+ ) and default_in_redis_ttl is not None:
+ self.ttl = default_in_redis_ttl
+
+ if self.namespace is not None and isinstance(self.cache, RedisCache):
+ self.cache.namespace = self.namespace
+
+ def get_cache_key(self, **kwargs) -> str:
+ """
+ Get the cache key for the given arguments.
+
+ Args:
+ **kwargs: kwargs to litellm.completion() or embedding()
+
+ Returns:
+ str: The cache key generated from the arguments, or None if no cache key could be generated.
+ """
+ cache_key = ""
+ # verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
+
+ preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
+ if preset_cache_key is not None:
+ verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
+ return preset_cache_key
+
+ combined_kwargs = ModelParamHelper._get_all_llm_api_params()
+ litellm_param_kwargs = all_litellm_params
+ for param in kwargs:
+ if param in combined_kwargs:
+ param_value: Optional[str] = self._get_param_value(param, kwargs)
+ if param_value is not None:
+ cache_key += f"{str(param)}: {str(param_value)}"
+ elif (
+ param not in litellm_param_kwargs
+ ): # check if user passed in optional param - e.g. top_k
+ if (
+ litellm.enable_caching_on_provider_specific_optional_params is True
+ ): # feature flagged for now
+ if kwargs[param] is None:
+ continue # ignore None params
+ param_value = kwargs[param]
+ cache_key += f"{str(param)}: {str(param_value)}"
+
+ verbose_logger.debug("\nCreated cache key: %s", cache_key)
+ hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
+ hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
+ self._set_preset_cache_key_in_kwargs(
+ preset_cache_key=hashed_cache_key, **kwargs
+ )
+ return hashed_cache_key
+
+ def _get_param_value(
+ self,
+ param: str,
+ kwargs: dict,
+ ) -> Optional[str]:
+ """
+ Get the value for the given param from kwargs
+ """
+ if param == "model":
+ return self._get_model_param_value(kwargs)
+ elif param == "file":
+ return self._get_file_param_value(kwargs)
+ return kwargs[param]
+
+ def _get_model_param_value(self, kwargs: dict) -> str:
+ """
+ Handles getting the value for the 'model' param from kwargs
+
+ 1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
+ 2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
+ 3. Else use the `model` passed in kwargs
+ """
+ metadata: Dict = kwargs.get("metadata", {}) or {}
+ litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
+ metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
+ model_group: Optional[str] = metadata.get(
+ "model_group"
+ ) or metadata_in_litellm_params.get("model_group")
+ caching_group = self._get_caching_group(metadata, model_group)
+ return caching_group or model_group or kwargs["model"]
+
+ def _get_caching_group(
+ self, metadata: dict, model_group: Optional[str]
+ ) -> Optional[str]:
+ caching_groups: Optional[List] = metadata.get("caching_groups", [])
+ if caching_groups:
+ for group in caching_groups:
+ if model_group in group:
+ return str(group)
+ return None
+
+ def _get_file_param_value(self, kwargs: dict) -> str:
+ """
+ Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
+ """
+ file = kwargs.get("file")
+ metadata = kwargs.get("metadata", {})
+ litellm_params = kwargs.get("litellm_params", {})
+ return (
+ metadata.get("file_checksum")
+ or getattr(file, "name", None)
+ or metadata.get("file_name")
+ or litellm_params.get("file_name")
+ )
+
+ def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
+ """
+ Get the preset cache key from kwargs["litellm_params"]
+
+ We use _get_preset_cache_keys for two reasons
+
+ 1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
+ 2. avoid doing duplicate / repeated work
+ """
+ if kwargs:
+ if "litellm_params" in kwargs:
+ return kwargs["litellm_params"].get("preset_cache_key", None)
+ return None
+
+ def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
+ """
+ Set the calculated cache key in kwargs
+
+ This is used to avoid doing duplicate / repeated work
+
+ Placed in kwargs["litellm_params"]
+ """
+ if kwargs:
+ if "litellm_params" in kwargs:
+ kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
+
+ @staticmethod
+ def _get_hashed_cache_key(cache_key: str) -> str:
+ """
+ Get the hashed cache key for the given cache key.
+
+ Use hashlib to create a sha256 hash of the cache key
+
+ Args:
+ cache_key (str): The cache key to hash.
+
+ Returns:
+ str: The hashed cache key.
+ """
+ hash_object = hashlib.sha256(cache_key.encode())
+ # Hexadecimal representation of the hash
+ hash_hex = hash_object.hexdigest()
+ verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
+ return hash_hex
+
+ def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
+ """
+ If a redis namespace is provided, add it to the cache key
+
+ Args:
+ hash_hex (str): The hashed cache key.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The final hashed cache key with the redis namespace.
+ """
+ dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
+ namespace = (
+ dynamic_cache_control.get("namespace")
+ or kwargs.get("metadata", {}).get("redis_namespace")
+ or self.namespace
+ )
+ if namespace:
+ hash_hex = f"{namespace}:{hash_hex}"
+ verbose_logger.debug("Final hashed key: %s", hash_hex)
+ return hash_hex
+
+ def generate_streaming_content(self, content):
+ chunk_size = 5 # Adjust the chunk size as needed
+ for i in range(0, len(content), chunk_size):
+ yield {
+ "choices": [
+ {
+ "delta": {
+ "role": "assistant",
+ "content": content[i : i + chunk_size],
+ }
+ }
+ ]
+ }
+ time.sleep(0.02)
+
+ def _get_cache_logic(
+ self,
+ cached_result: Optional[Any],
+ max_age: Optional[float],
+ ):
+ """
+ Common get cache logic across sync + async implementations
+ """
+ # Check if a timestamp was stored with the cached response
+ if (
+ cached_result is not None
+ and isinstance(cached_result, dict)
+ and "timestamp" in cached_result
+ ):
+ timestamp = cached_result["timestamp"]
+ current_time = time.time()
+
+ # Calculate age of the cached response
+ response_age = current_time - timestamp
+
+ # Check if the cached response is older than the max-age
+ if max_age is not None and response_age > max_age:
+ return None # Cached response is too old
+
+ # If the response is fresh, or there's no max-age requirement, return the cached response
+ # cached_response is in `b{} convert it to ModelResponse
+ cached_response = cached_result.get("response")
+ try:
+ if isinstance(cached_response, dict):
+ pass
+ else:
+ cached_response = json.loads(
+ cached_response # type: ignore
+ ) # Convert string to dictionary
+ except Exception:
+ cached_response = ast.literal_eval(cached_response) # type: ignore
+ return cached_response
+ return cached_result
+
+ def get_cache(self, **kwargs):
+ """
+ Retrieves the cached result for the given arguments.
+
+ Args:
+ *args: args to litellm.completion() or embedding()
+ **kwargs: kwargs to litellm.completion() or embedding()
+
+ Returns:
+ The cached result if it exists, otherwise None.
+ """
+ try: # never block execution
+ if self.should_use_cache(**kwargs) is not True:
+ return
+ messages = kwargs.get("messages", [])
+ if "cache_key" in kwargs:
+ cache_key = kwargs["cache_key"]
+ else:
+ cache_key = self.get_cache_key(**kwargs)
+ if cache_key is not None:
+ cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
+ max_age = (
+ cache_control_args.get("s-maxage")
+ or cache_control_args.get("s-max-age")
+ or float("inf")
+ )
+ cached_result = self.cache.get_cache(cache_key, messages=messages)
+ cached_result = self.cache.get_cache(cache_key, messages=messages)
+ return self._get_cache_logic(
+ cached_result=cached_result, max_age=max_age
+ )
+ except Exception:
+ print_verbose(f"An exception occurred: {traceback.format_exc()}")
+ return None
+
+ async def async_get_cache(self, **kwargs):
+ """
+ Async get cache implementation.
+
+ Used for embedding calls in async wrapper
+ """
+
+ try: # never block execution
+ if self.should_use_cache(**kwargs) is not True:
+ return
+
+ kwargs.get("messages", [])
+ if "cache_key" in kwargs:
+ cache_key = kwargs["cache_key"]
+ else:
+ cache_key = self.get_cache_key(**kwargs)
+ if cache_key is not None:
+ cache_control_args = kwargs.get("cache", {})
+ max_age = cache_control_args.get(
+ "s-max-age", cache_control_args.get("s-maxage", float("inf"))
+ )
+ cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
+ return self._get_cache_logic(
+ cached_result=cached_result, max_age=max_age
+ )
+ except Exception:
+ print_verbose(f"An exception occurred: {traceback.format_exc()}")
+ return None
+
+ def _add_cache_logic(self, result, **kwargs):
+ """
+ Common implementation across sync + async add_cache functions
+ """
+ try:
+ if "cache_key" in kwargs:
+ cache_key = kwargs["cache_key"]
+ else:
+ cache_key = self.get_cache_key(**kwargs)
+ if cache_key is not None:
+ if isinstance(result, BaseModel):
+ result = result.model_dump_json()
+
+ ## DEFAULT TTL ##
+ if self.ttl is not None:
+ kwargs["ttl"] = self.ttl
+ ## Get Cache-Controls ##
+ _cache_kwargs = kwargs.get("cache", None)
+ if isinstance(_cache_kwargs, dict):
+ for k, v in _cache_kwargs.items():
+ if k == "ttl":
+ kwargs["ttl"] = v
+
+ cached_data = {"timestamp": time.time(), "response": result}
+ return cache_key, cached_data, kwargs
+ else:
+ raise Exception("cache key is None")
+ except Exception as e:
+ raise e
+
+ def add_cache(self, result, **kwargs):
+ """
+ Adds a result to the cache.
+
+ Args:
+ *args: args to litellm.completion() or embedding()
+ **kwargs: kwargs to litellm.completion() or embedding()
+
+ Returns:
+ None
+ """
+ try:
+ if self.should_use_cache(**kwargs) is not True:
+ return
+ cache_key, cached_data, kwargs = self._add_cache_logic(
+ result=result, **kwargs
+ )
+ self.cache.set_cache(cache_key, cached_data, **kwargs)
+ except Exception as e:
+ verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
+
+ async def async_add_cache(self, result, **kwargs):
+ """
+ Async implementation of add_cache
+ """
+ try:
+ if self.should_use_cache(**kwargs) is not True:
+ return
+ if self.type == "redis" and self.redis_flush_size is not None:
+ # high traffic - fill in results in memory and then flush
+ await self.batch_cache_write(result, **kwargs)
+ else:
+ cache_key, cached_data, kwargs = self._add_cache_logic(
+ result=result, **kwargs
+ )
+
+ await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
+ except Exception as e:
+ verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
+
+ async def async_add_cache_pipeline(self, result, **kwargs):
+ """
+ Async implementation of add_cache for Embedding calls
+
+ Does a bulk write, to prevent using too many clients
+ """
+ try:
+ if self.should_use_cache(**kwargs) is not True:
+ return
+
+ # set default ttl if not set
+ if self.ttl is not None:
+ kwargs["ttl"] = self.ttl
+
+ cache_list = []
+ for idx, i in enumerate(kwargs["input"]):
+ preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
+ kwargs["cache_key"] = preset_cache_key
+ embedding_response = result.data[idx]
+ cache_key, cached_data, kwargs = self._add_cache_logic(
+ result=embedding_response,
+ **kwargs,
+ )
+ cache_list.append((cache_key, cached_data))
+
+ await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
+ # if async_set_cache_pipeline:
+ # await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
+ # else:
+ # tasks = []
+ # for val in cache_list:
+ # tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
+ # await asyncio.gather(*tasks)
+ except Exception as e:
+ verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
+
+ def should_use_cache(self, **kwargs):
+ """
+ Returns true if we should use the cache for LLM API calls
+
+ If cache is default_on then this is True
+ If cache is default_off then this is only true when user has opted in to use cache
+ """
+ if self.mode == CacheMode.default_on:
+ return True
+
+ # when mode == default_off -> Cache is opt in only
+ _cache = kwargs.get("cache", None)
+ verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
+ if _cache and isinstance(_cache, dict):
+ if _cache.get("use-cache", False) is True:
+ return True
+ return False
+
+ async def batch_cache_write(self, result, **kwargs):
+ cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
+ await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
+
+ async def ping(self):
+ cache_ping = getattr(self.cache, "ping")
+ if cache_ping:
+ return await cache_ping()
+ return None
+
+ async def delete_cache_keys(self, keys):
+ cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
+ if cache_delete_cache_keys:
+ return await cache_delete_cache_keys(keys)
+ return None
+
+ async def disconnect(self):
+ if hasattr(self.cache, "disconnect"):
+ await self.cache.disconnect()
+
+ def _supports_async(self) -> bool:
+ """
+ Internal method to check if the cache type supports async get/set operations
+
+ Only S3 Cache Does NOT support async operations
+
+ """
+ if self.type and self.type == LiteLLMCacheType.S3:
+ return False
+ return True
+
+
+def enable_cache(
+ type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
+ host: Optional[str] = None,
+ port: Optional[str] = None,
+ password: Optional[str] = None,
+ supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
+ "completion",
+ "acompletion",
+ "embedding",
+ "aembedding",
+ "atranscription",
+ "transcription",
+ "atext_completion",
+ "text_completion",
+ "arerank",
+ "rerank",
+ ],
+ **kwargs,
+):
+ """
+ Enable cache with the specified configuration.
+
+ Args:
+ type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
+ host (Optional[str]): The host address of the cache server. Defaults to None.
+ port (Optional[str]): The port number of the cache server. Defaults to None.
+ password (Optional[str]): The password for the cache server. Defaults to None.
+ supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
+ The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ None
+
+ Raises:
+ None
+ """
+ print_verbose("LiteLLM: Enabling Cache")
+ if "cache" not in litellm.input_callback:
+ litellm.input_callback.append("cache")
+ if "cache" not in litellm.success_callback:
+ litellm.logging_callback_manager.add_litellm_success_callback("cache")
+ if "cache" not in litellm._async_success_callback:
+ litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
+
+ if litellm.cache is None:
+ litellm.cache = Cache(
+ type=type,
+ host=host,
+ port=port,
+ password=password,
+ supported_call_types=supported_call_types,
+ **kwargs,
+ )
+ print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
+ print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
+
+
+def update_cache(
+ type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
+ host: Optional[str] = None,
+ port: Optional[str] = None,
+ password: Optional[str] = None,
+ supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
+ "completion",
+ "acompletion",
+ "embedding",
+ "aembedding",
+ "atranscription",
+ "transcription",
+ "atext_completion",
+ "text_completion",
+ "arerank",
+ "rerank",
+ ],
+ **kwargs,
+):
+ """
+ Update the cache for LiteLLM.
+
+ Args:
+ type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
+ host (Optional[str]): The host of the cache. Defaults to None.
+ port (Optional[str]): The port of the cache. Defaults to None.
+ password (Optional[str]): The password for the cache. Defaults to None.
+ supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
+ The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
+ **kwargs: Additional keyword arguments for the cache.
+
+ Returns:
+ None
+
+ """
+ print_verbose("LiteLLM: Updating Cache")
+ litellm.cache = Cache(
+ type=type,
+ host=host,
+ port=port,
+ password=password,
+ supported_call_types=supported_call_types,
+ **kwargs,
+ )
+ print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
+ print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
+
+
+def disable_cache():
+ """
+ Disable the cache used by LiteLLM.
+
+ This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
+
+ Parameters:
+ None
+
+ Returns:
+ None
+ """
+ from contextlib import suppress
+
+ print_verbose("LiteLLM: Disabling Cache")
+ with suppress(ValueError):
+ litellm.input_callback.remove("cache")
+ litellm.success_callback.remove("cache")
+ litellm._async_success_callback.remove("cache")
+
+ litellm.cache = None
+ print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/caching_handler.py b/.venv/lib/python3.12/site-packages/litellm/caching/caching_handler.py
new file mode 100644
index 00000000..09fabf1c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/caching_handler.py
@@ -0,0 +1,909 @@
+"""
+This contains LLMCachingHandler
+
+This exposes two methods:
+ - async_get_cache
+ - async_set_cache
+
+This file is a wrapper around caching.py
+
+This class is used to handle caching logic specific for LLM API requests (completion / embedding / text_completion / transcription etc)
+
+It utilizes the (RedisCache, s3Cache, RedisSemanticCache, QdrantSemanticCache, InMemoryCache, DiskCache) based on what the user has setup
+
+In each method it will call the appropriate method from caching.py
+"""
+
+import asyncio
+import datetime
+import inspect
+import threading
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncGenerator,
+ Callable,
+ Dict,
+ Generator,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
+
+from pydantic import BaseModel
+
+import litellm
+from litellm._logging import print_verbose, verbose_logger
+from litellm.caching.caching import S3Cache
+from litellm.litellm_core_utils.logging_utils import (
+ _assemble_complete_response_from_streaming_chunks,
+)
+from litellm.types.rerank import RerankResponse
+from litellm.types.utils import (
+ CallTypes,
+ Embedding,
+ EmbeddingResponse,
+ ModelResponse,
+ TextCompletionResponse,
+ TranscriptionResponse,
+)
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+ from litellm.utils import CustomStreamWrapper
+else:
+ LiteLLMLoggingObj = Any
+ CustomStreamWrapper = Any
+
+
+class CachingHandlerResponse(BaseModel):
+ """
+ This is the response object for the caching handler. We need to separate embedding cached responses and (completion / text_completion / transcription) cached responses
+
+ For embeddings there can be a cache hit for some of the inputs in the list and a cache miss for others
+ """
+
+ cached_result: Optional[Any] = None
+ final_embedding_cached_response: Optional[EmbeddingResponse] = None
+ embedding_all_elements_cache_hit: bool = (
+ False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
+ )
+
+
+class LLMCachingHandler:
+ def __init__(
+ self,
+ original_function: Callable,
+ request_kwargs: Dict[str, Any],
+ start_time: datetime.datetime,
+ ):
+ self.async_streaming_chunks: List[ModelResponse] = []
+ self.sync_streaming_chunks: List[ModelResponse] = []
+ self.request_kwargs = request_kwargs
+ self.original_function = original_function
+ self.start_time = start_time
+ pass
+
+ async def _async_get_cache(
+ self,
+ model: str,
+ original_function: Callable,
+ logging_obj: LiteLLMLoggingObj,
+ start_time: datetime.datetime,
+ call_type: str,
+ kwargs: Dict[str, Any],
+ args: Optional[Tuple[Any, ...]] = None,
+ ) -> CachingHandlerResponse:
+ """
+ Internal method to get from the cache.
+ Handles different call types (embeddings, chat/completions, text_completion, transcription)
+ and accordingly returns the cached response
+
+ Args:
+ model: str:
+ original_function: Callable:
+ logging_obj: LiteLLMLoggingObj:
+ start_time: datetime.datetime:
+ call_type: str:
+ kwargs: Dict[str, Any]:
+ args: Optional[Tuple[Any, ...]] = None:
+
+
+ Returns:
+ CachingHandlerResponse:
+ Raises:
+ None
+ """
+ from litellm.utils import CustomStreamWrapper
+
+ args = args or ()
+
+ final_embedding_cached_response: Optional[EmbeddingResponse] = None
+ embedding_all_elements_cache_hit: bool = False
+ cached_result: Optional[Any] = None
+ if (
+ (kwargs.get("caching", None) is None and litellm.cache is not None)
+ or kwargs.get("caching", False) is True
+ ) and (
+ kwargs.get("cache", {}).get("no-cache", False) is not True
+ ): # allow users to control returning cached responses from the completion function
+ if litellm.cache is not None and self._is_call_type_supported_by_cache(
+ original_function=original_function
+ ):
+ verbose_logger.debug("Checking Cache")
+ cached_result = await self._retrieve_from_cache(
+ call_type=call_type,
+ kwargs=kwargs,
+ args=args,
+ )
+
+ if cached_result is not None and not isinstance(cached_result, list):
+ verbose_logger.debug("Cache Hit!")
+ cache_hit = True
+ end_time = datetime.datetime.now()
+ model, _, _, _ = litellm.get_llm_provider(
+ model=model,
+ custom_llm_provider=kwargs.get("custom_llm_provider", None),
+ api_base=kwargs.get("api_base", None),
+ api_key=kwargs.get("api_key", None),
+ )
+ self._update_litellm_logging_obj_environment(
+ logging_obj=logging_obj,
+ model=model,
+ kwargs=kwargs,
+ cached_result=cached_result,
+ is_async=True,
+ )
+
+ call_type = original_function.__name__
+
+ cached_result = self._convert_cached_result_to_model_response(
+ cached_result=cached_result,
+ call_type=call_type,
+ kwargs=kwargs,
+ logging_obj=logging_obj,
+ model=model,
+ custom_llm_provider=kwargs.get("custom_llm_provider", None),
+ args=args,
+ )
+ if kwargs.get("stream", False) is False:
+ # LOG SUCCESS
+ self._async_log_cache_hit_on_callbacks(
+ logging_obj=logging_obj,
+ cached_result=cached_result,
+ start_time=start_time,
+ end_time=end_time,
+ cache_hit=cache_hit,
+ )
+ cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
+ **kwargs
+ )
+ if (
+ isinstance(cached_result, BaseModel)
+ or isinstance(cached_result, CustomStreamWrapper)
+ ) and hasattr(cached_result, "_hidden_params"):
+ cached_result._hidden_params["cache_key"] = cache_key # type: ignore
+ return CachingHandlerResponse(cached_result=cached_result)
+ elif (
+ call_type == CallTypes.aembedding.value
+ and cached_result is not None
+ and isinstance(cached_result, list)
+ and litellm.cache is not None
+ and not isinstance(
+ litellm.cache.cache, S3Cache
+ ) # s3 doesn't support bulk writing. Exclude.
+ ):
+ (
+ final_embedding_cached_response,
+ embedding_all_elements_cache_hit,
+ ) = self._process_async_embedding_cached_response(
+ final_embedding_cached_response=final_embedding_cached_response,
+ cached_result=cached_result,
+ kwargs=kwargs,
+ logging_obj=logging_obj,
+ start_time=start_time,
+ model=model,
+ )
+ return CachingHandlerResponse(
+ final_embedding_cached_response=final_embedding_cached_response,
+ embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
+ )
+ verbose_logger.debug(f"CACHE RESULT: {cached_result}")
+ return CachingHandlerResponse(
+ cached_result=cached_result,
+ final_embedding_cached_response=final_embedding_cached_response,
+ )
+
+ def _sync_get_cache(
+ self,
+ model: str,
+ original_function: Callable,
+ logging_obj: LiteLLMLoggingObj,
+ start_time: datetime.datetime,
+ call_type: str,
+ kwargs: Dict[str, Any],
+ args: Optional[Tuple[Any, ...]] = None,
+ ) -> CachingHandlerResponse:
+ from litellm.utils import CustomStreamWrapper
+
+ args = args or ()
+ new_kwargs = kwargs.copy()
+ new_kwargs.update(
+ convert_args_to_kwargs(
+ self.original_function,
+ args,
+ )
+ )
+ cached_result: Optional[Any] = None
+ if litellm.cache is not None and self._is_call_type_supported_by_cache(
+ original_function=original_function
+ ):
+ print_verbose("Checking Cache")
+ cached_result = litellm.cache.get_cache(**new_kwargs)
+ if cached_result is not None:
+ if "detail" in cached_result:
+ # implies an error occurred
+ pass
+ else:
+ call_type = original_function.__name__
+ cached_result = self._convert_cached_result_to_model_response(
+ cached_result=cached_result,
+ call_type=call_type,
+ kwargs=kwargs,
+ logging_obj=logging_obj,
+ model=model,
+ custom_llm_provider=kwargs.get("custom_llm_provider", None),
+ args=args,
+ )
+
+ # LOG SUCCESS
+ cache_hit = True
+ end_time = datetime.datetime.now()
+ (
+ model,
+ custom_llm_provider,
+ dynamic_api_key,
+ api_base,
+ ) = litellm.get_llm_provider(
+ model=model or "",
+ custom_llm_provider=kwargs.get("custom_llm_provider", None),
+ api_base=kwargs.get("api_base", None),
+ api_key=kwargs.get("api_key", None),
+ )
+ self._update_litellm_logging_obj_environment(
+ logging_obj=logging_obj,
+ model=model,
+ kwargs=kwargs,
+ cached_result=cached_result,
+ is_async=False,
+ )
+
+ threading.Thread(
+ target=logging_obj.success_handler,
+ args=(cached_result, start_time, end_time, cache_hit),
+ ).start()
+ cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
+ **kwargs
+ )
+ if (
+ isinstance(cached_result, BaseModel)
+ or isinstance(cached_result, CustomStreamWrapper)
+ ) and hasattr(cached_result, "_hidden_params"):
+ cached_result._hidden_params["cache_key"] = cache_key # type: ignore
+ return CachingHandlerResponse(cached_result=cached_result)
+ return CachingHandlerResponse(cached_result=cached_result)
+
+ def _process_async_embedding_cached_response(
+ self,
+ final_embedding_cached_response: Optional[EmbeddingResponse],
+ cached_result: List[Optional[Dict[str, Any]]],
+ kwargs: Dict[str, Any],
+ logging_obj: LiteLLMLoggingObj,
+ start_time: datetime.datetime,
+ model: str,
+ ) -> Tuple[Optional[EmbeddingResponse], bool]:
+ """
+ Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
+
+ For embedding responses, there can be a cache hit for some of the inputs in the list and a cache miss for others
+ This function processes the cached embedding responses and returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
+
+ Args:
+ final_embedding_cached_response: Optional[EmbeddingResponse]:
+ cached_result: List[Optional[Dict[str, Any]]]:
+ kwargs: Dict[str, Any]:
+ logging_obj: LiteLLMLoggingObj:
+ start_time: datetime.datetime:
+ model: str:
+
+ Returns:
+ Tuple[Optional[EmbeddingResponse], bool]:
+ Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
+
+
+ """
+ embedding_all_elements_cache_hit: bool = False
+ remaining_list = []
+ non_null_list = []
+ for idx, cr in enumerate(cached_result):
+ if cr is None:
+ remaining_list.append(kwargs["input"][idx])
+ else:
+ non_null_list.append((idx, cr))
+ original_kwargs_input = kwargs["input"]
+ kwargs["input"] = remaining_list
+ if len(non_null_list) > 0:
+ print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}")
+ final_embedding_cached_response = EmbeddingResponse(
+ model=kwargs.get("model"),
+ data=[None] * len(original_kwargs_input),
+ )
+ final_embedding_cached_response._hidden_params["cache_hit"] = True
+
+ for val in non_null_list:
+ idx, cr = val # (idx, cr) tuple
+ if cr is not None:
+ final_embedding_cached_response.data[idx] = Embedding(
+ embedding=cr["embedding"],
+ index=idx,
+ object="embedding",
+ )
+ if len(remaining_list) == 0:
+ # LOG SUCCESS
+ cache_hit = True
+ embedding_all_elements_cache_hit = True
+ end_time = datetime.datetime.now()
+ (
+ model,
+ custom_llm_provider,
+ dynamic_api_key,
+ api_base,
+ ) = litellm.get_llm_provider(
+ model=model,
+ custom_llm_provider=kwargs.get("custom_llm_provider", None),
+ api_base=kwargs.get("api_base", None),
+ api_key=kwargs.get("api_key", None),
+ )
+
+ self._update_litellm_logging_obj_environment(
+ logging_obj=logging_obj,
+ model=model,
+ kwargs=kwargs,
+ cached_result=final_embedding_cached_response,
+ is_async=True,
+ is_embedding=True,
+ )
+ self._async_log_cache_hit_on_callbacks(
+ logging_obj=logging_obj,
+ cached_result=final_embedding_cached_response,
+ start_time=start_time,
+ end_time=end_time,
+ cache_hit=cache_hit,
+ )
+ return final_embedding_cached_response, embedding_all_elements_cache_hit
+ return final_embedding_cached_response, embedding_all_elements_cache_hit
+
+ def _combine_cached_embedding_response_with_api_result(
+ self,
+ _caching_handler_response: CachingHandlerResponse,
+ embedding_response: EmbeddingResponse,
+ start_time: datetime.datetime,
+ end_time: datetime.datetime,
+ ) -> EmbeddingResponse:
+ """
+ Combines the cached embedding response with the API EmbeddingResponse
+
+ For caching there can be a cache hit for some of the inputs in the list and a cache miss for others
+ This function combines the cached embedding response with the API EmbeddingResponse
+
+ Args:
+ caching_handler_response: CachingHandlerResponse:
+ embedding_response: EmbeddingResponse:
+
+ Returns:
+ EmbeddingResponse:
+ """
+ if _caching_handler_response.final_embedding_cached_response is None:
+ return embedding_response
+
+ idx = 0
+ final_data_list = []
+ for item in _caching_handler_response.final_embedding_cached_response.data:
+ if item is None and embedding_response.data is not None:
+ final_data_list.append(embedding_response.data[idx])
+ idx += 1
+ else:
+ final_data_list.append(item)
+
+ _caching_handler_response.final_embedding_cached_response.data = final_data_list
+ _caching_handler_response.final_embedding_cached_response._hidden_params[
+ "cache_hit"
+ ] = True
+ _caching_handler_response.final_embedding_cached_response._response_ms = (
+ end_time - start_time
+ ).total_seconds() * 1000
+ return _caching_handler_response.final_embedding_cached_response
+
+ def _async_log_cache_hit_on_callbacks(
+ self,
+ logging_obj: LiteLLMLoggingObj,
+ cached_result: Any,
+ start_time: datetime.datetime,
+ end_time: datetime.datetime,
+ cache_hit: bool,
+ ):
+ """
+ Helper function to log the success of a cached result on callbacks
+
+ Args:
+ logging_obj (LiteLLMLoggingObj): The logging object.
+ cached_result: The cached result.
+ start_time (datetime): The start time of the operation.
+ end_time (datetime): The end time of the operation.
+ cache_hit (bool): Whether it was a cache hit.
+ """
+ asyncio.create_task(
+ logging_obj.async_success_handler(
+ cached_result, start_time, end_time, cache_hit
+ )
+ )
+ threading.Thread(
+ target=logging_obj.success_handler,
+ args=(cached_result, start_time, end_time, cache_hit),
+ ).start()
+
+ async def _retrieve_from_cache(
+ self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]
+ ) -> Optional[Any]:
+ """
+ Internal method to
+ - get cache key
+ - check what type of cache is used - Redis, RedisSemantic, Qdrant, S3
+ - async get cache value
+ - return the cached value
+
+ Args:
+ call_type: str:
+ kwargs: Dict[str, Any]:
+ args: Optional[Tuple[Any, ...]] = None:
+
+ Returns:
+ Optional[Any]:
+ Raises:
+ None
+ """
+ if litellm.cache is None:
+ return None
+
+ new_kwargs = kwargs.copy()
+ new_kwargs.update(
+ convert_args_to_kwargs(
+ self.original_function,
+ args,
+ )
+ )
+ cached_result: Optional[Any] = None
+ if call_type == CallTypes.aembedding.value and isinstance(
+ new_kwargs["input"], list
+ ):
+ tasks = []
+ for idx, i in enumerate(new_kwargs["input"]):
+ preset_cache_key = litellm.cache.get_cache_key(
+ **{**new_kwargs, "input": i}
+ )
+ tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
+ cached_result = await asyncio.gather(*tasks)
+ ## check if cached result is None ##
+ if cached_result is not None and isinstance(cached_result, list):
+ # set cached_result to None if all elements are None
+ if all(result is None for result in cached_result):
+ cached_result = None
+ else:
+ if litellm.cache._supports_async() is True:
+ cached_result = await litellm.cache.async_get_cache(**new_kwargs)
+ else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
+ cached_result = litellm.cache.get_cache(**new_kwargs)
+ return cached_result
+
+ def _convert_cached_result_to_model_response(
+ self,
+ cached_result: Any,
+ call_type: str,
+ kwargs: Dict[str, Any],
+ logging_obj: LiteLLMLoggingObj,
+ model: str,
+ args: Tuple[Any, ...],
+ custom_llm_provider: Optional[str] = None,
+ ) -> Optional[
+ Union[
+ ModelResponse,
+ TextCompletionResponse,
+ EmbeddingResponse,
+ RerankResponse,
+ TranscriptionResponse,
+ CustomStreamWrapper,
+ ]
+ ]:
+ """
+ Internal method to process the cached result
+
+ Checks the call type and converts the cached result to the appropriate model response object
+ example if call type is text_completion -> returns TextCompletionResponse object
+
+ Args:
+ cached_result: Any:
+ call_type: str:
+ kwargs: Dict[str, Any]:
+ logging_obj: LiteLLMLoggingObj:
+ model: str:
+ custom_llm_provider: Optional[str] = None:
+ args: Optional[Tuple[Any, ...]] = None:
+
+ Returns:
+ Optional[Any]:
+ """
+ from litellm.utils import convert_to_model_response_object
+
+ if (
+ call_type == CallTypes.acompletion.value
+ or call_type == CallTypes.completion.value
+ ) and isinstance(cached_result, dict):
+ if kwargs.get("stream", False) is True:
+ cached_result = self._convert_cached_stream_response(
+ cached_result=cached_result,
+ call_type=call_type,
+ logging_obj=logging_obj,
+ model=model,
+ )
+ else:
+ cached_result = convert_to_model_response_object(
+ response_object=cached_result,
+ model_response_object=ModelResponse(),
+ )
+ if (
+ call_type == CallTypes.atext_completion.value
+ or call_type == CallTypes.text_completion.value
+ ) and isinstance(cached_result, dict):
+ if kwargs.get("stream", False) is True:
+ cached_result = self._convert_cached_stream_response(
+ cached_result=cached_result,
+ call_type=call_type,
+ logging_obj=logging_obj,
+ model=model,
+ )
+ else:
+ cached_result = TextCompletionResponse(**cached_result)
+ elif (
+ call_type == CallTypes.aembedding.value
+ or call_type == CallTypes.embedding.value
+ ) and isinstance(cached_result, dict):
+ cached_result = convert_to_model_response_object(
+ response_object=cached_result,
+ model_response_object=EmbeddingResponse(),
+ response_type="embedding",
+ )
+
+ elif (
+ call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
+ ) and isinstance(cached_result, dict):
+ cached_result = convert_to_model_response_object(
+ response_object=cached_result,
+ model_response_object=None,
+ response_type="rerank",
+ )
+ elif (
+ call_type == CallTypes.atranscription.value
+ or call_type == CallTypes.transcription.value
+ ) and isinstance(cached_result, dict):
+ hidden_params = {
+ "model": "whisper-1",
+ "custom_llm_provider": custom_llm_provider,
+ "cache_hit": True,
+ }
+ cached_result = convert_to_model_response_object(
+ response_object=cached_result,
+ model_response_object=TranscriptionResponse(),
+ response_type="audio_transcription",
+ hidden_params=hidden_params,
+ )
+
+ if (
+ hasattr(cached_result, "_hidden_params")
+ and cached_result._hidden_params is not None
+ and isinstance(cached_result._hidden_params, dict)
+ ):
+ cached_result._hidden_params["cache_hit"] = True
+ return cached_result
+
+ def _convert_cached_stream_response(
+ self,
+ cached_result: Any,
+ call_type: str,
+ logging_obj: LiteLLMLoggingObj,
+ model: str,
+ ) -> CustomStreamWrapper:
+ from litellm.utils import (
+ CustomStreamWrapper,
+ convert_to_streaming_response,
+ convert_to_streaming_response_async,
+ )
+
+ _stream_cached_result: Union[AsyncGenerator, Generator]
+ if (
+ call_type == CallTypes.acompletion.value
+ or call_type == CallTypes.atext_completion.value
+ ):
+ _stream_cached_result = convert_to_streaming_response_async(
+ response_object=cached_result,
+ )
+ else:
+ _stream_cached_result = convert_to_streaming_response(
+ response_object=cached_result,
+ )
+ return CustomStreamWrapper(
+ completion_stream=_stream_cached_result,
+ model=model,
+ custom_llm_provider="cached_response",
+ logging_obj=logging_obj,
+ )
+
+ async def async_set_cache(
+ self,
+ result: Any,
+ original_function: Callable,
+ kwargs: Dict[str, Any],
+ args: Optional[Tuple[Any, ...]] = None,
+ ):
+ """
+ Internal method to check the type of the result & cache used and adds the result to the cache accordingly
+
+ Args:
+ result: Any:
+ original_function: Callable:
+ kwargs: Dict[str, Any]:
+ args: Optional[Tuple[Any, ...]] = None:
+
+ Returns:
+ None
+ Raises:
+ None
+ """
+ if litellm.cache is None:
+ return
+
+ new_kwargs = kwargs.copy()
+ new_kwargs.update(
+ convert_args_to_kwargs(
+ original_function,
+ args,
+ )
+ )
+ # [OPTIONAL] ADD TO CACHE
+ if self._should_store_result_in_cache(
+ original_function=original_function, kwargs=new_kwargs
+ ):
+ if (
+ isinstance(result, litellm.ModelResponse)
+ or isinstance(result, litellm.EmbeddingResponse)
+ or isinstance(result, TranscriptionResponse)
+ or isinstance(result, RerankResponse)
+ ):
+ if (
+ isinstance(result, EmbeddingResponse)
+ and isinstance(new_kwargs["input"], list)
+ and litellm.cache is not None
+ and not isinstance(
+ litellm.cache.cache, S3Cache
+ ) # s3 doesn't support bulk writing. Exclude.
+ ):
+ asyncio.create_task(
+ litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
+ )
+ elif isinstance(litellm.cache.cache, S3Cache):
+ threading.Thread(
+ target=litellm.cache.add_cache,
+ args=(result,),
+ kwargs=new_kwargs,
+ ).start()
+ else:
+ asyncio.create_task(
+ litellm.cache.async_add_cache(
+ result.model_dump_json(), **new_kwargs
+ )
+ )
+ else:
+ asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs))
+
+ def sync_set_cache(
+ self,
+ result: Any,
+ kwargs: Dict[str, Any],
+ args: Optional[Tuple[Any, ...]] = None,
+ ):
+ """
+ Sync internal method to add the result to the cache
+ """
+
+ new_kwargs = kwargs.copy()
+ new_kwargs.update(
+ convert_args_to_kwargs(
+ self.original_function,
+ args,
+ )
+ )
+ if litellm.cache is None:
+ return
+
+ if self._should_store_result_in_cache(
+ original_function=self.original_function, kwargs=new_kwargs
+ ):
+
+ litellm.cache.add_cache(result, **new_kwargs)
+
+ return
+
+ def _should_store_result_in_cache(
+ self, original_function: Callable, kwargs: Dict[str, Any]
+ ) -> bool:
+ """
+ Helper function to determine if the result should be stored in the cache.
+
+ Returns:
+ bool: True if the result should be stored in the cache, False otherwise.
+ """
+ return (
+ (litellm.cache is not None)
+ and litellm.cache.supported_call_types is not None
+ and (str(original_function.__name__) in litellm.cache.supported_call_types)
+ and (kwargs.get("cache", {}).get("no-store", False) is not True)
+ )
+
+ def _is_call_type_supported_by_cache(
+ self,
+ original_function: Callable,
+ ) -> bool:
+ """
+ Helper function to determine if the call type is supported by the cache.
+
+ call types are acompletion, aembedding, atext_completion, atranscription, arerank
+
+ Defined on `litellm.types.utils.CallTypes`
+
+ Returns:
+ bool: True if the call type is supported by the cache, False otherwise.
+ """
+ if (
+ litellm.cache is not None
+ and litellm.cache.supported_call_types is not None
+ and str(original_function.__name__) in litellm.cache.supported_call_types
+ ):
+ return True
+ return False
+
+ async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
+ """
+ Internal method to add the streaming response to the cache
+
+
+ - If 'streaming_chunk' has a 'finish_reason' then assemble a litellm.ModelResponse object
+ - Else append the chunk to self.async_streaming_chunks
+
+ """
+
+ complete_streaming_response: Optional[
+ Union[ModelResponse, TextCompletionResponse]
+ ] = _assemble_complete_response_from_streaming_chunks(
+ result=processed_chunk,
+ start_time=self.start_time,
+ end_time=datetime.datetime.now(),
+ request_kwargs=self.request_kwargs,
+ streaming_chunks=self.async_streaming_chunks,
+ is_async=True,
+ )
+ # if a complete_streaming_response is assembled, add it to the cache
+ if complete_streaming_response is not None:
+ await self.async_set_cache(
+ result=complete_streaming_response,
+ original_function=self.original_function,
+ kwargs=self.request_kwargs,
+ )
+
+ def _sync_add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
+ """
+ Sync internal method to add the streaming response to the cache
+ """
+ complete_streaming_response: Optional[
+ Union[ModelResponse, TextCompletionResponse]
+ ] = _assemble_complete_response_from_streaming_chunks(
+ result=processed_chunk,
+ start_time=self.start_time,
+ end_time=datetime.datetime.now(),
+ request_kwargs=self.request_kwargs,
+ streaming_chunks=self.sync_streaming_chunks,
+ is_async=False,
+ )
+
+ # if a complete_streaming_response is assembled, add it to the cache
+ if complete_streaming_response is not None:
+ self.sync_set_cache(
+ result=complete_streaming_response,
+ kwargs=self.request_kwargs,
+ )
+
+ def _update_litellm_logging_obj_environment(
+ self,
+ logging_obj: LiteLLMLoggingObj,
+ model: str,
+ kwargs: Dict[str, Any],
+ cached_result: Any,
+ is_async: bool,
+ is_embedding: bool = False,
+ ):
+ """
+ Helper function to update the LiteLLMLoggingObj environment variables.
+
+ Args:
+ logging_obj (LiteLLMLoggingObj): The logging object to update.
+ model (str): The model being used.
+ kwargs (Dict[str, Any]): The keyword arguments from the original function call.
+ cached_result (Any): The cached result to log.
+ is_async (bool): Whether the call is asynchronous or not.
+ is_embedding (bool): Whether the call is for embeddings or not.
+
+ Returns:
+ None
+ """
+ litellm_params = {
+ "logger_fn": kwargs.get("logger_fn", None),
+ "acompletion": is_async,
+ "api_base": kwargs.get("api_base", ""),
+ "metadata": kwargs.get("metadata", {}),
+ "model_info": kwargs.get("model_info", {}),
+ "proxy_server_request": kwargs.get("proxy_server_request", None),
+ "stream_response": kwargs.get("stream_response", {}),
+ }
+
+ if litellm.cache is not None:
+ litellm_params["preset_cache_key"] = (
+ litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
+ )
+ else:
+ litellm_params["preset_cache_key"] = None
+
+ logging_obj.update_environment_variables(
+ model=model,
+ user=kwargs.get("user", None),
+ optional_params={},
+ litellm_params=litellm_params,
+ input=(
+ kwargs.get("messages", "")
+ if not is_embedding
+ else kwargs.get("input", "")
+ ),
+ api_key=kwargs.get("api_key", None),
+ original_response=str(cached_result),
+ additional_args=None,
+ stream=kwargs.get("stream", False),
+ )
+
+
+def convert_args_to_kwargs(
+ original_function: Callable,
+ args: Optional[Tuple[Any, ...]] = None,
+) -> Dict[str, Any]:
+ # Get the signature of the original function
+ signature = inspect.signature(original_function)
+
+ # Get parameter names in the order they appear in the original function
+ param_names = list(signature.parameters.keys())
+
+ # Create a mapping of positional arguments to parameter names
+ args_to_kwargs = {}
+ if args:
+ for index, arg in enumerate(args):
+ if index < len(param_names):
+ param_name = param_names[index]
+ args_to_kwargs[param_name] = arg
+
+ return args_to_kwargs
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py
new file mode 100644
index 00000000..abf3203f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py
@@ -0,0 +1,88 @@
+import json
+from typing import TYPE_CHECKING, Any, Optional
+
+from .base_cache import BaseCache
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ Span = _Span
+else:
+ Span = Any
+
+
+class DiskCache(BaseCache):
+ def __init__(self, disk_cache_dir: Optional[str] = None):
+ import diskcache as dc
+
+ # if users don't provider one, use the default litellm cache
+ if disk_cache_dir is None:
+ self.disk_cache = dc.Cache(".litellm_cache")
+ else:
+ self.disk_cache = dc.Cache(disk_cache_dir)
+
+ def set_cache(self, key, value, **kwargs):
+ if "ttl" in kwargs:
+ self.disk_cache.set(key, value, expire=kwargs["ttl"])
+ else:
+ self.disk_cache.set(key, value)
+
+ async def async_set_cache(self, key, value, **kwargs):
+ self.set_cache(key=key, value=value, **kwargs)
+
+ async def async_set_cache_pipeline(self, cache_list, **kwargs):
+ for cache_key, cache_value in cache_list:
+ if "ttl" in kwargs:
+ self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
+ else:
+ self.set_cache(key=cache_key, value=cache_value)
+
+ def get_cache(self, key, **kwargs):
+ original_cached_response = self.disk_cache.get(key)
+ if original_cached_response:
+ try:
+ cached_response = json.loads(original_cached_response) # type: ignore
+ except Exception:
+ cached_response = original_cached_response
+ return cached_response
+ return None
+
+ def batch_get_cache(self, keys: list, **kwargs):
+ return_val = []
+ for k in keys:
+ val = self.get_cache(key=k, **kwargs)
+ return_val.append(val)
+ return return_val
+
+ def increment_cache(self, key, value: int, **kwargs) -> int:
+ # get the value
+ init_value = self.get_cache(key=key) or 0
+ value = init_value + value # type: ignore
+ self.set_cache(key, value, **kwargs)
+ return value
+
+ async def async_get_cache(self, key, **kwargs):
+ return self.get_cache(key=key, **kwargs)
+
+ async def async_batch_get_cache(self, keys: list, **kwargs):
+ return_val = []
+ for k in keys:
+ val = self.get_cache(key=k, **kwargs)
+ return_val.append(val)
+ return return_val
+
+ async def async_increment(self, key, value: int, **kwargs) -> int:
+ # get the value
+ init_value = await self.async_get_cache(key=key) or 0
+ value = init_value + value # type: ignore
+ await self.async_set_cache(key, value, **kwargs)
+ return value
+
+ def flush_cache(self):
+ self.disk_cache.clear()
+
+ async def disconnect(self):
+ pass
+
+ def delete_cache(self, key):
+ self.disk_cache.pop(key)
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/dual_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/dual_cache.py
new file mode 100644
index 00000000..5f598f7d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/dual_cache.py
@@ -0,0 +1,434 @@
+"""
+Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously.
+
+Has 4 primary methods:
+ - set_cache
+ - get_cache
+ - async_set_cache
+ - async_get_cache
+"""
+
+import asyncio
+import time
+import traceback
+from concurrent.futures import ThreadPoolExecutor
+from typing import TYPE_CHECKING, Any, List, Optional
+
+import litellm
+from litellm._logging import print_verbose, verbose_logger
+
+from .base_cache import BaseCache
+from .in_memory_cache import InMemoryCache
+from .redis_cache import RedisCache
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ Span = _Span
+else:
+ Span = Any
+
+from collections import OrderedDict
+
+
+class LimitedSizeOrderedDict(OrderedDict):
+ def __init__(self, *args, max_size=100, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.max_size = max_size
+
+ def __setitem__(self, key, value):
+ # If inserting a new key exceeds max size, remove the oldest item
+ if len(self) >= self.max_size:
+ self.popitem(last=False)
+ super().__setitem__(key, value)
+
+
+class DualCache(BaseCache):
+ """
+ DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
+ When data is updated or inserted, it is written to both the in-memory cache + Redis.
+ This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
+ """
+
+ def __init__(
+ self,
+ in_memory_cache: Optional[InMemoryCache] = None,
+ redis_cache: Optional[RedisCache] = None,
+ default_in_memory_ttl: Optional[float] = None,
+ default_redis_ttl: Optional[float] = None,
+ default_redis_batch_cache_expiry: Optional[float] = None,
+ default_max_redis_batch_cache_size: int = 100,
+ ) -> None:
+ super().__init__()
+ # If in_memory_cache is not provided, use the default InMemoryCache
+ self.in_memory_cache = in_memory_cache or InMemoryCache()
+ # If redis_cache is not provided, use the default RedisCache
+ self.redis_cache = redis_cache
+ self.last_redis_batch_access_time = LimitedSizeOrderedDict(
+ max_size=default_max_redis_batch_cache_size
+ )
+ self.redis_batch_cache_expiry = (
+ default_redis_batch_cache_expiry
+ or litellm.default_redis_batch_cache_expiry
+ or 10
+ )
+ self.default_in_memory_ttl = (
+ default_in_memory_ttl or litellm.default_in_memory_ttl
+ )
+ self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
+
+ def update_cache_ttl(
+ self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float]
+ ):
+ if default_in_memory_ttl is not None:
+ self.default_in_memory_ttl = default_in_memory_ttl
+
+ if default_redis_ttl is not None:
+ self.default_redis_ttl = default_redis_ttl
+
+ def set_cache(self, key, value, local_only: bool = False, **kwargs):
+ # Update both Redis and in-memory cache
+ try:
+ if self.in_memory_cache is not None:
+ if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
+ kwargs["ttl"] = self.default_in_memory_ttl
+
+ self.in_memory_cache.set_cache(key, value, **kwargs)
+
+ if self.redis_cache is not None and local_only is False:
+ self.redis_cache.set_cache(key, value, **kwargs)
+ except Exception as e:
+ print_verbose(e)
+
+ def increment_cache(
+ self, key, value: int, local_only: bool = False, **kwargs
+ ) -> int:
+ """
+ Key - the key in cache
+
+ Value - int - the value you want to increment by
+
+ Returns - int - the incremented value
+ """
+ try:
+ result: int = value
+ if self.in_memory_cache is not None:
+ result = self.in_memory_cache.increment_cache(key, value, **kwargs)
+
+ if self.redis_cache is not None and local_only is False:
+ result = self.redis_cache.increment_cache(key, value, **kwargs)
+
+ return result
+ except Exception as e:
+ verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
+ raise e
+
+ def get_cache(
+ self,
+ key,
+ parent_otel_span: Optional[Span] = None,
+ local_only: bool = False,
+ **kwargs,
+ ):
+ # Try to fetch from in-memory cache first
+ try:
+ result = None
+ if self.in_memory_cache is not None:
+ in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
+
+ if in_memory_result is not None:
+ result = in_memory_result
+
+ if result is None and self.redis_cache is not None and local_only is False:
+ # If not found in in-memory cache, try fetching from Redis
+ redis_result = self.redis_cache.get_cache(
+ key, parent_otel_span=parent_otel_span
+ )
+
+ if redis_result is not None:
+ # Update in-memory cache with the value from Redis
+ self.in_memory_cache.set_cache(key, redis_result, **kwargs)
+
+ result = redis_result
+
+ print_verbose(f"get cache: cache result: {result}")
+ return result
+ except Exception:
+ verbose_logger.error(traceback.format_exc())
+
+ def batch_get_cache(
+ self,
+ keys: list,
+ parent_otel_span: Optional[Span] = None,
+ local_only: bool = False,
+ **kwargs,
+ ):
+ received_args = locals()
+ received_args.pop("self")
+
+ def run_in_new_loop():
+ """Run the coroutine in a new event loop within this thread."""
+ new_loop = asyncio.new_event_loop()
+ try:
+ asyncio.set_event_loop(new_loop)
+ return new_loop.run_until_complete(
+ self.async_batch_get_cache(**received_args)
+ )
+ finally:
+ new_loop.close()
+ asyncio.set_event_loop(None)
+
+ try:
+ # First, try to get the current event loop
+ _ = asyncio.get_running_loop()
+ # If we're already in an event loop, run in a separate thread
+ # to avoid nested event loop issues
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(run_in_new_loop)
+ return future.result()
+
+ except RuntimeError:
+ # No running event loop, we can safely run in this thread
+ return run_in_new_loop()
+
+ async def async_get_cache(
+ self,
+ key,
+ parent_otel_span: Optional[Span] = None,
+ local_only: bool = False,
+ **kwargs,
+ ):
+ # Try to fetch from in-memory cache first
+ try:
+ print_verbose(
+ f"async get cache: cache key: {key}; local_only: {local_only}"
+ )
+ result = None
+ if self.in_memory_cache is not None:
+ in_memory_result = await self.in_memory_cache.async_get_cache(
+ key, **kwargs
+ )
+
+ print_verbose(f"in_memory_result: {in_memory_result}")
+ if in_memory_result is not None:
+ result = in_memory_result
+
+ if result is None and self.redis_cache is not None and local_only is False:
+ # If not found in in-memory cache, try fetching from Redis
+ redis_result = await self.redis_cache.async_get_cache(
+ key, parent_otel_span=parent_otel_span
+ )
+
+ if redis_result is not None:
+ # Update in-memory cache with the value from Redis
+ await self.in_memory_cache.async_set_cache(
+ key, redis_result, **kwargs
+ )
+
+ result = redis_result
+
+ print_verbose(f"get cache: cache result: {result}")
+ return result
+ except Exception:
+ verbose_logger.error(traceback.format_exc())
+
+ def get_redis_batch_keys(
+ self,
+ current_time: float,
+ keys: List[str],
+ result: List[Any],
+ ) -> List[str]:
+ sublist_keys = []
+ for key, value in zip(keys, result):
+ if value is None:
+ if (
+ key not in self.last_redis_batch_access_time
+ or current_time - self.last_redis_batch_access_time[key]
+ >= self.redis_batch_cache_expiry
+ ):
+ sublist_keys.append(key)
+ return sublist_keys
+
+ async def async_batch_get_cache(
+ self,
+ keys: list,
+ parent_otel_span: Optional[Span] = None,
+ local_only: bool = False,
+ **kwargs,
+ ):
+ try:
+ result = [None for _ in range(len(keys))]
+ if self.in_memory_cache is not None:
+ in_memory_result = await self.in_memory_cache.async_batch_get_cache(
+ keys, **kwargs
+ )
+
+ if in_memory_result is not None:
+ result = in_memory_result
+
+ if None in result and self.redis_cache is not None and local_only is False:
+ """
+ - for the none values in the result
+ - check the redis cache
+ """
+ current_time = time.time()
+ sublist_keys = self.get_redis_batch_keys(current_time, keys, result)
+
+ # Only hit Redis if the last access time was more than 5 seconds ago
+ if len(sublist_keys) > 0:
+ # If not found in in-memory cache, try fetching from Redis
+ redis_result = await self.redis_cache.async_batch_get_cache(
+ sublist_keys, parent_otel_span=parent_otel_span
+ )
+
+ if redis_result is not None:
+ # Update in-memory cache with the value from Redis
+ for key, value in redis_result.items():
+ if value is not None:
+ await self.in_memory_cache.async_set_cache(
+ key, redis_result[key], **kwargs
+ )
+ # Update the last access time for each key fetched from Redis
+ self.last_redis_batch_access_time[key] = current_time
+
+ for key, value in redis_result.items():
+ index = keys.index(key)
+ result[index] = value
+
+ return result
+ except Exception:
+ verbose_logger.error(traceback.format_exc())
+
+ async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
+ print_verbose(
+ f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
+ )
+ try:
+ if self.in_memory_cache is not None:
+ await self.in_memory_cache.async_set_cache(key, value, **kwargs)
+
+ if self.redis_cache is not None and local_only is False:
+ await self.redis_cache.async_set_cache(key, value, **kwargs)
+ except Exception as e:
+ verbose_logger.exception(
+ f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
+ )
+
+ # async_batch_set_cache
+ async def async_set_cache_pipeline(
+ self, cache_list: list, local_only: bool = False, **kwargs
+ ):
+ """
+ Batch write values to the cache
+ """
+ print_verbose(
+ f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
+ )
+ try:
+ if self.in_memory_cache is not None:
+ await self.in_memory_cache.async_set_cache_pipeline(
+ cache_list=cache_list, **kwargs
+ )
+
+ if self.redis_cache is not None and local_only is False:
+ await self.redis_cache.async_set_cache_pipeline(
+ cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
+ )
+ except Exception as e:
+ verbose_logger.exception(
+ f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
+ )
+
+ async def async_increment_cache(
+ self,
+ key,
+ value: float,
+ parent_otel_span: Optional[Span] = None,
+ local_only: bool = False,
+ **kwargs,
+ ) -> float:
+ """
+ Key - the key in cache
+
+ Value - float - the value you want to increment by
+
+ Returns - float - the incremented value
+ """
+ try:
+ result: float = value
+ if self.in_memory_cache is not None:
+ result = await self.in_memory_cache.async_increment(
+ key, value, **kwargs
+ )
+
+ if self.redis_cache is not None and local_only is False:
+ result = await self.redis_cache.async_increment(
+ key,
+ value,
+ parent_otel_span=parent_otel_span,
+ ttl=kwargs.get("ttl", None),
+ )
+
+ return result
+ except Exception as e:
+ raise e # don't log if exception is raised
+
+ async def async_set_cache_sadd(
+ self, key, value: List, local_only: bool = False, **kwargs
+ ) -> None:
+ """
+ Add value to a set
+
+ Key - the key in cache
+
+ Value - str - the value you want to add to the set
+
+ Returns - None
+ """
+ try:
+ if self.in_memory_cache is not None:
+ _ = await self.in_memory_cache.async_set_cache_sadd(
+ key, value, ttl=kwargs.get("ttl", None)
+ )
+
+ if self.redis_cache is not None and local_only is False:
+ _ = await self.redis_cache.async_set_cache_sadd(
+ key, value, ttl=kwargs.get("ttl", None)
+ )
+
+ return None
+ except Exception as e:
+ raise e # don't log, if exception is raised
+
+ def flush_cache(self):
+ if self.in_memory_cache is not None:
+ self.in_memory_cache.flush_cache()
+ if self.redis_cache is not None:
+ self.redis_cache.flush_cache()
+
+ def delete_cache(self, key):
+ """
+ Delete a key from the cache
+ """
+ if self.in_memory_cache is not None:
+ self.in_memory_cache.delete_cache(key)
+ if self.redis_cache is not None:
+ self.redis_cache.delete_cache(key)
+
+ async def async_delete_cache(self, key: str):
+ """
+ Delete a key from the cache
+ """
+ if self.in_memory_cache is not None:
+ self.in_memory_cache.delete_cache(key)
+ if self.redis_cache is not None:
+ await self.redis_cache.async_delete_cache(key)
+
+ async def async_get_ttl(self, key: str) -> Optional[int]:
+ """
+ Get the remaining TTL of a key in in-memory cache or redis
+ """
+ ttl = await self.in_memory_cache.async_get_ttl(key)
+ if ttl is None and self.redis_cache is not None:
+ ttl = await self.redis_cache.async_get_ttl(key)
+ return ttl
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/in_memory_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/in_memory_cache.py
new file mode 100644
index 00000000..5e09fe84
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/in_memory_cache.py
@@ -0,0 +1,202 @@
+"""
+In-Memory Cache implementation
+
+Has 4 methods:
+ - set_cache
+ - get_cache
+ - async_set_cache
+ - async_get_cache
+"""
+
+import json
+import sys
+import time
+from typing import Any, List, Optional
+
+from pydantic import BaseModel
+
+from ..constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
+from .base_cache import BaseCache
+
+
+class InMemoryCache(BaseCache):
+ def __init__(
+ self,
+ max_size_in_memory: Optional[int] = 200,
+ default_ttl: Optional[
+ int
+ ] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute
+ max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB
+ ):
+ """
+ max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default
+ """
+ self.max_size_in_memory = (
+ max_size_in_memory or 200
+ ) # set an upper bound of 200 items in-memory
+ self.default_ttl = default_ttl or 600
+ self.max_size_per_item = (
+ max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
+ ) # 1MB = 1024KB
+
+ # in-memory cache
+ self.cache_dict: dict = {}
+ self.ttl_dict: dict = {}
+
+ def check_value_size(self, value: Any):
+ """
+ Check if value size exceeds max_size_per_item (1MB)
+ Returns True if value size is acceptable, False otherwise
+ """
+ try:
+ # Fast path for common primitive types that are typically small
+ if (
+ isinstance(value, (bool, int, float, str))
+ and len(str(value)) < self.max_size_per_item * 512
+ ): # Conservative estimate
+ return True
+
+ # Direct size check for bytes objects
+ if isinstance(value, bytes):
+ return sys.getsizeof(value) / 1024 <= self.max_size_per_item
+
+ # Handle special types without full conversion when possible
+ if hasattr(value, "__sizeof__"): # Use __sizeof__ if available
+ size = value.__sizeof__() / 1024
+ return size <= self.max_size_per_item
+
+ # Fallback for complex types
+ if isinstance(value, BaseModel) and hasattr(
+ value, "model_dump"
+ ): # Pydantic v2
+ value = value.model_dump()
+ elif hasattr(value, "isoformat"): # datetime objects
+ return True # datetime strings are always small
+
+ # Only convert to JSON if absolutely necessary
+ if not isinstance(value, (str, bytes)):
+ value = json.dumps(value, default=str)
+
+ return sys.getsizeof(value) / 1024 <= self.max_size_per_item
+
+ except Exception:
+ return False
+
+ def evict_cache(self):
+ """
+ Eviction policy:
+ - check if any items in ttl_dict are expired -> remove them from ttl_dict and cache_dict
+
+
+ This guarantees the following:
+ - 1. When item ttl not set: At minimumm each item will remain in memory for 5 minutes
+ - 2. When ttl is set: the item will remain in memory for at least that amount of time
+ - 3. the size of in-memory cache is bounded
+
+ """
+ for key in list(self.ttl_dict.keys()):
+ if time.time() > self.ttl_dict[key]:
+ self.cache_dict.pop(key, None)
+ self.ttl_dict.pop(key, None)
+
+ # de-reference the removed item
+ # https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
+ # One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
+ # This can occur when an object is referenced by another object, but the reference is never removed.
+
+ def set_cache(self, key, value, **kwargs):
+ if len(self.cache_dict) >= self.max_size_in_memory:
+ # only evict when cache is full
+ self.evict_cache()
+ if not self.check_value_size(value):
+ return
+
+ self.cache_dict[key] = value
+ if "ttl" in kwargs and kwargs["ttl"] is not None:
+ self.ttl_dict[key] = time.time() + kwargs["ttl"]
+ else:
+ self.ttl_dict[key] = time.time() + self.default_ttl
+
+ async def async_set_cache(self, key, value, **kwargs):
+ self.set_cache(key=key, value=value, **kwargs)
+
+ async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
+ for cache_key, cache_value in cache_list:
+ if ttl is not None:
+ self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
+ else:
+ self.set_cache(key=cache_key, value=cache_value)
+
+ async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
+ """
+ Add value to set
+ """
+ # get the value
+ init_value = self.get_cache(key=key) or set()
+ for val in value:
+ init_value.add(val)
+ self.set_cache(key, init_value, ttl=ttl)
+ return value
+
+ def get_cache(self, key, **kwargs):
+ if key in self.cache_dict:
+ if key in self.ttl_dict:
+ if time.time() > self.ttl_dict[key]:
+ self.cache_dict.pop(key, None)
+ return None
+ original_cached_response = self.cache_dict[key]
+ try:
+ cached_response = json.loads(original_cached_response)
+ except Exception:
+ cached_response = original_cached_response
+ return cached_response
+ return None
+
+ def batch_get_cache(self, keys: list, **kwargs):
+ return_val = []
+ for k in keys:
+ val = self.get_cache(key=k, **kwargs)
+ return_val.append(val)
+ return return_val
+
+ def increment_cache(self, key, value: int, **kwargs) -> int:
+ # get the value
+ init_value = self.get_cache(key=key) or 0
+ value = init_value + value
+ self.set_cache(key, value, **kwargs)
+ return value
+
+ async def async_get_cache(self, key, **kwargs):
+ return self.get_cache(key=key, **kwargs)
+
+ async def async_batch_get_cache(self, keys: list, **kwargs):
+ return_val = []
+ for k in keys:
+ val = self.get_cache(key=k, **kwargs)
+ return_val.append(val)
+ return return_val
+
+ async def async_increment(self, key, value: float, **kwargs) -> float:
+ # get the value
+ init_value = await self.async_get_cache(key=key) or 0
+ value = init_value + value
+ await self.async_set_cache(key, value, **kwargs)
+
+ return value
+
+ def flush_cache(self):
+ self.cache_dict.clear()
+ self.ttl_dict.clear()
+
+ async def disconnect(self):
+ pass
+
+ def delete_cache(self, key):
+ self.cache_dict.pop(key, None)
+ self.ttl_dict.pop(key, None)
+
+ async def async_get_ttl(self, key: str) -> Optional[int]:
+ """
+ Get the remaining TTL of a key in in-memory cache
+ """
+ return self.ttl_dict.get(key, None)
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/llm_caching_handler.py b/.venv/lib/python3.12/site-packages/litellm/caching/llm_caching_handler.py
new file mode 100644
index 00000000..429634b7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/llm_caching_handler.py
@@ -0,0 +1,40 @@
+"""
+Add the event loop to the cache key, to prevent event loop closed errors.
+"""
+
+import asyncio
+
+from .in_memory_cache import InMemoryCache
+
+
+class LLMClientCache(InMemoryCache):
+
+ def update_cache_key_with_event_loop(self, key):
+ """
+ Add the event loop to the cache key, to prevent event loop closed errors.
+ If none, use the key as is.
+ """
+ try:
+ event_loop = asyncio.get_event_loop()
+ stringified_event_loop = str(id(event_loop))
+ return f"{key}-{stringified_event_loop}"
+ except Exception: # handle no current event loop
+ return key
+
+ def set_cache(self, key, value, **kwargs):
+ key = self.update_cache_key_with_event_loop(key)
+ return super().set_cache(key, value, **kwargs)
+
+ async def async_set_cache(self, key, value, **kwargs):
+ key = self.update_cache_key_with_event_loop(key)
+ return await super().async_set_cache(key, value, **kwargs)
+
+ def get_cache(self, key, **kwargs):
+ key = self.update_cache_key_with_event_loop(key)
+
+ return super().get_cache(key, **kwargs)
+
+ async def async_get_cache(self, key, **kwargs):
+ key = self.update_cache_key_with_event_loop(key)
+
+ return await super().async_get_cache(key, **kwargs)
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py
new file mode 100644
index 00000000..bdfd3770
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py
@@ -0,0 +1,430 @@
+"""
+Qdrant Semantic Cache implementation
+
+Has 4 methods:
+ - set_cache
+ - get_cache
+ - async_set_cache
+ - async_get_cache
+"""
+
+import ast
+import asyncio
+import json
+from typing import Any
+
+import litellm
+from litellm._logging import print_verbose
+
+from .base_cache import BaseCache
+
+
+class QdrantSemanticCache(BaseCache):
+ def __init__( # noqa: PLR0915
+ self,
+ qdrant_api_base=None,
+ qdrant_api_key=None,
+ collection_name=None,
+ similarity_threshold=None,
+ quantization_config=None,
+ embedding_model="text-embedding-ada-002",
+ host_type=None,
+ ):
+ import os
+
+ from litellm.llms.custom_httpx.http_handler import (
+ _get_httpx_client,
+ get_async_httpx_client,
+ httpxSpecialProvider,
+ )
+ from litellm.secret_managers.main import get_secret_str
+
+ if collection_name is None:
+ raise Exception("collection_name must be provided, passed None")
+
+ self.collection_name = collection_name
+ print_verbose(
+ f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
+ )
+
+ if similarity_threshold is None:
+ raise Exception("similarity_threshold must be provided, passed None")
+ self.similarity_threshold = similarity_threshold
+ self.embedding_model = embedding_model
+ headers = {}
+
+ # check if defined as os.environ/ variable
+ if qdrant_api_base:
+ if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
+ "os.environ/"
+ ):
+ qdrant_api_base = get_secret_str(qdrant_api_base)
+ if qdrant_api_key:
+ if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
+ "os.environ/"
+ ):
+ qdrant_api_key = get_secret_str(qdrant_api_key)
+
+ qdrant_api_base = (
+ qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
+ )
+ qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
+ headers = {"Content-Type": "application/json"}
+ if qdrant_api_key:
+ headers["api-key"] = qdrant_api_key
+
+ if qdrant_api_base is None:
+ raise ValueError("Qdrant url must be provided")
+
+ self.qdrant_api_base = qdrant_api_base
+ self.qdrant_api_key = qdrant_api_key
+ print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
+
+ self.headers = headers
+
+ self.sync_client = _get_httpx_client()
+ self.async_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.Caching
+ )
+
+ if quantization_config is None:
+ print_verbose(
+ "Quantization config is not provided. Default binary quantization will be used."
+ )
+ collection_exists = self.sync_client.get(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
+ headers=self.headers,
+ )
+ if collection_exists.status_code != 200:
+ raise ValueError(
+ f"Error from qdrant checking if /collections exist {collection_exists.text}"
+ )
+
+ if collection_exists.json()["result"]["exists"]:
+ collection_details = self.sync_client.get(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
+ headers=self.headers,
+ )
+ self.collection_info = collection_details.json()
+ print_verbose(
+ f"Collection already exists.\nCollection details:{self.collection_info}"
+ )
+ else:
+ if quantization_config is None or quantization_config == "binary":
+ quantization_params = {
+ "binary": {
+ "always_ram": False,
+ }
+ }
+ elif quantization_config == "scalar":
+ quantization_params = {
+ "scalar": {"type": "int8", "quantile": 0.99, "always_ram": False}
+ }
+ elif quantization_config == "product":
+ quantization_params = {
+ "product": {"compression": "x16", "always_ram": False}
+ }
+ else:
+ raise Exception(
+ "Quantization config must be one of 'scalar', 'binary' or 'product'"
+ )
+
+ new_collection_status = self.sync_client.put(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
+ json={
+ "vectors": {"size": 1536, "distance": "Cosine"},
+ "quantization_config": quantization_params,
+ },
+ headers=self.headers,
+ )
+ if new_collection_status.json()["result"]:
+ collection_details = self.sync_client.get(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
+ headers=self.headers,
+ )
+ self.collection_info = collection_details.json()
+ print_verbose(
+ f"New collection created.\nCollection details:{self.collection_info}"
+ )
+ else:
+ raise Exception("Error while creating new collection")
+
+ def _get_cache_logic(self, cached_response: Any):
+ if cached_response is None:
+ return cached_response
+ try:
+ cached_response = json.loads(
+ cached_response
+ ) # Convert string to dictionary
+ except Exception:
+ cached_response = ast.literal_eval(cached_response)
+ return cached_response
+
+ def set_cache(self, key, value, **kwargs):
+ print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
+ import uuid
+
+ # get the prompt
+ messages = kwargs["messages"]
+ prompt = ""
+ for message in messages:
+ prompt += message["content"]
+
+ # create an embedding for prompt
+ embedding_response = litellm.embedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ value = str(value)
+ assert isinstance(value, str)
+
+ data = {
+ "points": [
+ {
+ "id": str(uuid.uuid4()),
+ "vector": embedding,
+ "payload": {
+ "text": prompt,
+ "response": value,
+ },
+ },
+ ]
+ }
+ self.sync_client.put(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
+ headers=self.headers,
+ json=data,
+ )
+ return
+
+ def get_cache(self, key, **kwargs):
+ print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
+
+ # get the messages
+ messages = kwargs["messages"]
+ prompt = ""
+ for message in messages:
+ prompt += message["content"]
+
+ # convert to embedding
+ embedding_response = litellm.embedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ data = {
+ "vector": embedding,
+ "params": {
+ "quantization": {
+ "ignore": False,
+ "rescore": True,
+ "oversampling": 3.0,
+ }
+ },
+ "limit": 1,
+ "with_payload": True,
+ }
+
+ search_response = self.sync_client.post(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
+ headers=self.headers,
+ json=data,
+ )
+ results = search_response.json()["result"]
+
+ if results is None:
+ return None
+ if isinstance(results, list):
+ if len(results) == 0:
+ return None
+
+ similarity = results[0]["score"]
+ cached_prompt = results[0]["payload"]["text"]
+
+ # check similarity, if more than self.similarity_threshold, return results
+ print_verbose(
+ f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
+ )
+ if similarity >= self.similarity_threshold:
+ # cache hit !
+ cached_value = results[0]["payload"]["response"]
+ print_verbose(
+ f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
+ )
+ return self._get_cache_logic(cached_response=cached_value)
+ else:
+ # cache miss !
+ return None
+ pass
+
+ async def async_set_cache(self, key, value, **kwargs):
+ import uuid
+
+ from litellm.proxy.proxy_server import llm_model_list, llm_router
+
+ print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
+
+ # get the prompt
+ messages = kwargs["messages"]
+ prompt = ""
+ for message in messages:
+ prompt += message["content"]
+ # create an embedding for prompt
+ router_model_names = (
+ [m["model_name"] for m in llm_model_list]
+ if llm_model_list is not None
+ else []
+ )
+ if llm_router is not None and self.embedding_model in router_model_names:
+ user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
+ embedding_response = await llm_router.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ metadata={
+ "user_api_key": user_api_key,
+ "semantic-cache-embedding": True,
+ "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
+ },
+ )
+ else:
+ # convert to embedding
+ embedding_response = await litellm.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ value = str(value)
+ assert isinstance(value, str)
+
+ data = {
+ "points": [
+ {
+ "id": str(uuid.uuid4()),
+ "vector": embedding,
+ "payload": {
+ "text": prompt,
+ "response": value,
+ },
+ },
+ ]
+ }
+
+ await self.async_client.put(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
+ headers=self.headers,
+ json=data,
+ )
+ return
+
+ async def async_get_cache(self, key, **kwargs):
+ print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
+ from litellm.proxy.proxy_server import llm_model_list, llm_router
+
+ # get the messages
+ messages = kwargs["messages"]
+ prompt = ""
+ for message in messages:
+ prompt += message["content"]
+
+ router_model_names = (
+ [m["model_name"] for m in llm_model_list]
+ if llm_model_list is not None
+ else []
+ )
+ if llm_router is not None and self.embedding_model in router_model_names:
+ user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
+ embedding_response = await llm_router.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ metadata={
+ "user_api_key": user_api_key,
+ "semantic-cache-embedding": True,
+ "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
+ },
+ )
+ else:
+ # convert to embedding
+ embedding_response = await litellm.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ data = {
+ "vector": embedding,
+ "params": {
+ "quantization": {
+ "ignore": False,
+ "rescore": True,
+ "oversampling": 3.0,
+ }
+ },
+ "limit": 1,
+ "with_payload": True,
+ }
+
+ search_response = await self.async_client.post(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
+ headers=self.headers,
+ json=data,
+ )
+
+ results = search_response.json()["result"]
+
+ if results is None:
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
+ return None
+ if isinstance(results, list):
+ if len(results) == 0:
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
+ return None
+
+ similarity = results[0]["score"]
+ cached_prompt = results[0]["payload"]["text"]
+
+ # check similarity, if more than self.similarity_threshold, return results
+ print_verbose(
+ f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
+ )
+
+ # update kwargs["metadata"] with similarity, don't rewrite the original metadata
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
+
+ if similarity >= self.similarity_threshold:
+ # cache hit !
+ cached_value = results[0]["payload"]["response"]
+ print_verbose(
+ f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
+ )
+ return self._get_cache_logic(cached_response=cached_value)
+ else:
+ # cache miss !
+ return None
+ pass
+
+ async def _collection_info(self):
+ return self.collection_info
+
+ async def async_set_cache_pipeline(self, cache_list, **kwargs):
+ tasks = []
+ for val in cache_list:
+ tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
+ await asyncio.gather(*tasks)
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/redis_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/redis_cache.py
new file mode 100644
index 00000000..0571ac9f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/redis_cache.py
@@ -0,0 +1,1047 @@
+"""
+Redis Cache implementation
+
+Has 4 primary methods:
+ - set_cache
+ - get_cache
+ - async_set_cache
+ - async_get_cache
+"""
+
+import ast
+import asyncio
+import inspect
+import json
+import time
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
+
+import litellm
+from litellm._logging import print_verbose, verbose_logger
+from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
+from litellm.types.caching import RedisPipelineIncrementOperation
+from litellm.types.services import ServiceTypes
+
+from .base_cache import BaseCache
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+ from redis.asyncio import Redis, RedisCluster
+ from redis.asyncio.client import Pipeline
+ from redis.asyncio.cluster import ClusterPipeline
+
+ pipeline = Pipeline
+ cluster_pipeline = ClusterPipeline
+ async_redis_client = Redis
+ async_redis_cluster_client = RedisCluster
+ Span = _Span
+else:
+ pipeline = Any
+ cluster_pipeline = Any
+ async_redis_client = Any
+ async_redis_cluster_client = Any
+ Span = Any
+
+
+class RedisCache(BaseCache):
+ # if users don't provider one, use the default litellm cache
+
+ def __init__(
+ self,
+ host=None,
+ port=None,
+ password=None,
+ redis_flush_size: Optional[int] = 100,
+ namespace: Optional[str] = None,
+ startup_nodes: Optional[List] = None, # for redis-cluster
+ socket_timeout: Optional[float] = 5.0, # default 5 second timeout
+ **kwargs,
+ ):
+
+ from litellm._service_logger import ServiceLogging
+
+ from .._redis import get_redis_client, get_redis_connection_pool
+
+ redis_kwargs = {}
+ if host is not None:
+ redis_kwargs["host"] = host
+ if port is not None:
+ redis_kwargs["port"] = port
+ if password is not None:
+ redis_kwargs["password"] = password
+ if startup_nodes is not None:
+ redis_kwargs["startup_nodes"] = startup_nodes
+ if socket_timeout is not None:
+ redis_kwargs["socket_timeout"] = socket_timeout
+
+ ### HEALTH MONITORING OBJECT ###
+ if kwargs.get("service_logger_obj", None) is not None and isinstance(
+ kwargs["service_logger_obj"], ServiceLogging
+ ):
+ self.service_logger_obj = kwargs.pop("service_logger_obj")
+ else:
+ self.service_logger_obj = ServiceLogging()
+
+ redis_kwargs.update(kwargs)
+ self.redis_client = get_redis_client(**redis_kwargs)
+ self.redis_async_client: Optional[async_redis_client] = None
+ self.redis_kwargs = redis_kwargs
+ self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
+
+ # redis namespaces
+ self.namespace = namespace
+ # for high traffic, we store the redis results in memory and then batch write to redis
+ self.redis_batch_writing_buffer: list = []
+ if redis_flush_size is None:
+ self.redis_flush_size: int = 100
+ else:
+ self.redis_flush_size = redis_flush_size
+ self.redis_version = "Unknown"
+ try:
+ if not inspect.iscoroutinefunction(self.redis_client):
+ self.redis_version = self.redis_client.info()["redis_version"] # type: ignore
+ except Exception:
+ pass
+
+ ### ASYNC HEALTH PING ###
+ try:
+ # asyncio.get_running_loop().create_task(self.ping())
+ _ = asyncio.get_running_loop().create_task(self.ping())
+ except Exception as e:
+ if "no running event loop" in str(e):
+ verbose_logger.debug(
+ "Ignoring async redis ping. No running event loop."
+ )
+ else:
+ verbose_logger.error(
+ "Error connecting to Async Redis client - {}".format(str(e)),
+ extra={"error": str(e)},
+ )
+
+ ### SYNC HEALTH PING ###
+ try:
+ if hasattr(self.redis_client, "ping"):
+ self.redis_client.ping() # type: ignore
+ except Exception as e:
+ verbose_logger.error(
+ "Error connecting to Sync Redis client", extra={"error": str(e)}
+ )
+
+ if litellm.default_redis_ttl is not None:
+ super().__init__(default_ttl=int(litellm.default_redis_ttl))
+ else:
+ super().__init__() # defaults to 60s
+
+ def init_async_client(
+ self,
+ ) -> Union[async_redis_client, async_redis_cluster_client]:
+ from .._redis import get_redis_async_client
+
+ if self.redis_async_client is None:
+ self.redis_async_client = get_redis_async_client(
+ connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
+ )
+ return self.redis_async_client
+
+ def check_and_fix_namespace(self, key: str) -> str:
+ """
+ Make sure each key starts with the given namespace
+ """
+ if self.namespace is not None and not key.startswith(self.namespace):
+ key = self.namespace + ":" + key
+
+ return key
+
+ def set_cache(self, key, value, **kwargs):
+ ttl = self.get_ttl(**kwargs)
+ print_verbose(
+ f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
+ )
+ key = self.check_and_fix_namespace(key=key)
+ try:
+ start_time = time.time()
+ self.redis_client.set(name=key, value=str(value), ex=ttl)
+ end_time = time.time()
+ _duration = end_time - start_time
+ self.service_logger_obj.service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="set_cache",
+ start_time=start_time,
+ end_time=end_time,
+ )
+ except Exception as e:
+ # NON blocking - notify users Redis is throwing an exception
+ print_verbose(
+ f"litellm.caching.caching: set() - Got exception from REDIS : {str(e)}"
+ )
+
+ def increment_cache(
+ self, key, value: int, ttl: Optional[float] = None, **kwargs
+ ) -> int:
+ _redis_client = self.redis_client
+ start_time = time.time()
+ set_ttl = self.get_ttl(ttl=ttl)
+ try:
+ start_time = time.time()
+ result: int = _redis_client.incr(name=key, amount=value) # type: ignore
+ end_time = time.time()
+ _duration = end_time - start_time
+ self.service_logger_obj.service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="increment_cache",
+ start_time=start_time,
+ end_time=end_time,
+ )
+
+ if set_ttl is not None:
+ # check if key already has ttl, if not -> set ttl
+ start_time = time.time()
+ current_ttl = _redis_client.ttl(key)
+ end_time = time.time()
+ _duration = end_time - start_time
+ self.service_logger_obj.service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="increment_cache_ttl",
+ start_time=start_time,
+ end_time=end_time,
+ )
+ if current_ttl == -1:
+ # Key has no expiration
+ start_time = time.time()
+ _redis_client.expire(key, set_ttl) # type: ignore
+ end_time = time.time()
+ _duration = end_time - start_time
+ self.service_logger_obj.service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="increment_cache_expire",
+ start_time=start_time,
+ end_time=end_time,
+ )
+ return result
+ except Exception as e:
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ verbose_logger.error(
+ "LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
+ str(e),
+ value,
+ )
+ raise e
+
+ async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
+ from redis.asyncio import Redis
+
+ start_time = time.time()
+ try:
+ keys = []
+ _redis_client: Redis = self.init_async_client() # type: ignore
+
+ async for key in _redis_client.scan_iter(match=pattern + "*", count=count):
+ keys.append(key)
+ if len(keys) >= count:
+ break
+
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="async_scan_iter",
+ start_time=start_time,
+ end_time=end_time,
+ )
+ ) # DO NOT SLOW DOWN CALL B/C OF THIS
+ return keys
+ except Exception as e:
+ # NON blocking - notify users Redis is throwing an exception
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_failure_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ error=e,
+ call_type="async_scan_iter",
+ start_time=start_time,
+ end_time=end_time,
+ )
+ )
+ raise e
+
+ async def async_set_cache(self, key, value, **kwargs):
+ from redis.asyncio import Redis
+
+ start_time = time.time()
+ try:
+ _redis_client: Redis = self.init_async_client() # type: ignore
+ except Exception as e:
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_failure_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ error=e,
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
+ call_type="async_set_cache",
+ )
+ )
+ verbose_logger.error(
+ "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
+ str(e),
+ value,
+ )
+ raise e
+
+ key = self.check_and_fix_namespace(key=key)
+ ttl = self.get_ttl(**kwargs)
+ print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
+
+ try:
+ if not hasattr(_redis_client, "set"):
+ raise Exception("Redis client cannot set cache. Attribute not found.")
+ await _redis_client.set(name=key, value=json.dumps(value), ex=ttl)
+ print_verbose(
+ f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
+ )
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="async_set_cache",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
+ event_metadata={"key": key},
+ )
+ )
+ except Exception as e:
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_failure_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ error=e,
+ call_type="async_set_cache",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
+ event_metadata={"key": key},
+ )
+ )
+ verbose_logger.error(
+ "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
+ str(e),
+ value,
+ )
+
+ async def _pipeline_helper(
+ self,
+ pipe: Union[pipeline, cluster_pipeline],
+ cache_list: List[Tuple[Any, Any]],
+ ttl: Optional[float],
+ ) -> List:
+ """
+ Helper function for executing a pipeline of set operations on Redis
+ """
+ ttl = self.get_ttl(ttl=ttl)
+ # Iterate through each key-value pair in the cache_list and set them in the pipeline.
+ for cache_key, cache_value in cache_list:
+ cache_key = self.check_and_fix_namespace(key=cache_key)
+ print_verbose(
+ f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
+ )
+ json_cache_value = json.dumps(cache_value)
+ # Set the value with a TTL if it's provided.
+ _td: Optional[timedelta] = None
+ if ttl is not None:
+ _td = timedelta(seconds=ttl)
+ pipe.set( # type: ignore
+ name=cache_key,
+ value=json_cache_value,
+ ex=_td,
+ )
+ # Execute the pipeline and return the results.
+ results = await pipe.execute()
+ return results
+
+ async def async_set_cache_pipeline(
+ self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs
+ ):
+ """
+ Use Redis Pipelines for bulk write operations
+ """
+ # don't waste a network request if there's nothing to set
+ if len(cache_list) == 0:
+ return
+
+ _redis_client = self.init_async_client()
+ start_time = time.time()
+
+ print_verbose(
+ f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
+ )
+ cache_value: Any = None
+ try:
+ async with _redis_client.pipeline(transaction=False) as pipe:
+ results = await self._pipeline_helper(pipe, cache_list, ttl)
+
+ print_verbose(f"pipeline results: {results}")
+ # Optionally, you could process 'results' to make sure that all set operations were successful.
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="async_set_cache_pipeline",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
+ )
+ )
+ return None
+ except Exception as e:
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_failure_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ error=e,
+ call_type="async_set_cache_pipeline",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
+ )
+ )
+
+ verbose_logger.error(
+ "LiteLLM Redis Caching: async set_cache_pipeline() - Got exception from REDIS %s, Writing value=%s",
+ str(e),
+ cache_value,
+ )
+
+ async def _set_cache_sadd_helper(
+ self,
+ redis_client: async_redis_client,
+ key: str,
+ value: List,
+ ttl: Optional[float],
+ ) -> None:
+ """Helper function for async_set_cache_sadd. Separated for testing."""
+ ttl = self.get_ttl(ttl=ttl)
+ try:
+ await redis_client.sadd(key, *value) # type: ignore
+ if ttl is not None:
+ _td = timedelta(seconds=ttl)
+ await redis_client.expire(key, _td)
+ except Exception:
+ raise
+
+ async def async_set_cache_sadd(
+ self, key, value: List, ttl: Optional[float], **kwargs
+ ):
+ from redis.asyncio import Redis
+
+ start_time = time.time()
+ try:
+ _redis_client: Redis = self.init_async_client() # type: ignore
+ except Exception as e:
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_failure_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ error=e,
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
+ call_type="async_set_cache_sadd",
+ )
+ )
+ # NON blocking - notify users Redis is throwing an exception
+ verbose_logger.error(
+ "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
+ str(e),
+ value,
+ )
+ raise e
+
+ key = self.check_and_fix_namespace(key=key)
+ print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
+ try:
+ await self._set_cache_sadd_helper(
+ redis_client=_redis_client, key=key, value=value, ttl=ttl
+ )
+ print_verbose(
+ f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}"
+ )
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="async_set_cache_sadd",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
+ )
+ )
+ except Exception as e:
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_failure_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ error=e,
+ call_type="async_set_cache_sadd",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
+ )
+ )
+ # NON blocking - notify users Redis is throwing an exception
+ verbose_logger.error(
+ "LiteLLM Redis Caching: async set_cache_sadd() - Got exception from REDIS %s, Writing value=%s",
+ str(e),
+ value,
+ )
+
+ async def batch_cache_write(self, key, value, **kwargs):
+ print_verbose(
+ f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
+ )
+ key = self.check_and_fix_namespace(key=key)
+ self.redis_batch_writing_buffer.append((key, value))
+ if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
+ await self.flush_cache_buffer() # logging done in here
+
+ async def async_increment(
+ self,
+ key,
+ value: float,
+ ttl: Optional[int] = None,
+ parent_otel_span: Optional[Span] = None,
+ ) -> float:
+ from redis.asyncio import Redis
+
+ _redis_client: Redis = self.init_async_client() # type: ignore
+ start_time = time.time()
+ _used_ttl = self.get_ttl(ttl=ttl)
+ key = self.check_and_fix_namespace(key=key)
+ try:
+ result = await _redis_client.incrbyfloat(name=key, amount=value)
+ if _used_ttl is not None:
+ # check if key already has ttl, if not -> set ttl
+ current_ttl = await _redis_client.ttl(key)
+ if current_ttl == -1:
+ # Key has no expiration
+ await _redis_client.expire(key, _used_ttl)
+
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+
+ asyncio.create_task(
+ self.service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="async_increment",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=parent_otel_span,
+ )
+ )
+ return result
+ except Exception as e:
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_failure_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ error=e,
+ call_type="async_increment",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=parent_otel_span,
+ )
+ )
+ verbose_logger.error(
+ "LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
+ str(e),
+ value,
+ )
+ raise e
+
+ async def flush_cache_buffer(self):
+ print_verbose(
+ f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
+ )
+ await self.async_set_cache_pipeline(self.redis_batch_writing_buffer)
+ self.redis_batch_writing_buffer = []
+
+ def _get_cache_logic(self, cached_response: Any):
+ """
+ Common 'get_cache_logic' across sync + async redis client implementations
+ """
+ if cached_response is None:
+ return cached_response
+ # cached_response is in `b{} convert it to ModelResponse
+ cached_response = cached_response.decode("utf-8") # Convert bytes to string
+ try:
+ cached_response = json.loads(
+ cached_response
+ ) # Convert string to dictionary
+ except Exception:
+ cached_response = ast.literal_eval(cached_response)
+ return cached_response
+
+ def get_cache(self, key, parent_otel_span: Optional[Span] = None, **kwargs):
+ try:
+ key = self.check_and_fix_namespace(key=key)
+ print_verbose(f"Get Redis Cache: key: {key}")
+ start_time = time.time()
+ cached_response = self.redis_client.get(key)
+ end_time = time.time()
+ _duration = end_time - start_time
+ self.service_logger_obj.service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="get_cache",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=parent_otel_span,
+ )
+ print_verbose(
+ f"Got Redis Cache: key: {key}, cached_response {cached_response}"
+ )
+ return self._get_cache_logic(cached_response=cached_response)
+ except Exception as e:
+ # NON blocking - notify users Redis is throwing an exception
+ verbose_logger.error(
+ "litellm.caching.caching: get() - Got exception from REDIS: ", e
+ )
+
+ def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
+ """
+ Wrapper to call `mget` on the redis client
+
+ We use a wrapper so RedisCluster can override this method
+ """
+ return self.redis_client.mget(keys=keys) # type: ignore
+
+ async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
+ """
+ Wrapper to call `mget` on the redis client
+
+ We use a wrapper so RedisCluster can override this method
+ """
+ async_redis_client = self.init_async_client()
+ return await async_redis_client.mget(keys=keys) # type: ignore
+
+ def batch_get_cache(
+ self,
+ key_list: Union[List[str], List[Optional[str]]],
+ parent_otel_span: Optional[Span] = None,
+ ) -> dict:
+ """
+ Use Redis for bulk read operations
+
+ Args:
+ key_list: List of keys to get from Redis
+ parent_otel_span: Optional parent OpenTelemetry span
+
+ Returns:
+ dict: A dictionary mapping keys to their cached values
+ """
+ key_value_dict = {}
+ _key_list = [key for key in key_list if key is not None]
+
+ try:
+ _keys = []
+ for cache_key in _key_list:
+ cache_key = self.check_and_fix_namespace(key=cache_key or "")
+ _keys.append(cache_key)
+ start_time = time.time()
+ results: List = self._run_redis_mget_operation(keys=_keys)
+ end_time = time.time()
+ _duration = end_time - start_time
+ self.service_logger_obj.service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="batch_get_cache",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=parent_otel_span,
+ )
+
+ # Associate the results back with their keys.
+ # 'results' is a list of values corresponding to the order of keys in '_key_list'.
+ key_value_dict = dict(zip(_key_list, results))
+
+ decoded_results = {}
+ for k, v in key_value_dict.items():
+ if isinstance(k, bytes):
+ k = k.decode("utf-8")
+ v = self._get_cache_logic(v)
+ decoded_results[k] = v
+
+ return decoded_results
+ except Exception as e:
+ verbose_logger.error(f"Error occurred in batch get cache - {str(e)}")
+ return key_value_dict
+
+ async def async_get_cache(
+ self, key, parent_otel_span: Optional[Span] = None, **kwargs
+ ):
+ from redis.asyncio import Redis
+
+ _redis_client: Redis = self.init_async_client() # type: ignore
+ key = self.check_and_fix_namespace(key=key)
+ start_time = time.time()
+
+ try:
+ print_verbose(f"Get Async Redis Cache: key: {key}")
+ cached_response = await _redis_client.get(key)
+ print_verbose(
+ f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
+ )
+ response = self._get_cache_logic(cached_response=cached_response)
+
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="async_get_cache",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=parent_otel_span,
+ event_metadata={"key": key},
+ )
+ )
+ return response
+ except Exception as e:
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_failure_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ error=e,
+ call_type="async_get_cache",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=parent_otel_span,
+ event_metadata={"key": key},
+ )
+ )
+ print_verbose(
+ f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}"
+ )
+
+ async def async_batch_get_cache(
+ self,
+ key_list: Union[List[str], List[Optional[str]]],
+ parent_otel_span: Optional[Span] = None,
+ ) -> dict:
+ """
+ Use Redis for bulk read operations
+
+ Args:
+ key_list: List of keys to get from Redis
+ parent_otel_span: Optional parent OpenTelemetry span
+
+ Returns:
+ dict: A dictionary mapping keys to their cached values
+
+ `.mget` does not support None keys. This will filter out None keys.
+ """
+ # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
+ key_value_dict = {}
+ start_time = time.time()
+ _key_list = [key for key in key_list if key is not None]
+ try:
+ _keys = []
+ for cache_key in _key_list:
+ cache_key = self.check_and_fix_namespace(key=cache_key)
+ _keys.append(cache_key)
+ results = await self._async_run_redis_mget_operation(keys=_keys)
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="async_batch_get_cache",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=parent_otel_span,
+ )
+ )
+
+ # Associate the results back with their keys.
+ # 'results' is a list of values corresponding to the order of keys in 'key_list'.
+ key_value_dict = dict(zip(_key_list, results))
+
+ decoded_results = {}
+ for k, v in key_value_dict.items():
+ if isinstance(k, bytes):
+ k = k.decode("utf-8")
+ v = self._get_cache_logic(v)
+ decoded_results[k] = v
+
+ return decoded_results
+ except Exception as e:
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_failure_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ error=e,
+ call_type="async_batch_get_cache",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=parent_otel_span,
+ )
+ )
+ verbose_logger.error(f"Error occurred in async batch get cache - {str(e)}")
+ return key_value_dict
+
+ def sync_ping(self) -> bool:
+ """
+ Tests if the sync redis client is correctly setup.
+ """
+ print_verbose("Pinging Sync Redis Cache")
+ start_time = time.time()
+ try:
+ response: bool = self.redis_client.ping() # type: ignore
+ print_verbose(f"Redis Cache PING: {response}")
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ self.service_logger_obj.service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="sync_ping",
+ start_time=start_time,
+ end_time=end_time,
+ )
+ return response
+ except Exception as e:
+ # NON blocking - notify users Redis is throwing an exception
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ self.service_logger_obj.service_failure_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ error=e,
+ call_type="sync_ping",
+ )
+ verbose_logger.error(
+ f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
+ )
+ raise e
+
+ async def ping(self) -> bool:
+ # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ping`
+ _redis_client: Any = self.init_async_client()
+ start_time = time.time()
+ print_verbose("Pinging Async Redis Cache")
+ try:
+ response = await _redis_client.ping()
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="async_ping",
+ )
+ )
+ return response
+ except Exception as e:
+ # NON blocking - notify users Redis is throwing an exception
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_failure_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ error=e,
+ call_type="async_ping",
+ )
+ )
+ verbose_logger.error(
+ f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
+ )
+ raise e
+
+ async def delete_cache_keys(self, keys):
+ # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
+ _redis_client: Any = self.init_async_client()
+ # keys is a list, unpack it so it gets passed as individual elements to delete
+ await _redis_client.delete(*keys)
+
+ def client_list(self) -> List:
+ client_list: List = self.redis_client.client_list() # type: ignore
+ return client_list
+
+ def info(self):
+ info = self.redis_client.info()
+ return info
+
+ def flush_cache(self):
+ self.redis_client.flushall()
+
+ def flushall(self):
+ self.redis_client.flushall()
+
+ async def disconnect(self):
+ await self.async_redis_conn_pool.disconnect(inuse_connections=True)
+
+ async def async_delete_cache(self, key: str):
+ # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
+ _redis_client: Any = self.init_async_client()
+ # keys is str
+ await _redis_client.delete(key)
+
+ def delete_cache(self, key):
+ self.redis_client.delete(key)
+
+ async def _pipeline_increment_helper(
+ self,
+ pipe: pipeline,
+ increment_list: List[RedisPipelineIncrementOperation],
+ ) -> Optional[List[float]]:
+ """Helper function for pipeline increment operations"""
+ # Iterate through each increment operation and add commands to pipeline
+ for increment_op in increment_list:
+ cache_key = self.check_and_fix_namespace(key=increment_op["key"])
+ print_verbose(
+ f"Increment ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {increment_op['increment_value']}\nttl={increment_op['ttl']}"
+ )
+ pipe.incrbyfloat(cache_key, increment_op["increment_value"])
+ if increment_op["ttl"] is not None:
+ _td = timedelta(seconds=increment_op["ttl"])
+ pipe.expire(cache_key, _td)
+ # Execute the pipeline and return results
+ results = await pipe.execute()
+ print_verbose(f"Increment ASYNC Redis Cache PIPELINE: results: {results}")
+ return results
+
+ async def async_increment_pipeline(
+ self, increment_list: List[RedisPipelineIncrementOperation], **kwargs
+ ) -> Optional[List[float]]:
+ """
+ Use Redis Pipelines for bulk increment operations
+ Args:
+ increment_list: List of RedisPipelineIncrementOperation dicts containing:
+ - key: str
+ - increment_value: float
+ - ttl_seconds: int
+ """
+ # don't waste a network request if there's nothing to increment
+ if len(increment_list) == 0:
+ return None
+
+ from redis.asyncio import Redis
+
+ _redis_client: Redis = self.init_async_client() # type: ignore
+ start_time = time.time()
+
+ print_verbose(
+ f"Increment Async Redis Cache Pipeline: increment list: {increment_list}"
+ )
+
+ try:
+ async with _redis_client.pipeline(transaction=False) as pipe:
+ results = await self._pipeline_increment_helper(pipe, increment_list)
+
+ print_verbose(f"pipeline increment results: {results}")
+
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ call_type="async_increment_pipeline",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
+ )
+ )
+ return results
+ except Exception as e:
+ ## LOGGING ##
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_failure_hook(
+ service=ServiceTypes.REDIS,
+ duration=_duration,
+ error=e,
+ call_type="async_increment_pipeline",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
+ )
+ )
+ verbose_logger.error(
+ "LiteLLM Redis Caching: async increment_pipeline() - Got exception from REDIS %s",
+ str(e),
+ )
+ raise e
+
+ async def async_get_ttl(self, key: str) -> Optional[int]:
+ """
+ Get the remaining TTL of a key in Redis
+
+ Args:
+ key (str): The key to get TTL for
+
+ Returns:
+ Optional[int]: The remaining TTL in seconds, or None if key doesn't exist
+
+ Redis ref: https://redis.io/docs/latest/commands/ttl/
+ """
+ try:
+ # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ttl`
+ _redis_client: Any = self.init_async_client()
+ ttl = await _redis_client.ttl(key)
+ if ttl <= -1: # -1 means the key does not exist, -2 key does not exist
+ return None
+ return ttl
+ except Exception as e:
+ verbose_logger.debug(f"Redis TTL Error: {e}")
+ return None
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/redis_cluster_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/redis_cluster_cache.py
new file mode 100644
index 00000000..2e7d1de1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/redis_cluster_cache.py
@@ -0,0 +1,59 @@
+"""
+Redis Cluster Cache implementation
+
+Key differences:
+- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
+"""
+
+from typing import TYPE_CHECKING, Any, List, Optional
+
+from litellm.caching.redis_cache import RedisCache
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+ from redis.asyncio import Redis, RedisCluster
+ from redis.asyncio.client import Pipeline
+
+ pipeline = Pipeline
+ async_redis_client = Redis
+ Span = _Span
+else:
+ pipeline = Any
+ async_redis_client = Any
+ Span = Any
+
+
+class RedisClusterCache(RedisCache):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.redis_async_redis_cluster_client: Optional[RedisCluster] = None
+ self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None
+
+ def init_async_client(self):
+ from redis.asyncio import RedisCluster
+
+ from .._redis import get_redis_async_client
+
+ if self.redis_async_redis_cluster_client:
+ return self.redis_async_redis_cluster_client
+
+ _redis_client = get_redis_async_client(
+ connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
+ )
+ if isinstance(_redis_client, RedisCluster):
+ self.redis_async_redis_cluster_client = _redis_client
+
+ return _redis_client
+
+ def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
+ """
+ Overrides `_run_redis_mget_operation` in redis_cache.py
+ """
+ return self.redis_client.mget_nonatomic(keys=keys) # type: ignore
+
+ async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
+ """
+ Overrides `_async_run_redis_mget_operation` in redis_cache.py
+ """
+ async_redis_cluster_client = self.init_async_client()
+ return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py
new file mode 100644
index 00000000..b609286a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py
@@ -0,0 +1,337 @@
+"""
+Redis Semantic Cache implementation
+
+Has 4 methods:
+ - set_cache
+ - get_cache
+ - async_set_cache
+ - async_get_cache
+"""
+
+import ast
+import asyncio
+import json
+from typing import Any
+
+import litellm
+from litellm._logging import print_verbose
+
+from .base_cache import BaseCache
+
+
+class RedisSemanticCache(BaseCache):
+ def __init__(
+ self,
+ host=None,
+ port=None,
+ password=None,
+ redis_url=None,
+ similarity_threshold=None,
+ use_async=False,
+ embedding_model="text-embedding-ada-002",
+ **kwargs,
+ ):
+ from redisvl.index import SearchIndex
+
+ print_verbose(
+ "redis semantic-cache initializing INDEX - litellm_semantic_cache_index"
+ )
+ if similarity_threshold is None:
+ raise Exception("similarity_threshold must be provided, passed None")
+ self.similarity_threshold = similarity_threshold
+ self.embedding_model = embedding_model
+ schema = {
+ "index": {
+ "name": "litellm_semantic_cache_index",
+ "prefix": "litellm",
+ "storage_type": "hash",
+ },
+ "fields": {
+ "text": [{"name": "response"}],
+ "vector": [
+ {
+ "name": "litellm_embedding",
+ "dims": 1536,
+ "distance_metric": "cosine",
+ "algorithm": "flat",
+ "datatype": "float32",
+ }
+ ],
+ },
+ }
+ if redis_url is None:
+ # if no url passed, check if host, port and password are passed, if not raise an Exception
+ if host is None or port is None or password is None:
+ # try checking env for host, port and password
+ import os
+
+ host = os.getenv("REDIS_HOST")
+ port = os.getenv("REDIS_PORT")
+ password = os.getenv("REDIS_PASSWORD")
+ if host is None or port is None or password is None:
+ raise Exception("Redis host, port, and password must be provided")
+
+ redis_url = "redis://:" + password + "@" + host + ":" + port
+ print_verbose(f"redis semantic-cache redis_url: {redis_url}")
+ if use_async is False:
+ self.index = SearchIndex.from_dict(schema)
+ self.index.connect(redis_url=redis_url)
+ try:
+ self.index.create(overwrite=False) # don't overwrite existing index
+ except Exception as e:
+ print_verbose(f"Got exception creating semantic cache index: {str(e)}")
+ elif use_async is True:
+ schema["index"]["name"] = "litellm_semantic_cache_index_async"
+ self.index = SearchIndex.from_dict(schema)
+ self.index.connect(redis_url=redis_url, use_async=True)
+
+ #
+ def _get_cache_logic(self, cached_response: Any):
+ """
+ Common 'get_cache_logic' across sync + async redis client implementations
+ """
+ if cached_response is None:
+ return cached_response
+
+ # check if cached_response is bytes
+ if isinstance(cached_response, bytes):
+ cached_response = cached_response.decode("utf-8")
+
+ try:
+ cached_response = json.loads(
+ cached_response
+ ) # Convert string to dictionary
+ except Exception:
+ cached_response = ast.literal_eval(cached_response)
+ return cached_response
+
+ def set_cache(self, key, value, **kwargs):
+ import numpy as np
+
+ print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}")
+
+ # get the prompt
+ messages = kwargs["messages"]
+ prompt = "".join(message["content"] for message in messages)
+
+ # create an embedding for prompt
+ embedding_response = litellm.embedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ # make the embedding a numpy array, convert to bytes
+ embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
+ value = str(value)
+ assert isinstance(value, str)
+
+ new_data = [
+ {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
+ ]
+
+ # Add more data
+ self.index.load(new_data)
+
+ return
+
+ def get_cache(self, key, **kwargs):
+ print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
+ from redisvl.query import VectorQuery
+
+ # query
+ # get the messages
+ messages = kwargs["messages"]
+ prompt = "".join(message["content"] for message in messages)
+
+ # convert to embedding
+ embedding_response = litellm.embedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ query = VectorQuery(
+ vector=embedding,
+ vector_field_name="litellm_embedding",
+ return_fields=["response", "prompt", "vector_distance"],
+ num_results=1,
+ )
+
+ results = self.index.query(query)
+ if results is None:
+ return None
+ if isinstance(results, list):
+ if len(results) == 0:
+ return None
+
+ vector_distance = results[0]["vector_distance"]
+ vector_distance = float(vector_distance)
+ similarity = 1 - vector_distance
+ cached_prompt = results[0]["prompt"]
+
+ # check similarity, if more than self.similarity_threshold, return results
+ print_verbose(
+ f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
+ )
+ if similarity > self.similarity_threshold:
+ # cache hit !
+ cached_value = results[0]["response"]
+ print_verbose(
+ f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
+ )
+ return self._get_cache_logic(cached_response=cached_value)
+ else:
+ # cache miss !
+ return None
+
+ pass
+
+ async def async_set_cache(self, key, value, **kwargs):
+ import numpy as np
+
+ from litellm.proxy.proxy_server import llm_model_list, llm_router
+
+ try:
+ await self.index.acreate(overwrite=False) # don't overwrite existing index
+ except Exception as e:
+ print_verbose(f"Got exception creating semantic cache index: {str(e)}")
+ print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}")
+
+ # get the prompt
+ messages = kwargs["messages"]
+ prompt = "".join(message["content"] for message in messages)
+ # create an embedding for prompt
+ router_model_names = (
+ [m["model_name"] for m in llm_model_list]
+ if llm_model_list is not None
+ else []
+ )
+ if llm_router is not None and self.embedding_model in router_model_names:
+ user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
+ embedding_response = await llm_router.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ metadata={
+ "user_api_key": user_api_key,
+ "semantic-cache-embedding": True,
+ "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
+ },
+ )
+ else:
+ # convert to embedding
+ embedding_response = await litellm.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ # make the embedding a numpy array, convert to bytes
+ embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
+ value = str(value)
+ assert isinstance(value, str)
+
+ new_data = [
+ {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
+ ]
+
+ # Add more data
+ await self.index.aload(new_data)
+ return
+
+ async def async_get_cache(self, key, **kwargs):
+ print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
+ from redisvl.query import VectorQuery
+
+ from litellm.proxy.proxy_server import llm_model_list, llm_router
+
+ # query
+ # get the messages
+ messages = kwargs["messages"]
+ prompt = "".join(message["content"] for message in messages)
+
+ router_model_names = (
+ [m["model_name"] for m in llm_model_list]
+ if llm_model_list is not None
+ else []
+ )
+ if llm_router is not None and self.embedding_model in router_model_names:
+ user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
+ embedding_response = await llm_router.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ metadata={
+ "user_api_key": user_api_key,
+ "semantic-cache-embedding": True,
+ "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
+ },
+ )
+ else:
+ # convert to embedding
+ embedding_response = await litellm.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ query = VectorQuery(
+ vector=embedding,
+ vector_field_name="litellm_embedding",
+ return_fields=["response", "prompt", "vector_distance"],
+ )
+ results = await self.index.aquery(query)
+ if results is None:
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
+ return None
+ if isinstance(results, list):
+ if len(results) == 0:
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
+ return None
+
+ vector_distance = results[0]["vector_distance"]
+ vector_distance = float(vector_distance)
+ similarity = 1 - vector_distance
+ cached_prompt = results[0]["prompt"]
+
+ # check similarity, if more than self.similarity_threshold, return results
+ print_verbose(
+ f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
+ )
+
+ # update kwargs["metadata"] with similarity, don't rewrite the original metadata
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
+
+ if similarity > self.similarity_threshold:
+ # cache hit !
+ cached_value = results[0]["response"]
+ print_verbose(
+ f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
+ )
+ return self._get_cache_logic(cached_response=cached_value)
+ else:
+ # cache miss !
+ return None
+ pass
+
+ async def _index_info(self):
+ return await self.index.ainfo()
+
+ async def async_set_cache_pipeline(self, cache_list, **kwargs):
+ tasks = []
+ for val in cache_list:
+ tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
+ await asyncio.gather(*tasks)
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/s3_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/s3_cache.py
new file mode 100644
index 00000000..301591c6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/s3_cache.py
@@ -0,0 +1,159 @@
+"""
+S3 Cache implementation
+WARNING: DO NOT USE THIS IN PRODUCTION - This is not ASYNC
+
+Has 4 methods:
+ - set_cache
+ - get_cache
+ - async_set_cache
+ - async_get_cache
+"""
+
+import ast
+import asyncio
+import json
+from typing import Optional
+
+from litellm._logging import print_verbose, verbose_logger
+
+from .base_cache import BaseCache
+
+
+class S3Cache(BaseCache):
+ def __init__(
+ self,
+ s3_bucket_name,
+ s3_region_name=None,
+ s3_api_version=None,
+ s3_use_ssl: Optional[bool] = True,
+ s3_verify=None,
+ s3_endpoint_url=None,
+ s3_aws_access_key_id=None,
+ s3_aws_secret_access_key=None,
+ s3_aws_session_token=None,
+ s3_config=None,
+ s3_path=None,
+ **kwargs,
+ ):
+ import boto3
+
+ self.bucket_name = s3_bucket_name
+ self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else ""
+ # Create an S3 client with custom endpoint URL
+
+ self.s3_client = boto3.client(
+ "s3",
+ region_name=s3_region_name,
+ endpoint_url=s3_endpoint_url,
+ api_version=s3_api_version,
+ use_ssl=s3_use_ssl,
+ verify=s3_verify,
+ aws_access_key_id=s3_aws_access_key_id,
+ aws_secret_access_key=s3_aws_secret_access_key,
+ aws_session_token=s3_aws_session_token,
+ config=s3_config,
+ **kwargs,
+ )
+
+ def set_cache(self, key, value, **kwargs):
+ try:
+ print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}")
+ ttl = kwargs.get("ttl", None)
+ # Convert value to JSON before storing in S3
+ serialized_value = json.dumps(value)
+ key = self.key_prefix + key
+
+ if ttl is not None:
+ cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}"
+ import datetime
+
+ # Calculate expiration time
+ expiration_time = datetime.datetime.now() + ttl
+
+ # Upload the data to S3 with the calculated expiration time
+ self.s3_client.put_object(
+ Bucket=self.bucket_name,
+ Key=key,
+ Body=serialized_value,
+ Expires=expiration_time,
+ CacheControl=cache_control,
+ ContentType="application/json",
+ ContentLanguage="en",
+ ContentDisposition=f'inline; filename="{key}.json"',
+ )
+ else:
+ cache_control = "immutable, max-age=31536000, s-maxage=31536000"
+ # Upload the data to S3 without specifying Expires
+ self.s3_client.put_object(
+ Bucket=self.bucket_name,
+ Key=key,
+ Body=serialized_value,
+ CacheControl=cache_control,
+ ContentType="application/json",
+ ContentLanguage="en",
+ ContentDisposition=f'inline; filename="{key}.json"',
+ )
+ except Exception as e:
+ # NON blocking - notify users S3 is throwing an exception
+ print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
+
+ async def async_set_cache(self, key, value, **kwargs):
+ self.set_cache(key=key, value=value, **kwargs)
+
+ def get_cache(self, key, **kwargs):
+ import botocore
+
+ try:
+ key = self.key_prefix + key
+
+ print_verbose(f"Get S3 Cache: key: {key}")
+ # Download the data from S3
+ cached_response = self.s3_client.get_object(
+ Bucket=self.bucket_name, Key=key
+ )
+
+ if cached_response is not None:
+ # cached_response is in `b{} convert it to ModelResponse
+ cached_response = (
+ cached_response["Body"].read().decode("utf-8")
+ ) # Convert bytes to string
+ try:
+ cached_response = json.loads(
+ cached_response
+ ) # Convert string to dictionary
+ except Exception:
+ cached_response = ast.literal_eval(cached_response)
+ if type(cached_response) is not dict:
+ cached_response = dict(cached_response)
+ verbose_logger.debug(
+ f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
+ )
+
+ return cached_response
+ except botocore.exceptions.ClientError as e: # type: ignore
+ if e.response["Error"]["Code"] == "NoSuchKey":
+ verbose_logger.debug(
+ f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
+ )
+ return None
+
+ except Exception as e:
+ # NON blocking - notify users S3 is throwing an exception
+ verbose_logger.error(
+ f"S3 Caching: get_cache() - Got exception from S3: {e}"
+ )
+
+ async def async_get_cache(self, key, **kwargs):
+ return self.get_cache(key=key, **kwargs)
+
+ def flush_cache(self):
+ pass
+
+ async def disconnect(self):
+ pass
+
+ async def async_set_cache_pipeline(self, cache_list, **kwargs):
+ tasks = []
+ for val in cache_list:
+ tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
+ await asyncio.gather(*tasks)