diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py | 337 |
1 files changed, 337 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py new file mode 100644 index 00000000..b609286a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py @@ -0,0 +1,337 @@ +""" +Redis Semantic Cache implementation + +Has 4 methods: + - set_cache + - get_cache + - async_set_cache + - async_get_cache +""" + +import ast +import asyncio +import json +from typing import Any + +import litellm +from litellm._logging import print_verbose + +from .base_cache import BaseCache + + +class RedisSemanticCache(BaseCache): + def __init__( + self, + host=None, + port=None, + password=None, + redis_url=None, + similarity_threshold=None, + use_async=False, + embedding_model="text-embedding-ada-002", + **kwargs, + ): + from redisvl.index import SearchIndex + + print_verbose( + "redis semantic-cache initializing INDEX - litellm_semantic_cache_index" + ) + if similarity_threshold is None: + raise Exception("similarity_threshold must be provided, passed None") + self.similarity_threshold = similarity_threshold + self.embedding_model = embedding_model + schema = { + "index": { + "name": "litellm_semantic_cache_index", + "prefix": "litellm", + "storage_type": "hash", + }, + "fields": { + "text": [{"name": "response"}], + "vector": [ + { + "name": "litellm_embedding", + "dims": 1536, + "distance_metric": "cosine", + "algorithm": "flat", + "datatype": "float32", + } + ], + }, + } + if redis_url is None: + # if no url passed, check if host, port and password are passed, if not raise an Exception + if host is None or port is None or password is None: + # try checking env for host, port and password + import os + + host = os.getenv("REDIS_HOST") + port = os.getenv("REDIS_PORT") + password = os.getenv("REDIS_PASSWORD") + if host is None or port is None or password is None: + raise Exception("Redis host, port, and password must be provided") + + redis_url = "redis://:" + password + "@" + host + ":" + port + print_verbose(f"redis semantic-cache redis_url: {redis_url}") + if use_async is False: + self.index = SearchIndex.from_dict(schema) + self.index.connect(redis_url=redis_url) + try: + self.index.create(overwrite=False) # don't overwrite existing index + except Exception as e: + print_verbose(f"Got exception creating semantic cache index: {str(e)}") + elif use_async is True: + schema["index"]["name"] = "litellm_semantic_cache_index_async" + self.index = SearchIndex.from_dict(schema) + self.index.connect(redis_url=redis_url, use_async=True) + + # + def _get_cache_logic(self, cached_response: Any): + """ + Common 'get_cache_logic' across sync + async redis client implementations + """ + if cached_response is None: + return cached_response + + # check if cached_response is bytes + if isinstance(cached_response, bytes): + cached_response = cached_response.decode("utf-8") + + try: + cached_response = json.loads( + cached_response + ) # Convert string to dictionary + except Exception: + cached_response = ast.literal_eval(cached_response) + return cached_response + + def set_cache(self, key, value, **kwargs): + import numpy as np + + print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}") + + # get the prompt + messages = kwargs["messages"] + prompt = "".join(message["content"] for message in messages) + + # create an embedding for prompt + embedding_response = litellm.embedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + # make the embedding a numpy array, convert to bytes + embedding_bytes = np.array(embedding, dtype=np.float32).tobytes() + value = str(value) + assert isinstance(value, str) + + new_data = [ + {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes} + ] + + # Add more data + self.index.load(new_data) + + return + + def get_cache(self, key, **kwargs): + print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}") + from redisvl.query import VectorQuery + + # query + # get the messages + messages = kwargs["messages"] + prompt = "".join(message["content"] for message in messages) + + # convert to embedding + embedding_response = litellm.embedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + query = VectorQuery( + vector=embedding, + vector_field_name="litellm_embedding", + return_fields=["response", "prompt", "vector_distance"], + num_results=1, + ) + + results = self.index.query(query) + if results is None: + return None + if isinstance(results, list): + if len(results) == 0: + return None + + vector_distance = results[0]["vector_distance"] + vector_distance = float(vector_distance) + similarity = 1 - vector_distance + cached_prompt = results[0]["prompt"] + + # check similarity, if more than self.similarity_threshold, return results + print_verbose( + f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" + ) + if similarity > self.similarity_threshold: + # cache hit ! + cached_value = results[0]["response"] + print_verbose( + f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" + ) + return self._get_cache_logic(cached_response=cached_value) + else: + # cache miss ! + return None + + pass + + async def async_set_cache(self, key, value, **kwargs): + import numpy as np + + from litellm.proxy.proxy_server import llm_model_list, llm_router + + try: + await self.index.acreate(overwrite=False) # don't overwrite existing index + except Exception as e: + print_verbose(f"Got exception creating semantic cache index: {str(e)}") + print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}") + + # get the prompt + messages = kwargs["messages"] + prompt = "".join(message["content"] for message in messages) + # create an embedding for prompt + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if llm_router is not None and self.embedding_model in router_model_names: + user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") + embedding_response = await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + metadata={ + "user_api_key": user_api_key, + "semantic-cache-embedding": True, + "trace_id": kwargs.get("metadata", {}).get("trace_id", None), + }, + ) + else: + # convert to embedding + embedding_response = await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + # make the embedding a numpy array, convert to bytes + embedding_bytes = np.array(embedding, dtype=np.float32).tobytes() + value = str(value) + assert isinstance(value, str) + + new_data = [ + {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes} + ] + + # Add more data + await self.index.aload(new_data) + return + + async def async_get_cache(self, key, **kwargs): + print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}") + from redisvl.query import VectorQuery + + from litellm.proxy.proxy_server import llm_model_list, llm_router + + # query + # get the messages + messages = kwargs["messages"] + prompt = "".join(message["content"] for message in messages) + + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if llm_router is not None and self.embedding_model in router_model_names: + user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") + embedding_response = await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + metadata={ + "user_api_key": user_api_key, + "semantic-cache-embedding": True, + "trace_id": kwargs.get("metadata", {}).get("trace_id", None), + }, + ) + else: + # convert to embedding + embedding_response = await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + + # get the embedding + embedding = embedding_response["data"][0]["embedding"] + + query = VectorQuery( + vector=embedding, + vector_field_name="litellm_embedding", + return_fields=["response", "prompt", "vector_distance"], + ) + results = await self.index.aquery(query) + if results is None: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None + if isinstance(results, list): + if len(results) == 0: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None + + vector_distance = results[0]["vector_distance"] + vector_distance = float(vector_distance) + similarity = 1 - vector_distance + cached_prompt = results[0]["prompt"] + + # check similarity, if more than self.similarity_threshold, return results + print_verbose( + f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" + ) + + # update kwargs["metadata"] with similarity, don't rewrite the original metadata + kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity + + if similarity > self.similarity_threshold: + # cache hit ! + cached_value = results[0]["response"] + print_verbose( + f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" + ) + return self._get_cache_logic(cached_response=cached_value) + else: + # cache miss ! + return None + pass + + async def _index_info(self): + return await self.index.ainfo() + + async def async_set_cache_pipeline(self, cache_list, **kwargs): + tasks = [] + for val in cache_list: + tasks.append(self.async_set_cache(val[0], val[1], **kwargs)) + await asyncio.gather(*tasks) |