aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py337
1 files changed, 337 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py
new file mode 100644
index 00000000..b609286a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py
@@ -0,0 +1,337 @@
+"""
+Redis Semantic Cache implementation
+
+Has 4 methods:
+ - set_cache
+ - get_cache
+ - async_set_cache
+ - async_get_cache
+"""
+
+import ast
+import asyncio
+import json
+from typing import Any
+
+import litellm
+from litellm._logging import print_verbose
+
+from .base_cache import BaseCache
+
+
+class RedisSemanticCache(BaseCache):
+ def __init__(
+ self,
+ host=None,
+ port=None,
+ password=None,
+ redis_url=None,
+ similarity_threshold=None,
+ use_async=False,
+ embedding_model="text-embedding-ada-002",
+ **kwargs,
+ ):
+ from redisvl.index import SearchIndex
+
+ print_verbose(
+ "redis semantic-cache initializing INDEX - litellm_semantic_cache_index"
+ )
+ if similarity_threshold is None:
+ raise Exception("similarity_threshold must be provided, passed None")
+ self.similarity_threshold = similarity_threshold
+ self.embedding_model = embedding_model
+ schema = {
+ "index": {
+ "name": "litellm_semantic_cache_index",
+ "prefix": "litellm",
+ "storage_type": "hash",
+ },
+ "fields": {
+ "text": [{"name": "response"}],
+ "vector": [
+ {
+ "name": "litellm_embedding",
+ "dims": 1536,
+ "distance_metric": "cosine",
+ "algorithm": "flat",
+ "datatype": "float32",
+ }
+ ],
+ },
+ }
+ if redis_url is None:
+ # if no url passed, check if host, port and password are passed, if not raise an Exception
+ if host is None or port is None or password is None:
+ # try checking env for host, port and password
+ import os
+
+ host = os.getenv("REDIS_HOST")
+ port = os.getenv("REDIS_PORT")
+ password = os.getenv("REDIS_PASSWORD")
+ if host is None or port is None or password is None:
+ raise Exception("Redis host, port, and password must be provided")
+
+ redis_url = "redis://:" + password + "@" + host + ":" + port
+ print_verbose(f"redis semantic-cache redis_url: {redis_url}")
+ if use_async is False:
+ self.index = SearchIndex.from_dict(schema)
+ self.index.connect(redis_url=redis_url)
+ try:
+ self.index.create(overwrite=False) # don't overwrite existing index
+ except Exception as e:
+ print_verbose(f"Got exception creating semantic cache index: {str(e)}")
+ elif use_async is True:
+ schema["index"]["name"] = "litellm_semantic_cache_index_async"
+ self.index = SearchIndex.from_dict(schema)
+ self.index.connect(redis_url=redis_url, use_async=True)
+
+ #
+ def _get_cache_logic(self, cached_response: Any):
+ """
+ Common 'get_cache_logic' across sync + async redis client implementations
+ """
+ if cached_response is None:
+ return cached_response
+
+ # check if cached_response is bytes
+ if isinstance(cached_response, bytes):
+ cached_response = cached_response.decode("utf-8")
+
+ try:
+ cached_response = json.loads(
+ cached_response
+ ) # Convert string to dictionary
+ except Exception:
+ cached_response = ast.literal_eval(cached_response)
+ return cached_response
+
+ def set_cache(self, key, value, **kwargs):
+ import numpy as np
+
+ print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}")
+
+ # get the prompt
+ messages = kwargs["messages"]
+ prompt = "".join(message["content"] for message in messages)
+
+ # create an embedding for prompt
+ embedding_response = litellm.embedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ # make the embedding a numpy array, convert to bytes
+ embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
+ value = str(value)
+ assert isinstance(value, str)
+
+ new_data = [
+ {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
+ ]
+
+ # Add more data
+ self.index.load(new_data)
+
+ return
+
+ def get_cache(self, key, **kwargs):
+ print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
+ from redisvl.query import VectorQuery
+
+ # query
+ # get the messages
+ messages = kwargs["messages"]
+ prompt = "".join(message["content"] for message in messages)
+
+ # convert to embedding
+ embedding_response = litellm.embedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ query = VectorQuery(
+ vector=embedding,
+ vector_field_name="litellm_embedding",
+ return_fields=["response", "prompt", "vector_distance"],
+ num_results=1,
+ )
+
+ results = self.index.query(query)
+ if results is None:
+ return None
+ if isinstance(results, list):
+ if len(results) == 0:
+ return None
+
+ vector_distance = results[0]["vector_distance"]
+ vector_distance = float(vector_distance)
+ similarity = 1 - vector_distance
+ cached_prompt = results[0]["prompt"]
+
+ # check similarity, if more than self.similarity_threshold, return results
+ print_verbose(
+ f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
+ )
+ if similarity > self.similarity_threshold:
+ # cache hit !
+ cached_value = results[0]["response"]
+ print_verbose(
+ f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
+ )
+ return self._get_cache_logic(cached_response=cached_value)
+ else:
+ # cache miss !
+ return None
+
+ pass
+
+ async def async_set_cache(self, key, value, **kwargs):
+ import numpy as np
+
+ from litellm.proxy.proxy_server import llm_model_list, llm_router
+
+ try:
+ await self.index.acreate(overwrite=False) # don't overwrite existing index
+ except Exception as e:
+ print_verbose(f"Got exception creating semantic cache index: {str(e)}")
+ print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}")
+
+ # get the prompt
+ messages = kwargs["messages"]
+ prompt = "".join(message["content"] for message in messages)
+ # create an embedding for prompt
+ router_model_names = (
+ [m["model_name"] for m in llm_model_list]
+ if llm_model_list is not None
+ else []
+ )
+ if llm_router is not None and self.embedding_model in router_model_names:
+ user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
+ embedding_response = await llm_router.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ metadata={
+ "user_api_key": user_api_key,
+ "semantic-cache-embedding": True,
+ "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
+ },
+ )
+ else:
+ # convert to embedding
+ embedding_response = await litellm.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ # make the embedding a numpy array, convert to bytes
+ embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
+ value = str(value)
+ assert isinstance(value, str)
+
+ new_data = [
+ {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
+ ]
+
+ # Add more data
+ await self.index.aload(new_data)
+ return
+
+ async def async_get_cache(self, key, **kwargs):
+ print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
+ from redisvl.query import VectorQuery
+
+ from litellm.proxy.proxy_server import llm_model_list, llm_router
+
+ # query
+ # get the messages
+ messages = kwargs["messages"]
+ prompt = "".join(message["content"] for message in messages)
+
+ router_model_names = (
+ [m["model_name"] for m in llm_model_list]
+ if llm_model_list is not None
+ else []
+ )
+ if llm_router is not None and self.embedding_model in router_model_names:
+ user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
+ embedding_response = await llm_router.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ metadata={
+ "user_api_key": user_api_key,
+ "semantic-cache-embedding": True,
+ "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
+ },
+ )
+ else:
+ # convert to embedding
+ embedding_response = await litellm.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ query = VectorQuery(
+ vector=embedding,
+ vector_field_name="litellm_embedding",
+ return_fields=["response", "prompt", "vector_distance"],
+ )
+ results = await self.index.aquery(query)
+ if results is None:
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
+ return None
+ if isinstance(results, list):
+ if len(results) == 0:
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
+ return None
+
+ vector_distance = results[0]["vector_distance"]
+ vector_distance = float(vector_distance)
+ similarity = 1 - vector_distance
+ cached_prompt = results[0]["prompt"]
+
+ # check similarity, if more than self.similarity_threshold, return results
+ print_verbose(
+ f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
+ )
+
+ # update kwargs["metadata"] with similarity, don't rewrite the original metadata
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
+
+ if similarity > self.similarity_threshold:
+ # cache hit !
+ cached_value = results[0]["response"]
+ print_verbose(
+ f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
+ )
+ return self._get_cache_logic(cached_response=cached_value)
+ else:
+ # cache miss !
+ return None
+ pass
+
+ async def _index_info(self):
+ return await self.index.ainfo()
+
+ async def async_set_cache_pipeline(self, cache_list, **kwargs):
+ tasks = []
+ for val in cache_list:
+ tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
+ await asyncio.gather(*tasks)