aboutsummaryrefslogtreecommitdiff
import logging

from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult

logger = logging.getLogger(__name__)


class SentenceTransformerEmbeddingProvider(EmbeddingProvider):
    def __init__(
        self,
        config: EmbeddingConfig,
    ):
        super().__init__(config)
        logger.info(
            "Initializing `SentenceTransformerEmbeddingProvider` with separate models for search and rerank."
        )
        provider = config.provider
        if not provider:
            raise ValueError(
                "Must set provider in order to initialize SentenceTransformerEmbeddingProvider."
            )
        if provider != "sentence-transformers":
            raise ValueError(
                "SentenceTransformerEmbeddingProvider must be initialized with provider `sentence-transformers`."
            )
        try:
            from sentence_transformers import CrossEncoder, SentenceTransformer

            self.SentenceTransformer = SentenceTransformer
            # TODO - Modify this to be configurable, as `bge-reranker-large` is a `SentenceTransformer` model
            self.CrossEncoder = CrossEncoder
        except ImportError as e:
            raise ValueError(
                "Must download sentence-transformers library to run `SentenceTransformerEmbeddingProvider`."
            ) from e

        # Initialize separate models for search and rerank
        self.do_search = False
        self.do_rerank = False

        self.search_encoder = self._init_model(
            config, EmbeddingProvider.PipeStage.BASE
        )
        self.rerank_encoder = self._init_model(
            config, EmbeddingProvider.PipeStage.RERANK
        )

    def _init_model(self, config: EmbeddingConfig, stage: str):
        stage_name = stage.name.lower()
        model = config.dict().get(f"{stage_name}_model", None)
        dimension = config.dict().get(f"{stage_name}_dimension", None)

        transformer_type = config.dict().get(
            f"{stage_name}_transformer_type", "SentenceTransformer"
        )

        if stage == EmbeddingProvider.PipeStage.BASE:
            self.do_search = True
            # Check if a model is set for the stage
            if not (model and dimension and transformer_type):
                raise ValueError(
                    f"Must set {stage.name.lower()}_model and {stage.name.lower()}_dimension for {stage} stage in order to initialize SentenceTransformerEmbeddingProvider."
                )

        if stage == EmbeddingProvider.PipeStage.RERANK:
            # Check if a model is set for the stage
            if not (model and dimension and transformer_type):
                return None

            self.do_rerank = True
            if transformer_type == "SentenceTransformer":
                raise ValueError(
                    f"`SentenceTransformer` models are not yet supported for {stage} stage in SentenceTransformerEmbeddingProvider."
                )

        # Save the model_key and dimension into instance variables
        setattr(self, f"{stage_name}_model", model)
        setattr(self, f"{stage_name}_dimension", dimension)
        setattr(self, f"{stage_name}_transformer_type", transformer_type)

        # Initialize the model
        encoder = (
            self.SentenceTransformer(
                model, truncate_dim=dimension, trust_remote_code=True
            )
            if transformer_type == "SentenceTransformer"
            else self.CrossEncoder(model, trust_remote_code=True)
        )
        return encoder

    def get_embedding(
        self,
        text: str,
        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
    ) -> list[float]:
        if stage != EmbeddingProvider.PipeStage.BASE:
            raise ValueError("`get_embedding` only supports `SEARCH` stage.")
        if not self.do_search:
            raise ValueError(
                "`get_embedding` can only be called for the search stage if a search model is set."
            )
        encoder = self.search_encoder
        return encoder.encode([text]).tolist()[0]

    def get_embeddings(
        self,
        texts: list[str],
        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
    ) -> list[list[float]]:
        if stage != EmbeddingProvider.PipeStage.BASE:
            raise ValueError("`get_embeddings` only supports `SEARCH` stage.")
        if not self.do_search:
            raise ValueError(
                "`get_embeddings` can only be called for the search stage if a search model is set."
            )
        encoder = (
            self.search_encoder
            if stage == EmbeddingProvider.PipeStage.BASE
            else self.rerank_encoder
        )
        return encoder.encode(texts).tolist()

    def rerank(
        self,
        query: str,
        results: list[VectorSearchResult],
        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
        limit: int = 10,
    ) -> list[VectorSearchResult]:
        if stage != EmbeddingProvider.PipeStage.RERANK:
            raise ValueError("`rerank` only supports `RERANK` stage.")
        if not self.do_rerank:
            return results[:limit]

        from copy import copy

        texts = copy([doc.metadata["text"] for doc in results])
        # Use the rank method from the rerank_encoder, which is a CrossEncoder model
        reranked_scores = self.rerank_encoder.rank(
            query, texts, return_documents=False, top_k=limit
        )
        # Map the reranked scores back to the original documents
        reranked_results = []
        for score in reranked_scores:
            corpus_id = score["corpus_id"]
            new_result = results[corpus_id]
            new_result.score = float(score["score"])
            reranked_results.append(new_result)

        # Sort the documents by the new scores in descending order
        reranked_results.sort(key=lambda doc: doc.score, reverse=True)
        return reranked_results

    def tokenize_string(
        self,
        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
    ) -> list[int]:
        raise ValueError(
            "SentenceTransformerEmbeddingProvider does not support tokenize_string."
        )