about summary refs log tree commit diff
path: root/R2R/r2r/providers/embeddings/sentence_transformer
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/providers/embeddings/sentence_transformer')
-rwxr-xr-xR2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py160
1 files changed, 160 insertions, 0 deletions
diff --git a/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py
new file mode 100755
index 00000000..3316cb60
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py
@@ -0,0 +1,160 @@
+import logging
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class SentenceTransformerEmbeddingProvider(EmbeddingProvider):
+    def __init__(
+        self,
+        config: EmbeddingConfig,
+    ):
+        super().__init__(config)
+        logger.info(
+            "Initializing `SentenceTransformerEmbeddingProvider` with separate models for search and rerank."
+        )
+        provider = config.provider
+        if not provider:
+            raise ValueError(
+                "Must set provider in order to initialize SentenceTransformerEmbeddingProvider."
+            )
+        if provider != "sentence-transformers":
+            raise ValueError(
+                "SentenceTransformerEmbeddingProvider must be initialized with provider `sentence-transformers`."
+            )
+        try:
+            from sentence_transformers import CrossEncoder, SentenceTransformer
+
+            self.SentenceTransformer = SentenceTransformer
+            # TODO - Modify this to be configurable, as `bge-reranker-large` is a `SentenceTransformer` model
+            self.CrossEncoder = CrossEncoder
+        except ImportError as e:
+            raise ValueError(
+                "Must download sentence-transformers library to run `SentenceTransformerEmbeddingProvider`."
+            ) from e
+
+        # Initialize separate models for search and rerank
+        self.do_search = False
+        self.do_rerank = False
+
+        self.search_encoder = self._init_model(
+            config, EmbeddingProvider.PipeStage.BASE
+        )
+        self.rerank_encoder = self._init_model(
+            config, EmbeddingProvider.PipeStage.RERANK
+        )
+
+    def _init_model(self, config: EmbeddingConfig, stage: str):
+        stage_name = stage.name.lower()
+        model = config.dict().get(f"{stage_name}_model", None)
+        dimension = config.dict().get(f"{stage_name}_dimension", None)
+
+        transformer_type = config.dict().get(
+            f"{stage_name}_transformer_type", "SentenceTransformer"
+        )
+
+        if stage == EmbeddingProvider.PipeStage.BASE:
+            self.do_search = True
+            # Check if a model is set for the stage
+            if not (model and dimension and transformer_type):
+                raise ValueError(
+                    f"Must set {stage.name.lower()}_model and {stage.name.lower()}_dimension for {stage} stage in order to initialize SentenceTransformerEmbeddingProvider."
+                )
+
+        if stage == EmbeddingProvider.PipeStage.RERANK:
+            # Check if a model is set for the stage
+            if not (model and dimension and transformer_type):
+                return None
+
+            self.do_rerank = True
+            if transformer_type == "SentenceTransformer":
+                raise ValueError(
+                    f"`SentenceTransformer` models are not yet supported for {stage} stage in SentenceTransformerEmbeddingProvider."
+                )
+
+        # Save the model_key and dimension into instance variables
+        setattr(self, f"{stage_name}_model", model)
+        setattr(self, f"{stage_name}_dimension", dimension)
+        setattr(self, f"{stage_name}_transformer_type", transformer_type)
+
+        # Initialize the model
+        encoder = (
+            self.SentenceTransformer(
+                model, truncate_dim=dimension, trust_remote_code=True
+            )
+            if transformer_type == "SentenceTransformer"
+            else self.CrossEncoder(model, trust_remote_code=True)
+        )
+        return encoder
+
+    def get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError("`get_embedding` only supports `SEARCH` stage.")
+        if not self.do_search:
+            raise ValueError(
+                "`get_embedding` can only be called for the search stage if a search model is set."
+            )
+        encoder = self.search_encoder
+        return encoder.encode([text]).tolist()[0]
+
+    def get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError("`get_embeddings` only supports `SEARCH` stage.")
+        if not self.do_search:
+            raise ValueError(
+                "`get_embeddings` can only be called for the search stage if a search model is set."
+            )
+        encoder = (
+            self.search_encoder
+            if stage == EmbeddingProvider.PipeStage.BASE
+            else self.rerank_encoder
+        )
+        return encoder.encode(texts).tolist()
+
+    def rerank(
+        self,
+        query: str,
+        results: list[VectorSearchResult],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+        limit: int = 10,
+    ) -> list[VectorSearchResult]:
+        if stage != EmbeddingProvider.PipeStage.RERANK:
+            raise ValueError("`rerank` only supports `RERANK` stage.")
+        if not self.do_rerank:
+            return results[:limit]
+
+        from copy import copy
+
+        texts = copy([doc.metadata["text"] for doc in results])
+        # Use the rank method from the rerank_encoder, which is a CrossEncoder model
+        reranked_scores = self.rerank_encoder.rank(
+            query, texts, return_documents=False, top_k=limit
+        )
+        # Map the reranked scores back to the original documents
+        reranked_results = []
+        for score in reranked_scores:
+            corpus_id = score["corpus_id"]
+            new_result = results[corpus_id]
+            new_result.score = float(score["score"])
+            reranked_results.append(new_result)
+
+        # Sort the documents by the new scores in descending order
+        reranked_results.sort(key=lambda doc: doc.score, reverse=True)
+        return reranked_results
+
+    def tokenize_string(
+        self,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[int]:
+        raise ValueError(
+            "SentenceTransformerEmbeddingProvider does not support tokenize_string."
+        )