diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/base/providers/embedding_provider.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to 'R2R/r2r/base/providers/embedding_provider.py')
-rwxr-xr-x | R2R/r2r/base/providers/embedding_provider.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/R2R/r2r/base/providers/embedding_provider.py b/R2R/r2r/base/providers/embedding_provider.py new file mode 100755 index 00000000..8f3af56f --- /dev/null +++ b/R2R/r2r/base/providers/embedding_provider.py @@ -0,0 +1,83 @@ +import logging +from abc import abstractmethod +from enum import Enum +from typing import Optional + +from ..abstractions.search import VectorSearchResult +from .base_provider import Provider, ProviderConfig + +logger = logging.getLogger(__name__) + + +class EmbeddingConfig(ProviderConfig): + """A base embedding configuration class""" + + provider: Optional[str] = None + base_model: Optional[str] = None + base_dimension: Optional[int] = None + rerank_model: Optional[str] = None + rerank_dimension: Optional[int] = None + rerank_transformer_type: Optional[str] = None + batch_size: int = 1 + + def validate(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return [None, "openai", "ollama", "sentence-transformers"] + + +class EmbeddingProvider(Provider): + """An abstract class to provide a common interface for embedding providers.""" + + class PipeStage(Enum): + BASE = 1 + RERANK = 2 + + def __init__(self, config: EmbeddingConfig): + if not isinstance(config, EmbeddingConfig): + raise ValueError( + "EmbeddingProvider must be initialized with a `EmbeddingConfig`." + ) + logger.info(f"Initializing EmbeddingProvider with config {config}.") + + super().__init__(config) + + @abstractmethod + def get_embedding(self, text: str, stage: PipeStage = PipeStage.BASE): + pass + + async def async_get_embedding( + self, text: str, stage: PipeStage = PipeStage.BASE + ): + return self.get_embedding(text, stage) + + @abstractmethod + def get_embeddings( + self, texts: list[str], stage: PipeStage = PipeStage.BASE + ): + pass + + async def async_get_embeddings( + self, texts: list[str], stage: PipeStage = PipeStage.BASE + ): + return self.get_embeddings(texts, stage) + + @abstractmethod + def rerank( + self, + query: str, + results: list[VectorSearchResult], + stage: PipeStage = PipeStage.RERANK, + limit: int = 10, + ): + pass + + @abstractmethod + def tokenize_string( + self, text: str, model: str, stage: PipeStage + ) -> list[int]: + """Tokenizes the input string.""" + pass |