about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py88
1 files changed, 88 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py b/.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py
new file mode 100644
index 00000000..abf3203f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/caching/disk_cache.py
@@ -0,0 +1,88 @@
+import json
+from typing import TYPE_CHECKING, Any, Optional
+
+from .base_cache import BaseCache
+
+if TYPE_CHECKING:
+    from opentelemetry.trace import Span as _Span
+
+    Span = _Span
+else:
+    Span = Any
+
+
+class DiskCache(BaseCache):
+    def __init__(self, disk_cache_dir: Optional[str] = None):
+        import diskcache as dc
+
+        # if users don't provider one, use the default litellm cache
+        if disk_cache_dir is None:
+            self.disk_cache = dc.Cache(".litellm_cache")
+        else:
+            self.disk_cache = dc.Cache(disk_cache_dir)
+
+    def set_cache(self, key, value, **kwargs):
+        if "ttl" in kwargs:
+            self.disk_cache.set(key, value, expire=kwargs["ttl"])
+        else:
+            self.disk_cache.set(key, value)
+
+    async def async_set_cache(self, key, value, **kwargs):
+        self.set_cache(key=key, value=value, **kwargs)
+
+    async def async_set_cache_pipeline(self, cache_list, **kwargs):
+        for cache_key, cache_value in cache_list:
+            if "ttl" in kwargs:
+                self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
+            else:
+                self.set_cache(key=cache_key, value=cache_value)
+
+    def get_cache(self, key, **kwargs):
+        original_cached_response = self.disk_cache.get(key)
+        if original_cached_response:
+            try:
+                cached_response = json.loads(original_cached_response)  # type: ignore
+            except Exception:
+                cached_response = original_cached_response
+            return cached_response
+        return None
+
+    def batch_get_cache(self, keys: list, **kwargs):
+        return_val = []
+        for k in keys:
+            val = self.get_cache(key=k, **kwargs)
+            return_val.append(val)
+        return return_val
+
+    def increment_cache(self, key, value: int, **kwargs) -> int:
+        # get the value
+        init_value = self.get_cache(key=key) or 0
+        value = init_value + value  # type: ignore
+        self.set_cache(key, value, **kwargs)
+        return value
+
+    async def async_get_cache(self, key, **kwargs):
+        return self.get_cache(key=key, **kwargs)
+
+    async def async_batch_get_cache(self, keys: list, **kwargs):
+        return_val = []
+        for k in keys:
+            val = self.get_cache(key=k, **kwargs)
+            return_val.append(val)
+        return return_val
+
+    async def async_increment(self, key, value: int, **kwargs) -> int:
+        # get the value
+        init_value = await self.async_get_cache(key=key) or 0
+        value = init_value + value  # type: ignore
+        await self.async_set_cache(key, value, **kwargs)
+        return value
+
+    def flush_cache(self):
+        self.disk_cache.clear()
+
+    async def disconnect(self):
+        pass
+
+    def delete_cache(self, key):
+        self.disk_cache.pop(key)