aboutsummaryrefslogtreecommitdiff
"""
Redis Semantic Cache implementation

Has 4 methods:
    - set_cache
    - get_cache
    - async_set_cache
    - async_get_cache
"""

import ast
import asyncio
import json
from typing import Any

import litellm
from litellm._logging import print_verbose

from .base_cache import BaseCache


class RedisSemanticCache(BaseCache):
    def __init__(
        self,
        host=None,
        port=None,
        password=None,
        redis_url=None,
        similarity_threshold=None,
        use_async=False,
        embedding_model="text-embedding-ada-002",
        **kwargs,
    ):
        from redisvl.index import SearchIndex

        print_verbose(
            "redis semantic-cache initializing INDEX - litellm_semantic_cache_index"
        )
        if similarity_threshold is None:
            raise Exception("similarity_threshold must be provided, passed None")
        self.similarity_threshold = similarity_threshold
        self.embedding_model = embedding_model
        schema = {
            "index": {
                "name": "litellm_semantic_cache_index",
                "prefix": "litellm",
                "storage_type": "hash",
            },
            "fields": {
                "text": [{"name": "response"}],
                "vector": [
                    {
                        "name": "litellm_embedding",
                        "dims": 1536,
                        "distance_metric": "cosine",
                        "algorithm": "flat",
                        "datatype": "float32",
                    }
                ],
            },
        }
        if redis_url is None:
            # if no url passed, check if host, port and password are passed, if not raise an Exception
            if host is None or port is None or password is None:
                # try checking env for host, port and password
                import os

                host = os.getenv("REDIS_HOST")
                port = os.getenv("REDIS_PORT")
                password = os.getenv("REDIS_PASSWORD")
                if host is None or port is None or password is None:
                    raise Exception("Redis host, port, and password must be provided")

            redis_url = "redis://:" + password + "@" + host + ":" + port
        print_verbose(f"redis semantic-cache redis_url: {redis_url}")
        if use_async is False:
            self.index = SearchIndex.from_dict(schema)
            self.index.connect(redis_url=redis_url)
            try:
                self.index.create(overwrite=False)  # don't overwrite existing index
            except Exception as e:
                print_verbose(f"Got exception creating semantic cache index: {str(e)}")
        elif use_async is True:
            schema["index"]["name"] = "litellm_semantic_cache_index_async"
            self.index = SearchIndex.from_dict(schema)
            self.index.connect(redis_url=redis_url, use_async=True)

    #
    def _get_cache_logic(self, cached_response: Any):
        """
        Common 'get_cache_logic' across sync + async redis client implementations
        """
        if cached_response is None:
            return cached_response

        # check if cached_response is bytes
        if isinstance(cached_response, bytes):
            cached_response = cached_response.decode("utf-8")

        try:
            cached_response = json.loads(
                cached_response
            )  # Convert string to dictionary
        except Exception:
            cached_response = ast.literal_eval(cached_response)
        return cached_response

    def set_cache(self, key, value, **kwargs):
        import numpy as np

        print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}")

        # get the prompt
        messages = kwargs["messages"]
        prompt = "".join(message["content"] for message in messages)

        # create an embedding for prompt
        embedding_response = litellm.embedding(
            model=self.embedding_model,
            input=prompt,
            cache={"no-store": True, "no-cache": True},
        )

        # get the embedding
        embedding = embedding_response["data"][0]["embedding"]

        # make the embedding a numpy array, convert to bytes
        embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
        value = str(value)
        assert isinstance(value, str)

        new_data = [
            {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
        ]

        # Add more data
        self.index.load(new_data)

        return

    def get_cache(self, key, **kwargs):
        print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
        from redisvl.query import VectorQuery

        # query
        # get the messages
        messages = kwargs["messages"]
        prompt = "".join(message["content"] for message in messages)

        # convert to embedding
        embedding_response = litellm.embedding(
            model=self.embedding_model,
            input=prompt,
            cache={"no-store": True, "no-cache": True},
        )

        # get the embedding
        embedding = embedding_response["data"][0]["embedding"]

        query = VectorQuery(
            vector=embedding,
            vector_field_name="litellm_embedding",
            return_fields=["response", "prompt", "vector_distance"],
            num_results=1,
        )

        results = self.index.query(query)
        if results is None:
            return None
        if isinstance(results, list):
            if len(results) == 0:
                return None

        vector_distance = results[0]["vector_distance"]
        vector_distance = float(vector_distance)
        similarity = 1 - vector_distance
        cached_prompt = results[0]["prompt"]

        # check similarity, if more than self.similarity_threshold, return results
        print_verbose(
            f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
        )
        if similarity > self.similarity_threshold:
            # cache hit !
            cached_value = results[0]["response"]
            print_verbose(
                f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
            )
            return self._get_cache_logic(cached_response=cached_value)
        else:
            # cache miss !
            return None

        pass

    async def async_set_cache(self, key, value, **kwargs):
        import numpy as np

        from litellm.proxy.proxy_server import llm_model_list, llm_router

        try:
            await self.index.acreate(overwrite=False)  # don't overwrite existing index
        except Exception as e:
            print_verbose(f"Got exception creating semantic cache index: {str(e)}")
        print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}")

        # get the prompt
        messages = kwargs["messages"]
        prompt = "".join(message["content"] for message in messages)
        # create an embedding for prompt
        router_model_names = (
            [m["model_name"] for m in llm_model_list]
            if llm_model_list is not None
            else []
        )
        if llm_router is not None and self.embedding_model in router_model_names:
            user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
            embedding_response = await llm_router.aembedding(
                model=self.embedding_model,
                input=prompt,
                cache={"no-store": True, "no-cache": True},
                metadata={
                    "user_api_key": user_api_key,
                    "semantic-cache-embedding": True,
                    "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
                },
            )
        else:
            # convert to embedding
            embedding_response = await litellm.aembedding(
                model=self.embedding_model,
                input=prompt,
                cache={"no-store": True, "no-cache": True},
            )

        # get the embedding
        embedding = embedding_response["data"][0]["embedding"]

        # make the embedding a numpy array, convert to bytes
        embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
        value = str(value)
        assert isinstance(value, str)

        new_data = [
            {"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
        ]

        # Add more data
        await self.index.aload(new_data)
        return

    async def async_get_cache(self, key, **kwargs):
        print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
        from redisvl.query import VectorQuery

        from litellm.proxy.proxy_server import llm_model_list, llm_router

        # query
        # get the messages
        messages = kwargs["messages"]
        prompt = "".join(message["content"] for message in messages)

        router_model_names = (
            [m["model_name"] for m in llm_model_list]
            if llm_model_list is not None
            else []
        )
        if llm_router is not None and self.embedding_model in router_model_names:
            user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
            embedding_response = await llm_router.aembedding(
                model=self.embedding_model,
                input=prompt,
                cache={"no-store": True, "no-cache": True},
                metadata={
                    "user_api_key": user_api_key,
                    "semantic-cache-embedding": True,
                    "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
                },
            )
        else:
            # convert to embedding
            embedding_response = await litellm.aembedding(
                model=self.embedding_model,
                input=prompt,
                cache={"no-store": True, "no-cache": True},
            )

        # get the embedding
        embedding = embedding_response["data"][0]["embedding"]

        query = VectorQuery(
            vector=embedding,
            vector_field_name="litellm_embedding",
            return_fields=["response", "prompt", "vector_distance"],
        )
        results = await self.index.aquery(query)
        if results is None:
            kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
            return None
        if isinstance(results, list):
            if len(results) == 0:
                kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
                return None

        vector_distance = results[0]["vector_distance"]
        vector_distance = float(vector_distance)
        similarity = 1 - vector_distance
        cached_prompt = results[0]["prompt"]

        # check similarity, if more than self.similarity_threshold, return results
        print_verbose(
            f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
        )

        # update kwargs["metadata"] with similarity, don't rewrite the original metadata
        kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity

        if similarity > self.similarity_threshold:
            # cache hit !
            cached_value = results[0]["response"]
            print_verbose(
                f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
            )
            return self._get_cache_logic(cached_response=cached_value)
        else:
            # cache miss !
            return None
        pass

    async def _index_info(self):
        return await self.index.ainfo()

    async def async_set_cache_pipeline(self, cache_list, **kwargs):
        tasks = []
        for val in cache_list:
            tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
        await asyncio.gather(*tasks)