aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/caching/caching.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/caching/caching.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/caching/caching.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/caching.py797
1 files changed, 797 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/caching.py b/.venv/lib/python3.12/site-packages/litellm/caching/caching.py
new file mode 100644
index 00000000..415c49ed
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/caching.py
@@ -0,0 +1,797 @@
+# +-----------------------------------------------+
+# | |
+# | Give Feedback / Get Help |
+# | https://github.com/BerriAI/litellm/issues/new |
+# | |
+# +-----------------------------------------------+
+#
+# Thank you users! We ❤️ you! - Krrish & Ishaan
+
+import ast
+import hashlib
+import json
+import time
+import traceback
+from enum import Enum
+from typing import Any, Dict, List, Optional, Union
+
+from pydantic import BaseModel
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
+from litellm.types.caching import *
+from litellm.types.utils import all_litellm_params
+
+from .base_cache import BaseCache
+from .disk_cache import DiskCache
+from .dual_cache import DualCache # noqa
+from .in_memory_cache import InMemoryCache
+from .qdrant_semantic_cache import QdrantSemanticCache
+from .redis_cache import RedisCache
+from .redis_cluster_cache import RedisClusterCache
+from .redis_semantic_cache import RedisSemanticCache
+from .s3_cache import S3Cache
+
+
+def print_verbose(print_statement):
+ try:
+ verbose_logger.debug(print_statement)
+ if litellm.set_verbose:
+ print(print_statement) # noqa
+ except Exception:
+ pass
+
+
+class CacheMode(str, Enum):
+ default_on = "default_on"
+ default_off = "default_off"
+
+
+#### LiteLLM.Completion / Embedding Cache ####
+class Cache:
+ def __init__(
+ self,
+ type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
+ mode: Optional[
+ CacheMode
+ ] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
+ host: Optional[str] = None,
+ port: Optional[str] = None,
+ password: Optional[str] = None,
+ namespace: Optional[str] = None,
+ ttl: Optional[float] = None,
+ default_in_memory_ttl: Optional[float] = None,
+ default_in_redis_ttl: Optional[float] = None,
+ similarity_threshold: Optional[float] = None,
+ supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
+ "completion",
+ "acompletion",
+ "embedding",
+ "aembedding",
+ "atranscription",
+ "transcription",
+ "atext_completion",
+ "text_completion",
+ "arerank",
+ "rerank",
+ ],
+ # s3 Bucket, boto3 configuration
+ s3_bucket_name: Optional[str] = None,
+ s3_region_name: Optional[str] = None,
+ s3_api_version: Optional[str] = None,
+ s3_use_ssl: Optional[bool] = True,
+ s3_verify: Optional[Union[bool, str]] = None,
+ s3_endpoint_url: Optional[str] = None,
+ s3_aws_access_key_id: Optional[str] = None,
+ s3_aws_secret_access_key: Optional[str] = None,
+ s3_aws_session_token: Optional[str] = None,
+ s3_config: Optional[Any] = None,
+ s3_path: Optional[str] = None,
+ redis_semantic_cache_use_async=False,
+ redis_semantic_cache_embedding_model="text-embedding-ada-002",
+ redis_flush_size: Optional[int] = None,
+ redis_startup_nodes: Optional[List] = None,
+ disk_cache_dir=None,
+ qdrant_api_base: Optional[str] = None,
+ qdrant_api_key: Optional[str] = None,
+ qdrant_collection_name: Optional[str] = None,
+ qdrant_quantization_config: Optional[str] = None,
+ qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
+ **kwargs,
+ ):
+ """
+ Initializes the cache based on the given type.
+
+ Args:
+ type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
+
+ # Redis Cache Args
+ host (str, optional): The host address for the Redis cache. Required if type is "redis".
+ port (int, optional): The port number for the Redis cache. Required if type is "redis".
+ password (str, optional): The password for the Redis cache. Required if type is "redis".
+ namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
+ ttl (float, optional): The ttl for the Redis cache
+ redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
+ redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
+
+ # Qdrant Cache Args
+ qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
+ qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
+ qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
+ similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
+
+ # Disk Cache Args
+ disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
+
+ # S3 Cache Args
+ s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
+ s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
+ s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
+ s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
+ s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
+ s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
+ s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
+ s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
+ s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
+ s3_config (dict, optional): The config for the s3 cache. Defaults to None.
+
+ # Common Cache Args
+ supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
+ **kwargs: Additional keyword arguments for redis.Redis() cache
+
+ Raises:
+ ValueError: If an invalid cache type is provided.
+
+ Returns:
+ None. Cache is set as a litellm param
+ """
+ if type == LiteLLMCacheType.REDIS:
+ if redis_startup_nodes:
+ self.cache: BaseCache = RedisClusterCache(
+ host=host,
+ port=port,
+ password=password,
+ redis_flush_size=redis_flush_size,
+ startup_nodes=redis_startup_nodes,
+ **kwargs,
+ )
+ else:
+ self.cache = RedisCache(
+ host=host,
+ port=port,
+ password=password,
+ redis_flush_size=redis_flush_size,
+ **kwargs,
+ )
+ elif type == LiteLLMCacheType.REDIS_SEMANTIC:
+ self.cache = RedisSemanticCache(
+ host=host,
+ port=port,
+ password=password,
+ similarity_threshold=similarity_threshold,
+ use_async=redis_semantic_cache_use_async,
+ embedding_model=redis_semantic_cache_embedding_model,
+ **kwargs,
+ )
+ elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
+ self.cache = QdrantSemanticCache(
+ qdrant_api_base=qdrant_api_base,
+ qdrant_api_key=qdrant_api_key,
+ collection_name=qdrant_collection_name,
+ similarity_threshold=similarity_threshold,
+ quantization_config=qdrant_quantization_config,
+ embedding_model=qdrant_semantic_cache_embedding_model,
+ )
+ elif type == LiteLLMCacheType.LOCAL:
+ self.cache = InMemoryCache()
+ elif type == LiteLLMCacheType.S3:
+ self.cache = S3Cache(
+ s3_bucket_name=s3_bucket_name,
+ s3_region_name=s3_region_name,
+ s3_api_version=s3_api_version,
+ s3_use_ssl=s3_use_ssl,
+ s3_verify=s3_verify,
+ s3_endpoint_url=s3_endpoint_url,
+ s3_aws_access_key_id=s3_aws_access_key_id,
+ s3_aws_secret_access_key=s3_aws_secret_access_key,
+ s3_aws_session_token=s3_aws_session_token,
+ s3_config=s3_config,
+ s3_path=s3_path,
+ **kwargs,
+ )
+ elif type == LiteLLMCacheType.DISK:
+ self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
+ if "cache" not in litellm.input_callback:
+ litellm.input_callback.append("cache")
+ if "cache" not in litellm.success_callback:
+ litellm.logging_callback_manager.add_litellm_success_callback("cache")
+ if "cache" not in litellm._async_success_callback:
+ litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
+ self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
+ self.type = type
+ self.namespace = namespace
+ self.redis_flush_size = redis_flush_size
+ self.ttl = ttl
+ self.mode: CacheMode = mode or CacheMode.default_on
+
+ if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
+ self.ttl = default_in_memory_ttl
+
+ if (
+ self.type == LiteLLMCacheType.REDIS
+ or self.type == LiteLLMCacheType.REDIS_SEMANTIC
+ ) and default_in_redis_ttl is not None:
+ self.ttl = default_in_redis_ttl
+
+ if self.namespace is not None and isinstance(self.cache, RedisCache):
+ self.cache.namespace = self.namespace
+
+ def get_cache_key(self, **kwargs) -> str:
+ """
+ Get the cache key for the given arguments.
+
+ Args:
+ **kwargs: kwargs to litellm.completion() or embedding()
+
+ Returns:
+ str: The cache key generated from the arguments, or None if no cache key could be generated.
+ """
+ cache_key = ""
+ # verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
+
+ preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
+ if preset_cache_key is not None:
+ verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
+ return preset_cache_key
+
+ combined_kwargs = ModelParamHelper._get_all_llm_api_params()
+ litellm_param_kwargs = all_litellm_params
+ for param in kwargs:
+ if param in combined_kwargs:
+ param_value: Optional[str] = self._get_param_value(param, kwargs)
+ if param_value is not None:
+ cache_key += f"{str(param)}: {str(param_value)}"
+ elif (
+ param not in litellm_param_kwargs
+ ): # check if user passed in optional param - e.g. top_k
+ if (
+ litellm.enable_caching_on_provider_specific_optional_params is True
+ ): # feature flagged for now
+ if kwargs[param] is None:
+ continue # ignore None params
+ param_value = kwargs[param]
+ cache_key += f"{str(param)}: {str(param_value)}"
+
+ verbose_logger.debug("\nCreated cache key: %s", cache_key)
+ hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
+ hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
+ self._set_preset_cache_key_in_kwargs(
+ preset_cache_key=hashed_cache_key, **kwargs
+ )
+ return hashed_cache_key
+
+ def _get_param_value(
+ self,
+ param: str,
+ kwargs: dict,
+ ) -> Optional[str]:
+ """
+ Get the value for the given param from kwargs
+ """
+ if param == "model":
+ return self._get_model_param_value(kwargs)
+ elif param == "file":
+ return self._get_file_param_value(kwargs)
+ return kwargs[param]
+
+ def _get_model_param_value(self, kwargs: dict) -> str:
+ """
+ Handles getting the value for the 'model' param from kwargs
+
+ 1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
+ 2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
+ 3. Else use the `model` passed in kwargs
+ """
+ metadata: Dict = kwargs.get("metadata", {}) or {}
+ litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
+ metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
+ model_group: Optional[str] = metadata.get(
+ "model_group"
+ ) or metadata_in_litellm_params.get("model_group")
+ caching_group = self._get_caching_group(metadata, model_group)
+ return caching_group or model_group or kwargs["model"]
+
+ def _get_caching_group(
+ self, metadata: dict, model_group: Optional[str]
+ ) -> Optional[str]:
+ caching_groups: Optional[List] = metadata.get("caching_groups", [])
+ if caching_groups:
+ for group in caching_groups:
+ if model_group in group:
+ return str(group)
+ return None
+
+ def _get_file_param_value(self, kwargs: dict) -> str:
+ """
+ Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
+ """
+ file = kwargs.get("file")
+ metadata = kwargs.get("metadata", {})
+ litellm_params = kwargs.get("litellm_params", {})
+ return (
+ metadata.get("file_checksum")
+ or getattr(file, "name", None)
+ or metadata.get("file_name")
+ or litellm_params.get("file_name")
+ )
+
+ def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
+ """
+ Get the preset cache key from kwargs["litellm_params"]
+
+ We use _get_preset_cache_keys for two reasons
+
+ 1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
+ 2. avoid doing duplicate / repeated work
+ """
+ if kwargs:
+ if "litellm_params" in kwargs:
+ return kwargs["litellm_params"].get("preset_cache_key", None)
+ return None
+
+ def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
+ """
+ Set the calculated cache key in kwargs
+
+ This is used to avoid doing duplicate / repeated work
+
+ Placed in kwargs["litellm_params"]
+ """
+ if kwargs:
+ if "litellm_params" in kwargs:
+ kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
+
+ @staticmethod
+ def _get_hashed_cache_key(cache_key: str) -> str:
+ """
+ Get the hashed cache key for the given cache key.
+
+ Use hashlib to create a sha256 hash of the cache key
+
+ Args:
+ cache_key (str): The cache key to hash.
+
+ Returns:
+ str: The hashed cache key.
+ """
+ hash_object = hashlib.sha256(cache_key.encode())
+ # Hexadecimal representation of the hash
+ hash_hex = hash_object.hexdigest()
+ verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
+ return hash_hex
+
+ def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
+ """
+ If a redis namespace is provided, add it to the cache key
+
+ Args:
+ hash_hex (str): The hashed cache key.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The final hashed cache key with the redis namespace.
+ """
+ dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
+ namespace = (
+ dynamic_cache_control.get("namespace")
+ or kwargs.get("metadata", {}).get("redis_namespace")
+ or self.namespace
+ )
+ if namespace:
+ hash_hex = f"{namespace}:{hash_hex}"
+ verbose_logger.debug("Final hashed key: %s", hash_hex)
+ return hash_hex
+
+ def generate_streaming_content(self, content):
+ chunk_size = 5 # Adjust the chunk size as needed
+ for i in range(0, len(content), chunk_size):
+ yield {
+ "choices": [
+ {
+ "delta": {
+ "role": "assistant",
+ "content": content[i : i + chunk_size],
+ }
+ }
+ ]
+ }
+ time.sleep(0.02)
+
+ def _get_cache_logic(
+ self,
+ cached_result: Optional[Any],
+ max_age: Optional[float],
+ ):
+ """
+ Common get cache logic across sync + async implementations
+ """
+ # Check if a timestamp was stored with the cached response
+ if (
+ cached_result is not None
+ and isinstance(cached_result, dict)
+ and "timestamp" in cached_result
+ ):
+ timestamp = cached_result["timestamp"]
+ current_time = time.time()
+
+ # Calculate age of the cached response
+ response_age = current_time - timestamp
+
+ # Check if the cached response is older than the max-age
+ if max_age is not None and response_age > max_age:
+ return None # Cached response is too old
+
+ # If the response is fresh, or there's no max-age requirement, return the cached response
+ # cached_response is in `b{} convert it to ModelResponse
+ cached_response = cached_result.get("response")
+ try:
+ if isinstance(cached_response, dict):
+ pass
+ else:
+ cached_response = json.loads(
+ cached_response # type: ignore
+ ) # Convert string to dictionary
+ except Exception:
+ cached_response = ast.literal_eval(cached_response) # type: ignore
+ return cached_response
+ return cached_result
+
+ def get_cache(self, **kwargs):
+ """
+ Retrieves the cached result for the given arguments.
+
+ Args:
+ *args: args to litellm.completion() or embedding()
+ **kwargs: kwargs to litellm.completion() or embedding()
+
+ Returns:
+ The cached result if it exists, otherwise None.
+ """
+ try: # never block execution
+ if self.should_use_cache(**kwargs) is not True:
+ return
+ messages = kwargs.get("messages", [])
+ if "cache_key" in kwargs:
+ cache_key = kwargs["cache_key"]
+ else:
+ cache_key = self.get_cache_key(**kwargs)
+ if cache_key is not None:
+ cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
+ max_age = (
+ cache_control_args.get("s-maxage")
+ or cache_control_args.get("s-max-age")
+ or float("inf")
+ )
+ cached_result = self.cache.get_cache(cache_key, messages=messages)
+ cached_result = self.cache.get_cache(cache_key, messages=messages)
+ return self._get_cache_logic(
+ cached_result=cached_result, max_age=max_age
+ )
+ except Exception:
+ print_verbose(f"An exception occurred: {traceback.format_exc()}")
+ return None
+
+ async def async_get_cache(self, **kwargs):
+ """
+ Async get cache implementation.
+
+ Used for embedding calls in async wrapper
+ """
+
+ try: # never block execution
+ if self.should_use_cache(**kwargs) is not True:
+ return
+
+ kwargs.get("messages", [])
+ if "cache_key" in kwargs:
+ cache_key = kwargs["cache_key"]
+ else:
+ cache_key = self.get_cache_key(**kwargs)
+ if cache_key is not None:
+ cache_control_args = kwargs.get("cache", {})
+ max_age = cache_control_args.get(
+ "s-max-age", cache_control_args.get("s-maxage", float("inf"))
+ )
+ cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
+ return self._get_cache_logic(
+ cached_result=cached_result, max_age=max_age
+ )
+ except Exception:
+ print_verbose(f"An exception occurred: {traceback.format_exc()}")
+ return None
+
+ def _add_cache_logic(self, result, **kwargs):
+ """
+ Common implementation across sync + async add_cache functions
+ """
+ try:
+ if "cache_key" in kwargs:
+ cache_key = kwargs["cache_key"]
+ else:
+ cache_key = self.get_cache_key(**kwargs)
+ if cache_key is not None:
+ if isinstance(result, BaseModel):
+ result = result.model_dump_json()
+
+ ## DEFAULT TTL ##
+ if self.ttl is not None:
+ kwargs["ttl"] = self.ttl
+ ## Get Cache-Controls ##
+ _cache_kwargs = kwargs.get("cache", None)
+ if isinstance(_cache_kwargs, dict):
+ for k, v in _cache_kwargs.items():
+ if k == "ttl":
+ kwargs["ttl"] = v
+
+ cached_data = {"timestamp": time.time(), "response": result}
+ return cache_key, cached_data, kwargs
+ else:
+ raise Exception("cache key is None")
+ except Exception as e:
+ raise e
+
+ def add_cache(self, result, **kwargs):
+ """
+ Adds a result to the cache.
+
+ Args:
+ *args: args to litellm.completion() or embedding()
+ **kwargs: kwargs to litellm.completion() or embedding()
+
+ Returns:
+ None
+ """
+ try:
+ if self.should_use_cache(**kwargs) is not True:
+ return
+ cache_key, cached_data, kwargs = self._add_cache_logic(
+ result=result, **kwargs
+ )
+ self.cache.set_cache(cache_key, cached_data, **kwargs)
+ except Exception as e:
+ verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
+
+ async def async_add_cache(self, result, **kwargs):
+ """
+ Async implementation of add_cache
+ """
+ try:
+ if self.should_use_cache(**kwargs) is not True:
+ return
+ if self.type == "redis" and self.redis_flush_size is not None:
+ # high traffic - fill in results in memory and then flush
+ await self.batch_cache_write(result, **kwargs)
+ else:
+ cache_key, cached_data, kwargs = self._add_cache_logic(
+ result=result, **kwargs
+ )
+
+ await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
+ except Exception as e:
+ verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
+
+ async def async_add_cache_pipeline(self, result, **kwargs):
+ """
+ Async implementation of add_cache for Embedding calls
+
+ Does a bulk write, to prevent using too many clients
+ """
+ try:
+ if self.should_use_cache(**kwargs) is not True:
+ return
+
+ # set default ttl if not set
+ if self.ttl is not None:
+ kwargs["ttl"] = self.ttl
+
+ cache_list = []
+ for idx, i in enumerate(kwargs["input"]):
+ preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
+ kwargs["cache_key"] = preset_cache_key
+ embedding_response = result.data[idx]
+ cache_key, cached_data, kwargs = self._add_cache_logic(
+ result=embedding_response,
+ **kwargs,
+ )
+ cache_list.append((cache_key, cached_data))
+
+ await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
+ # if async_set_cache_pipeline:
+ # await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
+ # else:
+ # tasks = []
+ # for val in cache_list:
+ # tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
+ # await asyncio.gather(*tasks)
+ except Exception as e:
+ verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
+
+ def should_use_cache(self, **kwargs):
+ """
+ Returns true if we should use the cache for LLM API calls
+
+ If cache is default_on then this is True
+ If cache is default_off then this is only true when user has opted in to use cache
+ """
+ if self.mode == CacheMode.default_on:
+ return True
+
+ # when mode == default_off -> Cache is opt in only
+ _cache = kwargs.get("cache", None)
+ verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
+ if _cache and isinstance(_cache, dict):
+ if _cache.get("use-cache", False) is True:
+ return True
+ return False
+
+ async def batch_cache_write(self, result, **kwargs):
+ cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
+ await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
+
+ async def ping(self):
+ cache_ping = getattr(self.cache, "ping")
+ if cache_ping:
+ return await cache_ping()
+ return None
+
+ async def delete_cache_keys(self, keys):
+ cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
+ if cache_delete_cache_keys:
+ return await cache_delete_cache_keys(keys)
+ return None
+
+ async def disconnect(self):
+ if hasattr(self.cache, "disconnect"):
+ await self.cache.disconnect()
+
+ def _supports_async(self) -> bool:
+ """
+ Internal method to check if the cache type supports async get/set operations
+
+ Only S3 Cache Does NOT support async operations
+
+ """
+ if self.type and self.type == LiteLLMCacheType.S3:
+ return False
+ return True
+
+
+def enable_cache(
+ type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
+ host: Optional[str] = None,
+ port: Optional[str] = None,
+ password: Optional[str] = None,
+ supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
+ "completion",
+ "acompletion",
+ "embedding",
+ "aembedding",
+ "atranscription",
+ "transcription",
+ "atext_completion",
+ "text_completion",
+ "arerank",
+ "rerank",
+ ],
+ **kwargs,
+):
+ """
+ Enable cache with the specified configuration.
+
+ Args:
+ type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
+ host (Optional[str]): The host address of the cache server. Defaults to None.
+ port (Optional[str]): The port number of the cache server. Defaults to None.
+ password (Optional[str]): The password for the cache server. Defaults to None.
+ supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
+ The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ None
+
+ Raises:
+ None
+ """
+ print_verbose("LiteLLM: Enabling Cache")
+ if "cache" not in litellm.input_callback:
+ litellm.input_callback.append("cache")
+ if "cache" not in litellm.success_callback:
+ litellm.logging_callback_manager.add_litellm_success_callback("cache")
+ if "cache" not in litellm._async_success_callback:
+ litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
+
+ if litellm.cache is None:
+ litellm.cache = Cache(
+ type=type,
+ host=host,
+ port=port,
+ password=password,
+ supported_call_types=supported_call_types,
+ **kwargs,
+ )
+ print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
+ print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
+
+
+def update_cache(
+ type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
+ host: Optional[str] = None,
+ port: Optional[str] = None,
+ password: Optional[str] = None,
+ supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
+ "completion",
+ "acompletion",
+ "embedding",
+ "aembedding",
+ "atranscription",
+ "transcription",
+ "atext_completion",
+ "text_completion",
+ "arerank",
+ "rerank",
+ ],
+ **kwargs,
+):
+ """
+ Update the cache for LiteLLM.
+
+ Args:
+ type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
+ host (Optional[str]): The host of the cache. Defaults to None.
+ port (Optional[str]): The port of the cache. Defaults to None.
+ password (Optional[str]): The password for the cache. Defaults to None.
+ supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
+ The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
+ **kwargs: Additional keyword arguments for the cache.
+
+ Returns:
+ None
+
+ """
+ print_verbose("LiteLLM: Updating Cache")
+ litellm.cache = Cache(
+ type=type,
+ host=host,
+ port=port,
+ password=password,
+ supported_call_types=supported_call_types,
+ **kwargs,
+ )
+ print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
+ print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
+
+
+def disable_cache():
+ """
+ Disable the cache used by LiteLLM.
+
+ This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
+
+ Parameters:
+ None
+
+ Returns:
+ None
+ """
+ from contextlib import suppress
+
+ print_verbose("LiteLLM: Disabling Cache")
+ with suppress(ValueError):
+ litellm.input_callback.remove("cache")
+ litellm.success_callback.remove("cache")
+ litellm._async_success_callback.remove("cache")
+
+ litellm.cache = None
+ print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")