about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py337
1 files changed, 337 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py
new file mode 100644
index 00000000..b609286a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/redis_semantic_cache.py
@@ -0,0 +1,337 @@
+"""
+Redis Semantic Cache implementation
+
+Has 4 methods:
+    - set_cache
+    - get_cache
+    - async_set_cache
+    - async_get_cache
+"""
+
+import ast
+import asyncio
+import json
+from typing import Any
+
+import litellm
+from litellm._logging import print_verbose
+
+from .base_cache import BaseCache
+
+
+class RedisSemanticCache(BaseCache):
+    def __init__(
+        self,
+        host=None,
+        port=None,
+        password=None,
+        redis_url=None,
+        similarity_threshold=None,
+        use_async=False,
+        embedding_model="text-embedding-ada-002",
+        **kwargs,
+    ):
+        from redisvl.index import SearchIndex
+
+        print_verbose(
+            "redis semantic-cache initializing INDEX - litellm_semantic_cache_index"
+        )
+        if similarity_threshold is None:
+            raise Exception("similarity_threshold must be provided, passed None")
+        self.similarity_threshold = similarity_threshold
+        self.embedding_model = embedding_model
+        schema = {
+            "index": {
+                "name": "litellm_semantic_cache_index",
+                "prefix": "litellm",
+                "storage_type": "hash",
+            },
+            "fields": {
+                "text": [{"name": "response"}],
+                "vector": [
+                    {
+                        "name": "litellm_embedding",
+                        "dims": 1536,
+                        "distance_metric": "cosine",
+                        "algorithm": "flat",
+                        "datatype": "float32",
+                    }
+                ],
+            },
+        }
+        if redis_url is None:
+            # if no url passed, check if host, port and password are passed, if not raise an Exception
+            if host is None or port is None or password is None:
+                # try checking env for host, port and password
+                import os
+
+                host = os.getenv("REDIS_HOST")
+                port = os.getenv("REDIS_PORT")
+                password = os.getenv("REDIS_PASSWORD")
+                if host is None or port is None or password is None:
+                    raise Exception("Redis host, port, and password must be provided")
+
+            redis_url = "redis://:" + password + "@" + host + ":" + port
+        print_verbose(f"redis semantic-cache redis_url: {redis_url}")
+        if use_async is False:
+            self.index = SearchIndex.from_dict(schema)
+            self.index.connect(redis_url=redis_url)
+            try:
+                self.index.create(overwrite=False)  # don't overwrite existing index
+            except Exception as e:
+                print_verbose(f"Got exception creating semantic cache index: {str(e)}")
+        elif use_async is True:
+            schema["index"]["name"] = "litellm_semantic_cache_index_async"
+            self.index = SearchIndex.from_dict(schema)
+            self.index.connect(redis_url=redis_url, use_async=True)
+
+    #
+    def _get_cache_logic(self, cached_response: Any):
+        """
+        Common 'get_cache_logic' across sync + async redis client implementations
+        """
+        if cached_response is None:
+            return cached_response
+
+        # check if cached_response is bytes
+        if isinstance(cached_response, bytes):
+            cached_response = cached_response.decode("utf-8")
+
+        try:
+            cached_response = json.loads(
+                cached_response
+            )  # Convert string to dictionary
+        except Exception:
+            cached_response = ast.literal_eval(cached_response)
+        return cached_response
+
+    def set_cache(self, key, value, **kwargs):
+        import numpy as np
+
+        print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}")
+
+        # get the prompt
+        messages = kwargs["messages"]
+        prompt = "".join(message["content"] for message in messages)
+
+        # create an embedding for prompt
+        embedding_response = litellm.embedding(
+            model=self.embedding_model,
+            input=prompt,
+            cache={"no-store": True, "no-cache": True},
+        )
+
+        # get the embedding
+        embedding = embedding_response["data"][0]["embedding"]
+
+        # make the embedding a numpy array, convert to bytes
+        embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
+        value = str(value)
+        assert isinstance(value, str)
+
+        new_data = [
+            {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
+        ]
+
+        # Add more data
+        self.index.load(new_data)
+
+        return
+
+    def get_cache(self, key, **kwargs):
+        print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
+        from redisvl.query import VectorQuery
+
+        # query
+        # get the messages
+        messages = kwargs["messages"]
+        prompt = "".join(message["content"] for message in messages)
+
+        # convert to embedding
+        embedding_response = litellm.embedding(
+            model=self.embedding_model,
+            input=prompt,
+            cache={"no-store": True, "no-cache": True},
+        )
+
+        # get the embedding
+        embedding = embedding_response["data"][0]["embedding"]
+
+        query = VectorQuery(
+            vector=embedding,
+            vector_field_name="litellm_embedding",
+            return_fields=["response", "prompt", "vector_distance"],
+            num_results=1,
+        )
+
+        results = self.index.query(query)
+        if results is None:
+            return None
+        if isinstance(results, list):
+            if len(results) == 0:
+                return None
+
+        vector_distance = results[0]["vector_distance"]
+        vector_distance = float(vector_distance)
+        similarity = 1 - vector_distance
+        cached_prompt = results[0]["prompt"]
+
+        # check similarity, if more than self.similarity_threshold, return results
+        print_verbose(
+            f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
+        )
+        if similarity > self.similarity_threshold:
+            # cache hit !
+            cached_value = results[0]["response"]
+            print_verbose(
+                f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
+            )
+            return self._get_cache_logic(cached_response=cached_value)
+        else:
+            # cache miss !
+            return None
+
+        pass
+
+    async def async_set_cache(self, key, value, **kwargs):
+        import numpy as np
+
+        from litellm.proxy.proxy_server import llm_model_list, llm_router
+
+        try:
+            await self.index.acreate(overwrite=False)  # don't overwrite existing index
+        except Exception as e:
+            print_verbose(f"Got exception creating semantic cache index: {str(e)}")
+        print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}")
+
+        # get the prompt
+        messages = kwargs["messages"]
+        prompt = "".join(message["content"] for message in messages)
+        # create an embedding for prompt
+        router_model_names = (
+            [m["model_name"] for m in llm_model_list]
+            if llm_model_list is not None
+            else []
+        )
+        if llm_router is not None and self.embedding_model in router_model_names:
+            user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
+            embedding_response = await llm_router.aembedding(
+                model=self.embedding_model,
+                input=prompt,
+                cache={"no-store": True, "no-cache": True},
+                metadata={
+                    "user_api_key": user_api_key,
+                    "semantic-cache-embedding": True,
+                    "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
+                },
+            )
+        else:
+            # convert to embedding
+            embedding_response = await litellm.aembedding(
+                model=self.embedding_model,
+                input=prompt,
+                cache={"no-store": True, "no-cache": True},
+            )
+
+        # get the embedding
+        embedding = embedding_response["data"][0]["embedding"]
+
+        # make the embedding a numpy array, convert to bytes
+        embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
+        value = str(value)
+        assert isinstance(value, str)
+
+        new_data = [
+            {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
+        ]
+
+        # Add more data
+        await self.index.aload(new_data)
+        return
+
+    async def async_get_cache(self, key, **kwargs):
+        print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
+        from redisvl.query import VectorQuery
+
+        from litellm.proxy.proxy_server import llm_model_list, llm_router
+
+        # query
+        # get the messages
+        messages = kwargs["messages"]
+        prompt = "".join(message["content"] for message in messages)
+
+        router_model_names = (
+            [m["model_name"] for m in llm_model_list]
+            if llm_model_list is not None
+            else []
+        )
+        if llm_router is not None and self.embedding_model in router_model_names:
+            user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
+            embedding_response = await llm_router.aembedding(
+                model=self.embedding_model,
+                input=prompt,
+                cache={"no-store": True, "no-cache": True},
+                metadata={
+                    "user_api_key": user_api_key,
+                    "semantic-cache-embedding": True,
+                    "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
+                },
+            )
+        else:
+            # convert to embedding
+            embedding_response = await litellm.aembedding(
+                model=self.embedding_model,
+                input=prompt,
+                cache={"no-store": True, "no-cache": True},
+            )
+
+        # get the embedding
+        embedding = embedding_response["data"][0]["embedding"]
+
+        query = VectorQuery(
+            vector=embedding,
+            vector_field_name="litellm_embedding",
+            return_fields=["response", "prompt", "vector_distance"],
+        )
+        results = await self.index.aquery(query)
+        if results is None:
+            kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
+            return None
+        if isinstance(results, list):
+            if len(results) == 0:
+                kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
+                return None
+
+        vector_distance = results[0]["vector_distance"]
+        vector_distance = float(vector_distance)
+        similarity = 1 - vector_distance
+        cached_prompt = results[0]["prompt"]
+
+        # check similarity, if more than self.similarity_threshold, return results
+        print_verbose(
+            f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
+        )
+
+        # update kwargs["metadata"] with similarity, don't rewrite the original metadata
+        kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
+
+        if similarity > self.similarity_threshold:
+            # cache hit !
+            cached_value = results[0]["response"]
+            print_verbose(
+                f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
+            )
+            return self._get_cache_logic(cached_response=cached_value)
+        else:
+            # cache miss !
+            return None
+        pass
+
+    async def _index_info(self):
+        return await self.index.ainfo()
+
+    async def async_set_cache_pipeline(self, cache_list, **kwargs):
+        tasks = []
+        for val in cache_list:
+            tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
+        await asyncio.gather(*tasks)