about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/providers/embedding.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/embedding.py197
1 files changed, 197 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
new file mode 100644
index 00000000..d1f9f9d6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
@@ -0,0 +1,197 @@
+import asyncio
+import logging
+import random
+import time
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, Optional
+
+from litellm import AuthenticationError
+
+from core.base.abstractions import VectorQuantizationSettings
+
+from ..abstractions import (
+    ChunkSearchResult,
+    EmbeddingPurpose,
+    default_embedding_prefixes,
+)
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class EmbeddingConfig(ProviderConfig):
+    provider: str
+    base_model: str
+    base_dimension: int | float
+    rerank_model: Optional[str] = None
+    rerank_url: Optional[str] = None
+    batch_size: int = 1
+    prefixes: Optional[dict[str, str]] = None
+    add_title_as_prefix: bool = True
+    concurrent_request_limit: int = 256
+    max_retries: int = 3
+    initial_backoff: float = 1
+    max_backoff: float = 64.0
+    quantization_settings: VectorQuantizationSettings = (
+        VectorQuantizationSettings()
+    )
+
+    ## deprecated
+    rerank_dimension: Optional[int] = None
+    rerank_transformer_type: Optional[str] = None
+
+    def validate_config(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["litellm", "openai", "ollama"]
+
+
+class EmbeddingProvider(Provider):
+    class Step(Enum):
+        BASE = 1
+        RERANK = 2
+
+    def __init__(self, config: EmbeddingConfig):
+        if not isinstance(config, EmbeddingConfig):
+            raise ValueError(
+                "EmbeddingProvider must be initialized with a `EmbeddingConfig`."
+            )
+        logger.info(f"Initializing EmbeddingProvider with config {config}.")
+
+        super().__init__(config)
+        self.config: EmbeddingConfig = config
+        self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
+        self.current_requests = 0
+
+    async def _execute_with_backoff_async(self, task: dict[str, Any]):
+        retries = 0
+        backoff = self.config.initial_backoff
+        while retries < self.config.max_retries:
+            try:
+                async with self.semaphore:
+                    return await self._execute_task(task)
+            except AuthenticationError:
+                raise
+            except Exception as e:
+                logger.warning(
+                    f"Request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.config.max_retries:
+                    raise
+                await asyncio.sleep(random.uniform(0, backoff))
+                backoff = min(backoff * 2, self.config.max_backoff)
+
+    def _execute_with_backoff_sync(self, task: dict[str, Any]):
+        retries = 0
+        backoff = self.config.initial_backoff
+        while retries < self.config.max_retries:
+            try:
+                return self._execute_task_sync(task)
+            except AuthenticationError:
+                raise
+            except Exception as e:
+                logger.warning(
+                    f"Request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.config.max_retries:
+                    raise
+                time.sleep(random.uniform(0, backoff))
+                backoff = min(backoff * 2, self.config.max_backoff)
+
+    @abstractmethod
+    async def _execute_task(self, task: dict[str, Any]):
+        pass
+
+    @abstractmethod
+    def _execute_task_sync(self, task: dict[str, Any]):
+        pass
+
+    async def async_get_embedding(
+        self,
+        text: str,
+        stage: Step = Step.BASE,
+        purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+    ):
+        task = {
+            "text": text,
+            "stage": stage,
+            "purpose": purpose,
+        }
+        return await self._execute_with_backoff_async(task)
+
+    def get_embedding(
+        self,
+        text: str,
+        stage: Step = Step.BASE,
+        purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+    ):
+        task = {
+            "text": text,
+            "stage": stage,
+            "purpose": purpose,
+        }
+        return self._execute_with_backoff_sync(task)
+
+    async def async_get_embeddings(
+        self,
+        texts: list[str],
+        stage: Step = Step.BASE,
+        purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+    ):
+        task = {
+            "texts": texts,
+            "stage": stage,
+            "purpose": purpose,
+        }
+        return await self._execute_with_backoff_async(task)
+
+    def get_embeddings(
+        self,
+        texts: list[str],
+        stage: Step = Step.BASE,
+        purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+    ) -> list[list[float]]:
+        task = {
+            "texts": texts,
+            "stage": stage,
+            "purpose": purpose,
+        }
+        return self._execute_with_backoff_sync(task)
+
+    @abstractmethod
+    def rerank(
+        self,
+        query: str,
+        results: list[ChunkSearchResult],
+        stage: Step = Step.RERANK,
+        limit: int = 10,
+    ):
+        pass
+
+    @abstractmethod
+    async def arerank(
+        self,
+        query: str,
+        results: list[ChunkSearchResult],
+        stage: Step = Step.RERANK,
+        limit: int = 10,
+    ):
+        pass
+
+    def set_prefixes(self, config_prefixes: dict[str, str], base_model: str):
+        self.prefixes = {}
+
+        for t, p in config_prefixes.items():
+            purpose = EmbeddingPurpose(t.lower())
+            self.prefixes[purpose] = p
+
+        if base_model in default_embedding_prefixes:
+            for t, p in default_embedding_prefixes[base_model].items():
+                if t not in self.prefixes:
+                    self.prefixes[t] = p