aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py430
1 files changed, 430 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py
new file mode 100644
index 00000000..bdfd3770
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/qdrant_semantic_cache.py
@@ -0,0 +1,430 @@
+"""
+Qdrant Semantic Cache implementation
+
+Has 4 methods:
+ - set_cache
+ - get_cache
+ - async_set_cache
+ - async_get_cache
+"""
+
+import ast
+import asyncio
+import json
+from typing import Any
+
+import litellm
+from litellm._logging import print_verbose
+
+from .base_cache import BaseCache
+
+
+class QdrantSemanticCache(BaseCache):
+ def __init__( # noqa: PLR0915
+ self,
+ qdrant_api_base=None,
+ qdrant_api_key=None,
+ collection_name=None,
+ similarity_threshold=None,
+ quantization_config=None,
+ embedding_model="text-embedding-ada-002",
+ host_type=None,
+ ):
+ import os
+
+ from litellm.llms.custom_httpx.http_handler import (
+ _get_httpx_client,
+ get_async_httpx_client,
+ httpxSpecialProvider,
+ )
+ from litellm.secret_managers.main import get_secret_str
+
+ if collection_name is None:
+ raise Exception("collection_name must be provided, passed None")
+
+ self.collection_name = collection_name
+ print_verbose(
+ f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
+ )
+
+ if similarity_threshold is None:
+ raise Exception("similarity_threshold must be provided, passed None")
+ self.similarity_threshold = similarity_threshold
+ self.embedding_model = embedding_model
+ headers = {}
+
+ # check if defined as os.environ/ variable
+ if qdrant_api_base:
+ if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
+ "os.environ/"
+ ):
+ qdrant_api_base = get_secret_str(qdrant_api_base)
+ if qdrant_api_key:
+ if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
+ "os.environ/"
+ ):
+ qdrant_api_key = get_secret_str(qdrant_api_key)
+
+ qdrant_api_base = (
+ qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
+ )
+ qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
+ headers = {"Content-Type": "application/json"}
+ if qdrant_api_key:
+ headers["api-key"] = qdrant_api_key
+
+ if qdrant_api_base is None:
+ raise ValueError("Qdrant url must be provided")
+
+ self.qdrant_api_base = qdrant_api_base
+ self.qdrant_api_key = qdrant_api_key
+ print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
+
+ self.headers = headers
+
+ self.sync_client = _get_httpx_client()
+ self.async_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.Caching
+ )
+
+ if quantization_config is None:
+ print_verbose(
+ "Quantization config is not provided. Default binary quantization will be used."
+ )
+ collection_exists = self.sync_client.get(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
+ headers=self.headers,
+ )
+ if collection_exists.status_code != 200:
+ raise ValueError(
+ f"Error from qdrant checking if /collections exist {collection_exists.text}"
+ )
+
+ if collection_exists.json()["result"]["exists"]:
+ collection_details = self.sync_client.get(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
+ headers=self.headers,
+ )
+ self.collection_info = collection_details.json()
+ print_verbose(
+ f"Collection already exists.\nCollection details:{self.collection_info}"
+ )
+ else:
+ if quantization_config is None or quantization_config == "binary":
+ quantization_params = {
+ "binary": {
+ "always_ram": False,
+ }
+ }
+ elif quantization_config == "scalar":
+ quantization_params = {
+ "scalar": {"type": "int8", "quantile": 0.99, "always_ram": False}
+ }
+ elif quantization_config == "product":
+ quantization_params = {
+ "product": {"compression": "x16", "always_ram": False}
+ }
+ else:
+ raise Exception(
+ "Quantization config must be one of 'scalar', 'binary' or 'product'"
+ )
+
+ new_collection_status = self.sync_client.put(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
+ json={
+ "vectors": {"size": 1536, "distance": "Cosine"},
+ "quantization_config": quantization_params,
+ },
+ headers=self.headers,
+ )
+ if new_collection_status.json()["result"]:
+ collection_details = self.sync_client.get(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
+ headers=self.headers,
+ )
+ self.collection_info = collection_details.json()
+ print_verbose(
+ f"New collection created.\nCollection details:{self.collection_info}"
+ )
+ else:
+ raise Exception("Error while creating new collection")
+
+ def _get_cache_logic(self, cached_response: Any):
+ if cached_response is None:
+ return cached_response
+ try:
+ cached_response = json.loads(
+ cached_response
+ ) # Convert string to dictionary
+ except Exception:
+ cached_response = ast.literal_eval(cached_response)
+ return cached_response
+
+ def set_cache(self, key, value, **kwargs):
+ print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
+ import uuid
+
+ # get the prompt
+ messages = kwargs["messages"]
+ prompt = ""
+ for message in messages:
+ prompt += message["content"]
+
+ # create an embedding for prompt
+ embedding_response = litellm.embedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ value = str(value)
+ assert isinstance(value, str)
+
+ data = {
+ "points": [
+ {
+ "id": str(uuid.uuid4()),
+ "vector": embedding,
+ "payload": {
+ "text": prompt,
+ "response": value,
+ },
+ },
+ ]
+ }
+ self.sync_client.put(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
+ headers=self.headers,
+ json=data,
+ )
+ return
+
+ def get_cache(self, key, **kwargs):
+ print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
+
+ # get the messages
+ messages = kwargs["messages"]
+ prompt = ""
+ for message in messages:
+ prompt += message["content"]
+
+ # convert to embedding
+ embedding_response = litellm.embedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ data = {
+ "vector": embedding,
+ "params": {
+ "quantization": {
+ "ignore": False,
+ "rescore": True,
+ "oversampling": 3.0,
+ }
+ },
+ "limit": 1,
+ "with_payload": True,
+ }
+
+ search_response = self.sync_client.post(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
+ headers=self.headers,
+ json=data,
+ )
+ results = search_response.json()["result"]
+
+ if results is None:
+ return None
+ if isinstance(results, list):
+ if len(results) == 0:
+ return None
+
+ similarity = results[0]["score"]
+ cached_prompt = results[0]["payload"]["text"]
+
+ # check similarity, if more than self.similarity_threshold, return results
+ print_verbose(
+ f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
+ )
+ if similarity >= self.similarity_threshold:
+ # cache hit !
+ cached_value = results[0]["payload"]["response"]
+ print_verbose(
+ f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
+ )
+ return self._get_cache_logic(cached_response=cached_value)
+ else:
+ # cache miss !
+ return None
+ pass
+
+ async def async_set_cache(self, key, value, **kwargs):
+ import uuid
+
+ from litellm.proxy.proxy_server import llm_model_list, llm_router
+
+ print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
+
+ # get the prompt
+ messages = kwargs["messages"]
+ prompt = ""
+ for message in messages:
+ prompt += message["content"]
+ # create an embedding for prompt
+ router_model_names = (
+ [m["model_name"] for m in llm_model_list]
+ if llm_model_list is not None
+ else []
+ )
+ if llm_router is not None and self.embedding_model in router_model_names:
+ user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
+ embedding_response = await llm_router.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ metadata={
+ "user_api_key": user_api_key,
+ "semantic-cache-embedding": True,
+ "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
+ },
+ )
+ else:
+ # convert to embedding
+ embedding_response = await litellm.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ value = str(value)
+ assert isinstance(value, str)
+
+ data = {
+ "points": [
+ {
+ "id": str(uuid.uuid4()),
+ "vector": embedding,
+ "payload": {
+ "text": prompt,
+ "response": value,
+ },
+ },
+ ]
+ }
+
+ await self.async_client.put(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
+ headers=self.headers,
+ json=data,
+ )
+ return
+
+ async def async_get_cache(self, key, **kwargs):
+ print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
+ from litellm.proxy.proxy_server import llm_model_list, llm_router
+
+ # get the messages
+ messages = kwargs["messages"]
+ prompt = ""
+ for message in messages:
+ prompt += message["content"]
+
+ router_model_names = (
+ [m["model_name"] for m in llm_model_list]
+ if llm_model_list is not None
+ else []
+ )
+ if llm_router is not None and self.embedding_model in router_model_names:
+ user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
+ embedding_response = await llm_router.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ metadata={
+ "user_api_key": user_api_key,
+ "semantic-cache-embedding": True,
+ "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
+ },
+ )
+ else:
+ # convert to embedding
+ embedding_response = await litellm.aembedding(
+ model=self.embedding_model,
+ input=prompt,
+ cache={"no-store": True, "no-cache": True},
+ )
+
+ # get the embedding
+ embedding = embedding_response["data"][0]["embedding"]
+
+ data = {
+ "vector": embedding,
+ "params": {
+ "quantization": {
+ "ignore": False,
+ "rescore": True,
+ "oversampling": 3.0,
+ }
+ },
+ "limit": 1,
+ "with_payload": True,
+ }
+
+ search_response = await self.async_client.post(
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
+ headers=self.headers,
+ json=data,
+ )
+
+ results = search_response.json()["result"]
+
+ if results is None:
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
+ return None
+ if isinstance(results, list):
+ if len(results) == 0:
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
+ return None
+
+ similarity = results[0]["score"]
+ cached_prompt = results[0]["payload"]["text"]
+
+ # check similarity, if more than self.similarity_threshold, return results
+ print_verbose(
+ f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
+ )
+
+ # update kwargs["metadata"] with similarity, don't rewrite the original metadata
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
+
+ if similarity >= self.similarity_threshold:
+ # cache hit !
+ cached_value = results[0]["payload"]["response"]
+ print_verbose(
+ f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
+ )
+ return self._get_cache_logic(cached_response=cached_value)
+ else:
+ # cache miss !
+ return None
+ pass
+
+ async def _collection_info(self):
+ return self.collection_info
+
+ async def async_set_cache_pipeline(self, cache_list, **kwargs):
+ tasks = []
+ for val in cache_list:
+ tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
+ await asyncio.gather(*tasks)