aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/providers/embeddings/sentence_transformer
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/providers/embeddings/sentence_transformer')
-rwxr-xr-xR2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py160
1 files changed, 160 insertions, 0 deletions
diff --git a/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py
new file mode 100755
index 00000000..3316cb60
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py
@@ -0,0 +1,160 @@
+import logging
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class SentenceTransformerEmbeddingProvider(EmbeddingProvider):
+ def __init__(
+ self,
+ config: EmbeddingConfig,
+ ):
+ super().__init__(config)
+ logger.info(
+ "Initializing `SentenceTransformerEmbeddingProvider` with separate models for search and rerank."
+ )
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize SentenceTransformerEmbeddingProvider."
+ )
+ if provider != "sentence-transformers":
+ raise ValueError(
+ "SentenceTransformerEmbeddingProvider must be initialized with provider `sentence-transformers`."
+ )
+ try:
+ from sentence_transformers import CrossEncoder, SentenceTransformer
+
+ self.SentenceTransformer = SentenceTransformer
+ # TODO - Modify this to be configurable, as `bge-reranker-large` is a `SentenceTransformer` model
+ self.CrossEncoder = CrossEncoder
+ except ImportError as e:
+ raise ValueError(
+ "Must download sentence-transformers library to run `SentenceTransformerEmbeddingProvider`."
+ ) from e
+
+ # Initialize separate models for search and rerank
+ self.do_search = False
+ self.do_rerank = False
+
+ self.search_encoder = self._init_model(
+ config, EmbeddingProvider.PipeStage.BASE
+ )
+ self.rerank_encoder = self._init_model(
+ config, EmbeddingProvider.PipeStage.RERANK
+ )
+
+ def _init_model(self, config: EmbeddingConfig, stage: str):
+ stage_name = stage.name.lower()
+ model = config.dict().get(f"{stage_name}_model", None)
+ dimension = config.dict().get(f"{stage_name}_dimension", None)
+
+ transformer_type = config.dict().get(
+ f"{stage_name}_transformer_type", "SentenceTransformer"
+ )
+
+ if stage == EmbeddingProvider.PipeStage.BASE:
+ self.do_search = True
+ # Check if a model is set for the stage
+ if not (model and dimension and transformer_type):
+ raise ValueError(
+ f"Must set {stage.name.lower()}_model and {stage.name.lower()}_dimension for {stage} stage in order to initialize SentenceTransformerEmbeddingProvider."
+ )
+
+ if stage == EmbeddingProvider.PipeStage.RERANK:
+ # Check if a model is set for the stage
+ if not (model and dimension and transformer_type):
+ return None
+
+ self.do_rerank = True
+ if transformer_type == "SentenceTransformer":
+ raise ValueError(
+ f"`SentenceTransformer` models are not yet supported for {stage} stage in SentenceTransformerEmbeddingProvider."
+ )
+
+ # Save the model_key and dimension into instance variables
+ setattr(self, f"{stage_name}_model", model)
+ setattr(self, f"{stage_name}_dimension", dimension)
+ setattr(self, f"{stage_name}_transformer_type", transformer_type)
+
+ # Initialize the model
+ encoder = (
+ self.SentenceTransformer(
+ model, truncate_dim=dimension, trust_remote_code=True
+ )
+ if transformer_type == "SentenceTransformer"
+ else self.CrossEncoder(model, trust_remote_code=True)
+ )
+ return encoder
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError("`get_embedding` only supports `SEARCH` stage.")
+ if not self.do_search:
+ raise ValueError(
+ "`get_embedding` can only be called for the search stage if a search model is set."
+ )
+ encoder = self.search_encoder
+ return encoder.encode([text]).tolist()[0]
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError("`get_embeddings` only supports `SEARCH` stage.")
+ if not self.do_search:
+ raise ValueError(
+ "`get_embeddings` can only be called for the search stage if a search model is set."
+ )
+ encoder = (
+ self.search_encoder
+ if stage == EmbeddingProvider.PipeStage.BASE
+ else self.rerank_encoder
+ )
+ return encoder.encode(texts).tolist()
+
+ def rerank(
+ self,
+ query: str,
+ results: list[VectorSearchResult],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+ limit: int = 10,
+ ) -> list[VectorSearchResult]:
+ if stage != EmbeddingProvider.PipeStage.RERANK:
+ raise ValueError("`rerank` only supports `RERANK` stage.")
+ if not self.do_rerank:
+ return results[:limit]
+
+ from copy import copy
+
+ texts = copy([doc.metadata["text"] for doc in results])
+ # Use the rank method from the rerank_encoder, which is a CrossEncoder model
+ reranked_scores = self.rerank_encoder.rank(
+ query, texts, return_documents=False, top_k=limit
+ )
+ # Map the reranked scores back to the original documents
+ reranked_results = []
+ for score in reranked_scores:
+ corpus_id = score["corpus_id"]
+ new_result = results[corpus_id]
+ new_result.score = float(score["score"])
+ reranked_results.append(new_result)
+
+ # Sort the documents by the new scores in descending order
+ reranked_results.sort(key=lambda doc: doc.score, reverse=True)
+ return reranked_results
+
+ def tokenize_string(
+ self,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[int]:
+ raise ValueError(
+ "SentenceTransformerEmbeddingProvider does not support tokenize_string."
+ )