diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/caching/caching.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/caching/caching.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/caching/caching.py | 797 |
1 files changed, 797 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/caching.py b/.venv/lib/python3.12/site-packages/litellm/caching/caching.py new file mode 100644 index 00000000..415c49ed --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/caching.py @@ -0,0 +1,797 @@ +# +-----------------------------------------------+ +# | | +# | Give Feedback / Get Help | +# | https://github.com/BerriAI/litellm/issues/new | +# | | +# +-----------------------------------------------+ +# +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import ast +import hashlib +import json +import time +import traceback +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel + +import litellm +from litellm._logging import verbose_logger +from litellm.litellm_core_utils.model_param_helper import ModelParamHelper +from litellm.types.caching import * +from litellm.types.utils import all_litellm_params + +from .base_cache import BaseCache +from .disk_cache import DiskCache +from .dual_cache import DualCache # noqa +from .in_memory_cache import InMemoryCache +from .qdrant_semantic_cache import QdrantSemanticCache +from .redis_cache import RedisCache +from .redis_cluster_cache import RedisClusterCache +from .redis_semantic_cache import RedisSemanticCache +from .s3_cache import S3Cache + + +def print_verbose(print_statement): + try: + verbose_logger.debug(print_statement) + if litellm.set_verbose: + print(print_statement) # noqa + except Exception: + pass + + +class CacheMode(str, Enum): + default_on = "default_on" + default_off = "default_off" + + +#### LiteLLM.Completion / Embedding Cache #### +class Cache: + def __init__( + self, + type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, + mode: Optional[ + CacheMode + ] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in + host: Optional[str] = None, + port: Optional[str] = None, + password: Optional[str] = None, + namespace: Optional[str] = None, + ttl: Optional[float] = None, + default_in_memory_ttl: Optional[float] = None, + default_in_redis_ttl: Optional[float] = None, + similarity_threshold: Optional[float] = None, + supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + "atext_completion", + "text_completion", + "arerank", + "rerank", + ], + # s3 Bucket, boto3 configuration + s3_bucket_name: Optional[str] = None, + s3_region_name: Optional[str] = None, + s3_api_version: Optional[str] = None, + s3_use_ssl: Optional[bool] = True, + s3_verify: Optional[Union[bool, str]] = None, + s3_endpoint_url: Optional[str] = None, + s3_aws_access_key_id: Optional[str] = None, + s3_aws_secret_access_key: Optional[str] = None, + s3_aws_session_token: Optional[str] = None, + s3_config: Optional[Any] = None, + s3_path: Optional[str] = None, + redis_semantic_cache_use_async=False, + redis_semantic_cache_embedding_model="text-embedding-ada-002", + redis_flush_size: Optional[int] = None, + redis_startup_nodes: Optional[List] = None, + disk_cache_dir=None, + qdrant_api_base: Optional[str] = None, + qdrant_api_key: Optional[str] = None, + qdrant_collection_name: Optional[str] = None, + qdrant_quantization_config: Optional[str] = None, + qdrant_semantic_cache_embedding_model="text-embedding-ada-002", + **kwargs, + ): + """ + Initializes the cache based on the given type. + + Args: + type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local". + + # Redis Cache Args + host (str, optional): The host address for the Redis cache. Required if type is "redis". + port (int, optional): The port number for the Redis cache. Required if type is "redis". + password (str, optional): The password for the Redis cache. Required if type is "redis". + namespace (str, optional): The namespace for the Redis cache. Required if type is "redis". + ttl (float, optional): The ttl for the Redis cache + redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used. + redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None. + + # Qdrant Cache Args + qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic". + qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster. + qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic". + similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic". + + # Disk Cache Args + disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None. + + # S3 Cache Args + s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None. + s3_region_name (str, optional): The region name for the s3 cache. Defaults to None. + s3_api_version (str, optional): The api version for the s3 cache. Defaults to None. + s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True. + s3_verify (bool, optional): The verify for the s3 cache. Defaults to None. + s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None. + s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None. + s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None. + s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None. + s3_config (dict, optional): The config for the s3 cache. Defaults to None. + + # Common Cache Args + supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types. + **kwargs: Additional keyword arguments for redis.Redis() cache + + Raises: + ValueError: If an invalid cache type is provided. + + Returns: + None. Cache is set as a litellm param + """ + if type == LiteLLMCacheType.REDIS: + if redis_startup_nodes: + self.cache: BaseCache = RedisClusterCache( + host=host, + port=port, + password=password, + redis_flush_size=redis_flush_size, + startup_nodes=redis_startup_nodes, + **kwargs, + ) + else: + self.cache = RedisCache( + host=host, + port=port, + password=password, + redis_flush_size=redis_flush_size, + **kwargs, + ) + elif type == LiteLLMCacheType.REDIS_SEMANTIC: + self.cache = RedisSemanticCache( + host=host, + port=port, + password=password, + similarity_threshold=similarity_threshold, + use_async=redis_semantic_cache_use_async, + embedding_model=redis_semantic_cache_embedding_model, + **kwargs, + ) + elif type == LiteLLMCacheType.QDRANT_SEMANTIC: + self.cache = QdrantSemanticCache( + qdrant_api_base=qdrant_api_base, + qdrant_api_key=qdrant_api_key, + collection_name=qdrant_collection_name, + similarity_threshold=similarity_threshold, + quantization_config=qdrant_quantization_config, + embedding_model=qdrant_semantic_cache_embedding_model, + ) + elif type == LiteLLMCacheType.LOCAL: + self.cache = InMemoryCache() + elif type == LiteLLMCacheType.S3: + self.cache = S3Cache( + s3_bucket_name=s3_bucket_name, + s3_region_name=s3_region_name, + s3_api_version=s3_api_version, + s3_use_ssl=s3_use_ssl, + s3_verify=s3_verify, + s3_endpoint_url=s3_endpoint_url, + s3_aws_access_key_id=s3_aws_access_key_id, + s3_aws_secret_access_key=s3_aws_secret_access_key, + s3_aws_session_token=s3_aws_session_token, + s3_config=s3_config, + s3_path=s3_path, + **kwargs, + ) + elif type == LiteLLMCacheType.DISK: + self.cache = DiskCache(disk_cache_dir=disk_cache_dir) + if "cache" not in litellm.input_callback: + litellm.input_callback.append("cache") + if "cache" not in litellm.success_callback: + litellm.logging_callback_manager.add_litellm_success_callback("cache") + if "cache" not in litellm._async_success_callback: + litellm.logging_callback_manager.add_litellm_async_success_callback("cache") + self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"] + self.type = type + self.namespace = namespace + self.redis_flush_size = redis_flush_size + self.ttl = ttl + self.mode: CacheMode = mode or CacheMode.default_on + + if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None: + self.ttl = default_in_memory_ttl + + if ( + self.type == LiteLLMCacheType.REDIS + or self.type == LiteLLMCacheType.REDIS_SEMANTIC + ) and default_in_redis_ttl is not None: + self.ttl = default_in_redis_ttl + + if self.namespace is not None and isinstance(self.cache, RedisCache): + self.cache.namespace = self.namespace + + def get_cache_key(self, **kwargs) -> str: + """ + Get the cache key for the given arguments. + + Args: + **kwargs: kwargs to litellm.completion() or embedding() + + Returns: + str: The cache key generated from the arguments, or None if no cache key could be generated. + """ + cache_key = "" + # verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs) + + preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs) + if preset_cache_key is not None: + verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key) + return preset_cache_key + + combined_kwargs = ModelParamHelper._get_all_llm_api_params() + litellm_param_kwargs = all_litellm_params + for param in kwargs: + if param in combined_kwargs: + param_value: Optional[str] = self._get_param_value(param, kwargs) + if param_value is not None: + cache_key += f"{str(param)}: {str(param_value)}" + elif ( + param not in litellm_param_kwargs + ): # check if user passed in optional param - e.g. top_k + if ( + litellm.enable_caching_on_provider_specific_optional_params is True + ): # feature flagged for now + if kwargs[param] is None: + continue # ignore None params + param_value = kwargs[param] + cache_key += f"{str(param)}: {str(param_value)}" + + verbose_logger.debug("\nCreated cache key: %s", cache_key) + hashed_cache_key = Cache._get_hashed_cache_key(cache_key) + hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs) + self._set_preset_cache_key_in_kwargs( + preset_cache_key=hashed_cache_key, **kwargs + ) + return hashed_cache_key + + def _get_param_value( + self, + param: str, + kwargs: dict, + ) -> Optional[str]: + """ + Get the value for the given param from kwargs + """ + if param == "model": + return self._get_model_param_value(kwargs) + elif param == "file": + return self._get_file_param_value(kwargs) + return kwargs[param] + + def _get_model_param_value(self, kwargs: dict) -> str: + """ + Handles getting the value for the 'model' param from kwargs + + 1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups + 2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router() + 3. Else use the `model` passed in kwargs + """ + metadata: Dict = kwargs.get("metadata", {}) or {} + litellm_params: Dict = kwargs.get("litellm_params", {}) or {} + metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {} + model_group: Optional[str] = metadata.get( + "model_group" + ) or metadata_in_litellm_params.get("model_group") + caching_group = self._get_caching_group(metadata, model_group) + return caching_group or model_group or kwargs["model"] + + def _get_caching_group( + self, metadata: dict, model_group: Optional[str] + ) -> Optional[str]: + caching_groups: Optional[List] = metadata.get("caching_groups", []) + if caching_groups: + for group in caching_groups: + if model_group in group: + return str(group) + return None + + def _get_file_param_value(self, kwargs: dict) -> str: + """ + Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests + """ + file = kwargs.get("file") + metadata = kwargs.get("metadata", {}) + litellm_params = kwargs.get("litellm_params", {}) + return ( + metadata.get("file_checksum") + or getattr(file, "name", None) + or metadata.get("file_name") + or litellm_params.get("file_name") + ) + + def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]: + """ + Get the preset cache key from kwargs["litellm_params"] + + We use _get_preset_cache_keys for two reasons + + 1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens + 2. avoid doing duplicate / repeated work + """ + if kwargs: + if "litellm_params" in kwargs: + return kwargs["litellm_params"].get("preset_cache_key", None) + return None + + def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None: + """ + Set the calculated cache key in kwargs + + This is used to avoid doing duplicate / repeated work + + Placed in kwargs["litellm_params"] + """ + if kwargs: + if "litellm_params" in kwargs: + kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key + + @staticmethod + def _get_hashed_cache_key(cache_key: str) -> str: + """ + Get the hashed cache key for the given cache key. + + Use hashlib to create a sha256 hash of the cache key + + Args: + cache_key (str): The cache key to hash. + + Returns: + str: The hashed cache key. + """ + hash_object = hashlib.sha256(cache_key.encode()) + # Hexadecimal representation of the hash + hash_hex = hash_object.hexdigest() + verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex) + return hash_hex + + def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str: + """ + If a redis namespace is provided, add it to the cache key + + Args: + hash_hex (str): The hashed cache key. + **kwargs: Additional keyword arguments. + + Returns: + str: The final hashed cache key with the redis namespace. + """ + dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {}) + namespace = ( + dynamic_cache_control.get("namespace") + or kwargs.get("metadata", {}).get("redis_namespace") + or self.namespace + ) + if namespace: + hash_hex = f"{namespace}:{hash_hex}" + verbose_logger.debug("Final hashed key: %s", hash_hex) + return hash_hex + + def generate_streaming_content(self, content): + chunk_size = 5 # Adjust the chunk size as needed + for i in range(0, len(content), chunk_size): + yield { + "choices": [ + { + "delta": { + "role": "assistant", + "content": content[i : i + chunk_size], + } + } + ] + } + time.sleep(0.02) + + def _get_cache_logic( + self, + cached_result: Optional[Any], + max_age: Optional[float], + ): + """ + Common get cache logic across sync + async implementations + """ + # Check if a timestamp was stored with the cached response + if ( + cached_result is not None + and isinstance(cached_result, dict) + and "timestamp" in cached_result + ): + timestamp = cached_result["timestamp"] + current_time = time.time() + + # Calculate age of the cached response + response_age = current_time - timestamp + + # Check if the cached response is older than the max-age + if max_age is not None and response_age > max_age: + return None # Cached response is too old + + # If the response is fresh, or there's no max-age requirement, return the cached response + # cached_response is in `b{} convert it to ModelResponse + cached_response = cached_result.get("response") + try: + if isinstance(cached_response, dict): + pass + else: + cached_response = json.loads( + cached_response # type: ignore + ) # Convert string to dictionary + except Exception: + cached_response = ast.literal_eval(cached_response) # type: ignore + return cached_response + return cached_result + + def get_cache(self, **kwargs): + """ + Retrieves the cached result for the given arguments. + + Args: + *args: args to litellm.completion() or embedding() + **kwargs: kwargs to litellm.completion() or embedding() + + Returns: + The cached result if it exists, otherwise None. + """ + try: # never block execution + if self.should_use_cache(**kwargs) is not True: + return + messages = kwargs.get("messages", []) + if "cache_key" in kwargs: + cache_key = kwargs["cache_key"] + else: + cache_key = self.get_cache_key(**kwargs) + if cache_key is not None: + cache_control_args: DynamicCacheControl = kwargs.get("cache", {}) + max_age = ( + cache_control_args.get("s-maxage") + or cache_control_args.get("s-max-age") + or float("inf") + ) + cached_result = self.cache.get_cache(cache_key, messages=messages) + cached_result = self.cache.get_cache(cache_key, messages=messages) + return self._get_cache_logic( + cached_result=cached_result, max_age=max_age + ) + except Exception: + print_verbose(f"An exception occurred: {traceback.format_exc()}") + return None + + async def async_get_cache(self, **kwargs): + """ + Async get cache implementation. + + Used for embedding calls in async wrapper + """ + + try: # never block execution + if self.should_use_cache(**kwargs) is not True: + return + + kwargs.get("messages", []) + if "cache_key" in kwargs: + cache_key = kwargs["cache_key"] + else: + cache_key = self.get_cache_key(**kwargs) + if cache_key is not None: + cache_control_args = kwargs.get("cache", {}) + max_age = cache_control_args.get( + "s-max-age", cache_control_args.get("s-maxage", float("inf")) + ) + cached_result = await self.cache.async_get_cache(cache_key, **kwargs) + return self._get_cache_logic( + cached_result=cached_result, max_age=max_age + ) + except Exception: + print_verbose(f"An exception occurred: {traceback.format_exc()}") + return None + + def _add_cache_logic(self, result, **kwargs): + """ + Common implementation across sync + async add_cache functions + """ + try: + if "cache_key" in kwargs: + cache_key = kwargs["cache_key"] + else: + cache_key = self.get_cache_key(**kwargs) + if cache_key is not None: + if isinstance(result, BaseModel): + result = result.model_dump_json() + + ## DEFAULT TTL ## + if self.ttl is not None: + kwargs["ttl"] = self.ttl + ## Get Cache-Controls ## + _cache_kwargs = kwargs.get("cache", None) + if isinstance(_cache_kwargs, dict): + for k, v in _cache_kwargs.items(): + if k == "ttl": + kwargs["ttl"] = v + + cached_data = {"timestamp": time.time(), "response": result} + return cache_key, cached_data, kwargs + else: + raise Exception("cache key is None") + except Exception as e: + raise e + + def add_cache(self, result, **kwargs): + """ + Adds a result to the cache. + + Args: + *args: args to litellm.completion() or embedding() + **kwargs: kwargs to litellm.completion() or embedding() + + Returns: + None + """ + try: + if self.should_use_cache(**kwargs) is not True: + return + cache_key, cached_data, kwargs = self._add_cache_logic( + result=result, **kwargs + ) + self.cache.set_cache(cache_key, cached_data, **kwargs) + except Exception as e: + verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + + async def async_add_cache(self, result, **kwargs): + """ + Async implementation of add_cache + """ + try: + if self.should_use_cache(**kwargs) is not True: + return + if self.type == "redis" and self.redis_flush_size is not None: + # high traffic - fill in results in memory and then flush + await self.batch_cache_write(result, **kwargs) + else: + cache_key, cached_data, kwargs = self._add_cache_logic( + result=result, **kwargs + ) + + await self.cache.async_set_cache(cache_key, cached_data, **kwargs) + except Exception as e: + verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + + async def async_add_cache_pipeline(self, result, **kwargs): + """ + Async implementation of add_cache for Embedding calls + + Does a bulk write, to prevent using too many clients + """ + try: + if self.should_use_cache(**kwargs) is not True: + return + + # set default ttl if not set + if self.ttl is not None: + kwargs["ttl"] = self.ttl + + cache_list = [] + for idx, i in enumerate(kwargs["input"]): + preset_cache_key = self.get_cache_key(**{**kwargs, "input": i}) + kwargs["cache_key"] = preset_cache_key + embedding_response = result.data[idx] + cache_key, cached_data, kwargs = self._add_cache_logic( + result=embedding_response, + **kwargs, + ) + cache_list.append((cache_key, cached_data)) + + await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs) + # if async_set_cache_pipeline: + # await async_set_cache_pipeline(cache_list=cache_list, **kwargs) + # else: + # tasks = [] + # for val in cache_list: + # tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs)) + # await asyncio.gather(*tasks) + except Exception as e: + verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + + def should_use_cache(self, **kwargs): + """ + Returns true if we should use the cache for LLM API calls + + If cache is default_on then this is True + If cache is default_off then this is only true when user has opted in to use cache + """ + if self.mode == CacheMode.default_on: + return True + + # when mode == default_off -> Cache is opt in only + _cache = kwargs.get("cache", None) + verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache) + if _cache and isinstance(_cache, dict): + if _cache.get("use-cache", False) is True: + return True + return False + + async def batch_cache_write(self, result, **kwargs): + cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs) + await self.cache.batch_cache_write(cache_key, cached_data, **kwargs) + + async def ping(self): + cache_ping = getattr(self.cache, "ping") + if cache_ping: + return await cache_ping() + return None + + async def delete_cache_keys(self, keys): + cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys") + if cache_delete_cache_keys: + return await cache_delete_cache_keys(keys) + return None + + async def disconnect(self): + if hasattr(self.cache, "disconnect"): + await self.cache.disconnect() + + def _supports_async(self) -> bool: + """ + Internal method to check if the cache type supports async get/set operations + + Only S3 Cache Does NOT support async operations + + """ + if self.type and self.type == LiteLLMCacheType.S3: + return False + return True + + +def enable_cache( + type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, + host: Optional[str] = None, + port: Optional[str] = None, + password: Optional[str] = None, + supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + "atext_completion", + "text_completion", + "arerank", + "rerank", + ], + **kwargs, +): + """ + Enable cache with the specified configuration. + + Args: + type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local". + host (Optional[str]): The host address of the cache server. Defaults to None. + port (Optional[str]): The port number of the cache server. Defaults to None. + password (Optional[str]): The password for the cache server. Defaults to None. + supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]): + The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"]. + **kwargs: Additional keyword arguments. + + Returns: + None + + Raises: + None + """ + print_verbose("LiteLLM: Enabling Cache") + if "cache" not in litellm.input_callback: + litellm.input_callback.append("cache") + if "cache" not in litellm.success_callback: + litellm.logging_callback_manager.add_litellm_success_callback("cache") + if "cache" not in litellm._async_success_callback: + litellm.logging_callback_manager.add_litellm_async_success_callback("cache") + + if litellm.cache is None: + litellm.cache = Cache( + type=type, + host=host, + port=port, + password=password, + supported_call_types=supported_call_types, + **kwargs, + ) + print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}") + print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}") + + +def update_cache( + type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, + host: Optional[str] = None, + port: Optional[str] = None, + password: Optional[str] = None, + supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + "atext_completion", + "text_completion", + "arerank", + "rerank", + ], + **kwargs, +): + """ + Update the cache for LiteLLM. + + Args: + type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local". + host (Optional[str]): The host of the cache. Defaults to None. + port (Optional[str]): The port of the cache. Defaults to None. + password (Optional[str]): The password for the cache. Defaults to None. + supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]): + The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"]. + **kwargs: Additional keyword arguments for the cache. + + Returns: + None + + """ + print_verbose("LiteLLM: Updating Cache") + litellm.cache = Cache( + type=type, + host=host, + port=port, + password=password, + supported_call_types=supported_call_types, + **kwargs, + ) + print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}") + print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}") + + +def disable_cache(): + """ + Disable the cache used by LiteLLM. + + This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None. + + Parameters: + None + + Returns: + None + """ + from contextlib import suppress + + print_verbose("LiteLLM: Disabling Cache") + with suppress(ValueError): + litellm.input_callback.remove("cache") + litellm.success_callback.remove("cache") + litellm._async_success_callback.remove("cache") + + litellm.cache = None + print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}") |