aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/providers/embedding.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/embedding.py197
1 files changed, 197 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
new file mode 100644
index 00000000..d1f9f9d6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
@@ -0,0 +1,197 @@
+import asyncio
+import logging
+import random
+import time
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, Optional
+
+from litellm import AuthenticationError
+
+from core.base.abstractions import VectorQuantizationSettings
+
+from ..abstractions import (
+ ChunkSearchResult,
+ EmbeddingPurpose,
+ default_embedding_prefixes,
+)
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class EmbeddingConfig(ProviderConfig):
+ provider: str
+ base_model: str
+ base_dimension: int | float
+ rerank_model: Optional[str] = None
+ rerank_url: Optional[str] = None
+ batch_size: int = 1
+ prefixes: Optional[dict[str, str]] = None
+ add_title_as_prefix: bool = True
+ concurrent_request_limit: int = 256
+ max_retries: int = 3
+ initial_backoff: float = 1
+ max_backoff: float = 64.0
+ quantization_settings: VectorQuantizationSettings = (
+ VectorQuantizationSettings()
+ )
+
+ ## deprecated
+ rerank_dimension: Optional[int] = None
+ rerank_transformer_type: Optional[str] = None
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["litellm", "openai", "ollama"]
+
+
+class EmbeddingProvider(Provider):
+ class Step(Enum):
+ BASE = 1
+ RERANK = 2
+
+ def __init__(self, config: EmbeddingConfig):
+ if not isinstance(config, EmbeddingConfig):
+ raise ValueError(
+ "EmbeddingProvider must be initialized with a `EmbeddingConfig`."
+ )
+ logger.info(f"Initializing EmbeddingProvider with config {config}.")
+
+ super().__init__(config)
+ self.config: EmbeddingConfig = config
+ self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
+ self.current_requests = 0
+
+ async def _execute_with_backoff_async(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ async with self.semaphore:
+ return await self._execute_task(task)
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ await asyncio.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ def _execute_with_backoff_sync(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ return self._execute_task_sync(task)
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ time.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ @abstractmethod
+ async def _execute_task(self, task: dict[str, Any]):
+ pass
+
+ @abstractmethod
+ def _execute_task_sync(self, task: dict[str, Any]):
+ pass
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ):
+ task = {
+ "text": text,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ):
+ task = {
+ "text": text,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ):
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ) -> list[list[float]]:
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ @abstractmethod
+ def rerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: Step = Step.RERANK,
+ limit: int = 10,
+ ):
+ pass
+
+ @abstractmethod
+ async def arerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: Step = Step.RERANK,
+ limit: int = 10,
+ ):
+ pass
+
+ def set_prefixes(self, config_prefixes: dict[str, str], base_model: str):
+ self.prefixes = {}
+
+ for t, p in config_prefixes.items():
+ purpose = EmbeddingPurpose(t.lower())
+ self.prefixes[purpose] = p
+
+ if base_model in default_embedding_prefixes:
+ for t, p in default_embedding_prefixes[base_model].items():
+ if t not in self.prefixes:
+ self.prefixes[t] = p