about summary refs log tree commit diff
path: root/R2R/r2r/providers/embeddings
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/providers/embeddings')
-rwxr-xr-xR2R/r2r/providers/embeddings/__init__.py11
-rwxr-xr-xR2R/r2r/providers/embeddings/ollama/ollama_base.py156
-rwxr-xr-xR2R/r2r/providers/embeddings/openai/openai_base.py200
-rwxr-xr-xR2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py160
4 files changed, 527 insertions, 0 deletions
diff --git a/R2R/r2r/providers/embeddings/__init__.py b/R2R/r2r/providers/embeddings/__init__.py
new file mode 100755
index 00000000..6b0c8b83
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/__init__.py
@@ -0,0 +1,11 @@
+from .ollama.ollama_base import OllamaEmbeddingProvider
+from .openai.openai_base import OpenAIEmbeddingProvider
+from .sentence_transformer.sentence_transformer_base import (
+    SentenceTransformerEmbeddingProvider,
+)
+
+__all__ = [
+    "OllamaEmbeddingProvider",
+    "OpenAIEmbeddingProvider",
+    "SentenceTransformerEmbeddingProvider",
+]
diff --git a/R2R/r2r/providers/embeddings/ollama/ollama_base.py b/R2R/r2r/providers/embeddings/ollama/ollama_base.py
new file mode 100755
index 00000000..31a8c717
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/ollama/ollama_base.py
@@ -0,0 +1,156 @@
+import asyncio
+import logging
+import os
+import random
+from typing import Any
+
+from ollama import AsyncClient, Client
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class OllamaEmbeddingProvider(EmbeddingProvider):
+    def __init__(self, config: EmbeddingConfig):
+        super().__init__(config)
+        provider = config.provider
+        if not provider:
+            raise ValueError(
+                "Must set provider in order to initialize `OllamaEmbeddingProvider`."
+            )
+        if provider != "ollama":
+            raise ValueError(
+                "OllamaEmbeddingProvider must be initialized with provider `ollama`."
+            )
+        if config.rerank_model:
+            raise ValueError(
+                "OllamaEmbeddingProvider does not support separate reranking."
+            )
+
+        self.base_model = config.base_model
+        self.base_dimension = config.base_dimension
+        self.base_url = os.getenv("OLLAMA_API_BASE")
+        logger.info(
+            f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}"
+        )
+        self.client = Client(host=self.base_url)
+        self.aclient = AsyncClient(host=self.base_url)
+
+        self.request_queue = asyncio.Queue()
+        self.max_retries = 2
+        self.initial_backoff = 1
+        self.max_backoff = 60
+        self.concurrency_limit = 10
+        self.semaphore = asyncio.Semaphore(self.concurrency_limit)
+
+    async def process_queue(self):
+        while True:
+            task = await self.request_queue.get()
+            try:
+                result = await self.execute_task_with_backoff(task)
+                task["future"].set_result(result)
+            except Exception as e:
+                task["future"].set_exception(e)
+            finally:
+                self.request_queue.task_done()
+
+    async def execute_task_with_backoff(self, task: dict[str, Any]):
+        retries = 0
+        backoff = self.initial_backoff
+        while retries < self.max_retries:
+            try:
+                async with self.semaphore:
+                    response = await asyncio.wait_for(
+                        self.aclient.embeddings(
+                            prompt=task["text"], model=self.base_model
+                        ),
+                        timeout=30,
+                    )
+                return response["embedding"]
+            except Exception as e:
+                logger.warning(
+                    f"Request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.max_retries:
+                    raise Exception(
+                        f"Max retries reached. Last error: {str(e)}"
+                    )
+                await asyncio.sleep(backoff + random.uniform(0, 1))
+                backoff = min(backoff * 2, self.max_backoff)
+
+    def get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OllamaEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            response = self.client.embeddings(
+                prompt=text, model=self.base_model
+            )
+            return response["embedding"]
+        except Exception as e:
+            logger.error(f"Error getting embedding: {str(e)}")
+            raise
+
+    def get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        return [self.get_embedding(text, stage) for text in texts]
+
+    async def async_get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OllamaEmbeddingProvider only supports search stage."
+            )
+
+        queue_processor = asyncio.create_task(self.process_queue())
+        futures = []
+        for text in texts:
+            future = asyncio.Future()
+            await self.request_queue.put({"text": text, "future": future})
+            futures.append(future)
+
+        try:
+            results = await asyncio.gather(*futures, return_exceptions=True)
+            # Check if any result is an exception and raise it
+            exceptions = set([r for r in results if isinstance(r, Exception)])
+            if exceptions:
+                raise Exception(
+                    f"Embedding generation failed for one or more embeddings."
+                )
+            return results
+        except Exception as e:
+            logger.error(f"Embedding generation failed: {str(e)}")
+            raise
+        finally:
+            await self.request_queue.join()
+            queue_processor.cancel()
+
+    def rerank(
+        self,
+        query: str,
+        results: list[VectorSearchResult],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+        limit: int = 10,
+    ) -> list[VectorSearchResult]:
+        return results[:limit]
+
+    def tokenize_string(
+        self, text: str, model: str, stage: EmbeddingProvider.PipeStage
+    ) -> list[int]:
+        raise NotImplementedError(
+            "Tokenization is not supported by OllamaEmbeddingProvider."
+        )
diff --git a/R2R/r2r/providers/embeddings/openai/openai_base.py b/R2R/r2r/providers/embeddings/openai/openai_base.py
new file mode 100755
index 00000000..7e7d32aa
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/openai/openai_base.py
@@ -0,0 +1,200 @@
+import logging
+import os
+
+from openai import AsyncOpenAI, AuthenticationError, OpenAI
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAIEmbeddingProvider(EmbeddingProvider):
+    MODEL_TO_TOKENIZER = {
+        "text-embedding-ada-002": "cl100k_base",
+        "text-embedding-3-small": "cl100k_base",
+        "text-embedding-3-large": "cl100k_base",
+    }
+    MODEL_TO_DIMENSIONS = {
+        "text-embedding-ada-002": [1536],
+        "text-embedding-3-small": [512, 1536],
+        "text-embedding-3-large": [256, 1024, 3072],
+    }
+
+    def __init__(self, config: EmbeddingConfig):
+        super().__init__(config)
+        provider = config.provider
+        if not provider:
+            raise ValueError(
+                "Must set provider in order to initialize OpenAIEmbeddingProvider."
+            )
+
+        if provider != "openai":
+            raise ValueError(
+                "OpenAIEmbeddingProvider must be initialized with provider `openai`."
+            )
+        if not os.getenv("OPENAI_API_KEY"):
+            raise ValueError(
+                "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+            )
+        self.client = OpenAI()
+        self.async_client = AsyncOpenAI()
+
+        if config.rerank_model:
+            raise ValueError(
+                "OpenAIEmbeddingProvider does not support separate reranking."
+            )
+        self.base_model = config.base_model
+        self.base_dimension = config.base_dimension
+
+        if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+            raise ValueError(
+                f"OpenAI embedding model {self.base_model} not supported."
+            )
+        if (
+            self.base_dimension
+            and self.base_dimension
+            not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model]
+        ):
+            raise ValueError(
+                f"Dimensions {self.dimension} for {self.base_model} are not supported"
+            )
+
+        if not self.base_model or not self.base_dimension:
+            raise ValueError(
+                "Must set base_model and base_dimension in order to initialize OpenAIEmbeddingProvider."
+            )
+
+        if config.rerank_model:
+            raise ValueError(
+                "OpenAIEmbeddingProvider does not support separate reranking."
+            )
+
+    def get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            return (
+                self.client.embeddings.create(
+                    input=[text],
+                    model=self.base_model,
+                    dimensions=self.base_dimension
+                    or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                        self.base_model
+                    ][-1],
+                )
+                .data[0]
+                .embedding
+            )
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+
+    async def async_get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            response = await self.async_client.embeddings.create(
+                input=[text],
+                model=self.base_model,
+                dimensions=self.base_dimension
+                or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                    self.base_model
+                ][-1],
+            )
+            return response.data[0].embedding
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+
+    def get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            return [
+                ele.embedding
+                for ele in self.client.embeddings.create(
+                    input=texts,
+                    model=self.base_model,
+                    dimensions=self.base_dimension
+                    or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                        self.base_model
+                    ][-1],
+                ).data
+            ]
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+
+    async def async_get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            response = await self.async_client.embeddings.create(
+                input=texts,
+                model=self.base_model,
+                dimensions=self.base_dimension
+                or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                    self.base_model
+                ][-1],
+            )
+            return [ele.embedding for ele in response.data]
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+
+    def rerank(
+        self,
+        query: str,
+        results: list[VectorSearchResult],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+        limit: int = 10,
+    ):
+        return results[:limit]
+
+    def tokenize_string(self, text: str, model: str) -> list[int]:
+        try:
+            import tiktoken
+        except ImportError:
+            raise ValueError(
+                "Must download tiktoken library to run `tokenize_string`."
+            )
+        # tiktoken encoding -
+        # cl100k_base -	gpt-4, gpt-3.5-turbo, text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large
+        if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+            raise ValueError(f"OpenAI embedding model {model} not supported.")
+        encoding = tiktoken.get_encoding(
+            OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model]
+        )
+        return encoding.encode(text)
diff --git a/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py
new file mode 100755
index 00000000..3316cb60
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py
@@ -0,0 +1,160 @@
+import logging
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class SentenceTransformerEmbeddingProvider(EmbeddingProvider):
+    def __init__(
+        self,
+        config: EmbeddingConfig,
+    ):
+        super().__init__(config)
+        logger.info(
+            "Initializing `SentenceTransformerEmbeddingProvider` with separate models for search and rerank."
+        )
+        provider = config.provider
+        if not provider:
+            raise ValueError(
+                "Must set provider in order to initialize SentenceTransformerEmbeddingProvider."
+            )
+        if provider != "sentence-transformers":
+            raise ValueError(
+                "SentenceTransformerEmbeddingProvider must be initialized with provider `sentence-transformers`."
+            )
+        try:
+            from sentence_transformers import CrossEncoder, SentenceTransformer
+
+            self.SentenceTransformer = SentenceTransformer
+            # TODO - Modify this to be configurable, as `bge-reranker-large` is a `SentenceTransformer` model
+            self.CrossEncoder = CrossEncoder
+        except ImportError as e:
+            raise ValueError(
+                "Must download sentence-transformers library to run `SentenceTransformerEmbeddingProvider`."
+            ) from e
+
+        # Initialize separate models for search and rerank
+        self.do_search = False
+        self.do_rerank = False
+
+        self.search_encoder = self._init_model(
+            config, EmbeddingProvider.PipeStage.BASE
+        )
+        self.rerank_encoder = self._init_model(
+            config, EmbeddingProvider.PipeStage.RERANK
+        )
+
+    def _init_model(self, config: EmbeddingConfig, stage: str):
+        stage_name = stage.name.lower()
+        model = config.dict().get(f"{stage_name}_model", None)
+        dimension = config.dict().get(f"{stage_name}_dimension", None)
+
+        transformer_type = config.dict().get(
+            f"{stage_name}_transformer_type", "SentenceTransformer"
+        )
+
+        if stage == EmbeddingProvider.PipeStage.BASE:
+            self.do_search = True
+            # Check if a model is set for the stage
+            if not (model and dimension and transformer_type):
+                raise ValueError(
+                    f"Must set {stage.name.lower()}_model and {stage.name.lower()}_dimension for {stage} stage in order to initialize SentenceTransformerEmbeddingProvider."
+                )
+
+        if stage == EmbeddingProvider.PipeStage.RERANK:
+            # Check if a model is set for the stage
+            if not (model and dimension and transformer_type):
+                return None
+
+            self.do_rerank = True
+            if transformer_type == "SentenceTransformer":
+                raise ValueError(
+                    f"`SentenceTransformer` models are not yet supported for {stage} stage in SentenceTransformerEmbeddingProvider."
+                )
+
+        # Save the model_key and dimension into instance variables
+        setattr(self, f"{stage_name}_model", model)
+        setattr(self, f"{stage_name}_dimension", dimension)
+        setattr(self, f"{stage_name}_transformer_type", transformer_type)
+
+        # Initialize the model
+        encoder = (
+            self.SentenceTransformer(
+                model, truncate_dim=dimension, trust_remote_code=True
+            )
+            if transformer_type == "SentenceTransformer"
+            else self.CrossEncoder(model, trust_remote_code=True)
+        )
+        return encoder
+
+    def get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError("`get_embedding` only supports `SEARCH` stage.")
+        if not self.do_search:
+            raise ValueError(
+                "`get_embedding` can only be called for the search stage if a search model is set."
+            )
+        encoder = self.search_encoder
+        return encoder.encode([text]).tolist()[0]
+
+    def get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError("`get_embeddings` only supports `SEARCH` stage.")
+        if not self.do_search:
+            raise ValueError(
+                "`get_embeddings` can only be called for the search stage if a search model is set."
+            )
+        encoder = (
+            self.search_encoder
+            if stage == EmbeddingProvider.PipeStage.BASE
+            else self.rerank_encoder
+        )
+        return encoder.encode(texts).tolist()
+
+    def rerank(
+        self,
+        query: str,
+        results: list[VectorSearchResult],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+        limit: int = 10,
+    ) -> list[VectorSearchResult]:
+        if stage != EmbeddingProvider.PipeStage.RERANK:
+            raise ValueError("`rerank` only supports `RERANK` stage.")
+        if not self.do_rerank:
+            return results[:limit]
+
+        from copy import copy
+
+        texts = copy([doc.metadata["text"] for doc in results])
+        # Use the rank method from the rerank_encoder, which is a CrossEncoder model
+        reranked_scores = self.rerank_encoder.rank(
+            query, texts, return_documents=False, top_k=limit
+        )
+        # Map the reranked scores back to the original documents
+        reranked_results = []
+        for score in reranked_scores:
+            corpus_id = score["corpus_id"]
+            new_result = results[corpus_id]
+            new_result.score = float(score["score"])
+            reranked_results.append(new_result)
+
+        # Sort the documents by the new scores in descending order
+        reranked_results.sort(key=lambda doc: doc.score, reverse=True)
+        return reranked_results
+
+    def tokenize_string(
+        self,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[int]:
+        raise ValueError(
+            "SentenceTransformerEmbeddingProvider does not support tokenize_string."
+        )