aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/providers/embeddings
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/providers/embeddings
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/providers/embeddings')
-rwxr-xr-xR2R/r2r/providers/embeddings/__init__.py11
-rwxr-xr-xR2R/r2r/providers/embeddings/ollama/ollama_base.py156
-rwxr-xr-xR2R/r2r/providers/embeddings/openai/openai_base.py200
-rwxr-xr-xR2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py160
4 files changed, 527 insertions, 0 deletions
diff --git a/R2R/r2r/providers/embeddings/__init__.py b/R2R/r2r/providers/embeddings/__init__.py
new file mode 100755
index 00000000..6b0c8b83
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/__init__.py
@@ -0,0 +1,11 @@
+from .ollama.ollama_base import OllamaEmbeddingProvider
+from .openai.openai_base import OpenAIEmbeddingProvider
+from .sentence_transformer.sentence_transformer_base import (
+ SentenceTransformerEmbeddingProvider,
+)
+
+__all__ = [
+ "OllamaEmbeddingProvider",
+ "OpenAIEmbeddingProvider",
+ "SentenceTransformerEmbeddingProvider",
+]
diff --git a/R2R/r2r/providers/embeddings/ollama/ollama_base.py b/R2R/r2r/providers/embeddings/ollama/ollama_base.py
new file mode 100755
index 00000000..31a8c717
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/ollama/ollama_base.py
@@ -0,0 +1,156 @@
+import asyncio
+import logging
+import os
+import random
+from typing import Any
+
+from ollama import AsyncClient, Client
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class OllamaEmbeddingProvider(EmbeddingProvider):
+ def __init__(self, config: EmbeddingConfig):
+ super().__init__(config)
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize `OllamaEmbeddingProvider`."
+ )
+ if provider != "ollama":
+ raise ValueError(
+ "OllamaEmbeddingProvider must be initialized with provider `ollama`."
+ )
+ if config.rerank_model:
+ raise ValueError(
+ "OllamaEmbeddingProvider does not support separate reranking."
+ )
+
+ self.base_model = config.base_model
+ self.base_dimension = config.base_dimension
+ self.base_url = os.getenv("OLLAMA_API_BASE")
+ logger.info(
+ f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}"
+ )
+ self.client = Client(host=self.base_url)
+ self.aclient = AsyncClient(host=self.base_url)
+
+ self.request_queue = asyncio.Queue()
+ self.max_retries = 2
+ self.initial_backoff = 1
+ self.max_backoff = 60
+ self.concurrency_limit = 10
+ self.semaphore = asyncio.Semaphore(self.concurrency_limit)
+
+ async def process_queue(self):
+ while True:
+ task = await self.request_queue.get()
+ try:
+ result = await self.execute_task_with_backoff(task)
+ task["future"].set_result(result)
+ except Exception as e:
+ task["future"].set_exception(e)
+ finally:
+ self.request_queue.task_done()
+
+ async def execute_task_with_backoff(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.initial_backoff
+ while retries < self.max_retries:
+ try:
+ async with self.semaphore:
+ response = await asyncio.wait_for(
+ self.aclient.embeddings(
+ prompt=task["text"], model=self.base_model
+ ),
+ timeout=30,
+ )
+ return response["embedding"]
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.max_retries:
+ raise Exception(
+ f"Max retries reached. Last error: {str(e)}"
+ )
+ await asyncio.sleep(backoff + random.uniform(0, 1))
+ backoff = min(backoff * 2, self.max_backoff)
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ try:
+ response = self.client.embeddings(
+ prompt=text, model=self.base_model
+ )
+ return response["embedding"]
+ except Exception as e:
+ logger.error(f"Error getting embedding: {str(e)}")
+ raise
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ return [self.get_embedding(text, stage) for text in texts]
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ queue_processor = asyncio.create_task(self.process_queue())
+ futures = []
+ for text in texts:
+ future = asyncio.Future()
+ await self.request_queue.put({"text": text, "future": future})
+ futures.append(future)
+
+ try:
+ results = await asyncio.gather(*futures, return_exceptions=True)
+ # Check if any result is an exception and raise it
+ exceptions = set([r for r in results if isinstance(r, Exception)])
+ if exceptions:
+ raise Exception(
+ f"Embedding generation failed for one or more embeddings."
+ )
+ return results
+ except Exception as e:
+ logger.error(f"Embedding generation failed: {str(e)}")
+ raise
+ finally:
+ await self.request_queue.join()
+ queue_processor.cancel()
+
+ def rerank(
+ self,
+ query: str,
+ results: list[VectorSearchResult],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+ limit: int = 10,
+ ) -> list[VectorSearchResult]:
+ return results[:limit]
+
+ def tokenize_string(
+ self, text: str, model: str, stage: EmbeddingProvider.PipeStage
+ ) -> list[int]:
+ raise NotImplementedError(
+ "Tokenization is not supported by OllamaEmbeddingProvider."
+ )
diff --git a/R2R/r2r/providers/embeddings/openai/openai_base.py b/R2R/r2r/providers/embeddings/openai/openai_base.py
new file mode 100755
index 00000000..7e7d32aa
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/openai/openai_base.py
@@ -0,0 +1,200 @@
+import logging
+import os
+
+from openai import AsyncOpenAI, AuthenticationError, OpenAI
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAIEmbeddingProvider(EmbeddingProvider):
+ MODEL_TO_TOKENIZER = {
+ "text-embedding-ada-002": "cl100k_base",
+ "text-embedding-3-small": "cl100k_base",
+ "text-embedding-3-large": "cl100k_base",
+ }
+ MODEL_TO_DIMENSIONS = {
+ "text-embedding-ada-002": [1536],
+ "text-embedding-3-small": [512, 1536],
+ "text-embedding-3-large": [256, 1024, 3072],
+ }
+
+ def __init__(self, config: EmbeddingConfig):
+ super().__init__(config)
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize OpenAIEmbeddingProvider."
+ )
+
+ if provider != "openai":
+ raise ValueError(
+ "OpenAIEmbeddingProvider must be initialized with provider `openai`."
+ )
+ if not os.getenv("OPENAI_API_KEY"):
+ raise ValueError(
+ "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+ )
+ self.client = OpenAI()
+ self.async_client = AsyncOpenAI()
+
+ if config.rerank_model:
+ raise ValueError(
+ "OpenAIEmbeddingProvider does not support separate reranking."
+ )
+ self.base_model = config.base_model
+ self.base_dimension = config.base_dimension
+
+ if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+ raise ValueError(
+ f"OpenAI embedding model {self.base_model} not supported."
+ )
+ if (
+ self.base_dimension
+ and self.base_dimension
+ not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model]
+ ):
+ raise ValueError(
+ f"Dimensions {self.dimension} for {self.base_model} are not supported"
+ )
+
+ if not self.base_model or not self.base_dimension:
+ raise ValueError(
+ "Must set base_model and base_dimension in order to initialize OpenAIEmbeddingProvider."
+ )
+
+ if config.rerank_model:
+ raise ValueError(
+ "OpenAIEmbeddingProvider does not support separate reranking."
+ )
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ try:
+ return (
+ self.client.embeddings.create(
+ input=[text],
+ model=self.base_model,
+ dimensions=self.base_dimension
+ or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+ self.base_model
+ ][-1],
+ )
+ .data[0]
+ .embedding
+ )
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ try:
+ response = await self.async_client.embeddings.create(
+ input=[text],
+ model=self.base_model,
+ dimensions=self.base_dimension
+ or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+ self.base_model
+ ][-1],
+ )
+ return response.data[0].embedding
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ try:
+ return [
+ ele.embedding
+ for ele in self.client.embeddings.create(
+ input=texts,
+ model=self.base_model,
+ dimensions=self.base_dimension
+ or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+ self.base_model
+ ][-1],
+ ).data
+ ]
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ try:
+ response = await self.async_client.embeddings.create(
+ input=texts,
+ model=self.base_model,
+ dimensions=self.base_dimension
+ or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+ self.base_model
+ ][-1],
+ )
+ return [ele.embedding for ele in response.data]
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+
+ def rerank(
+ self,
+ query: str,
+ results: list[VectorSearchResult],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+ limit: int = 10,
+ ):
+ return results[:limit]
+
+ def tokenize_string(self, text: str, model: str) -> list[int]:
+ try:
+ import tiktoken
+ except ImportError:
+ raise ValueError(
+ "Must download tiktoken library to run `tokenize_string`."
+ )
+ # tiktoken encoding -
+ # cl100k_base - gpt-4, gpt-3.5-turbo, text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large
+ if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+ raise ValueError(f"OpenAI embedding model {model} not supported.")
+ encoding = tiktoken.get_encoding(
+ OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model]
+ )
+ return encoding.encode(text)
diff --git a/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py
new file mode 100755
index 00000000..3316cb60
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py
@@ -0,0 +1,160 @@
+import logging
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class SentenceTransformerEmbeddingProvider(EmbeddingProvider):
+ def __init__(
+ self,
+ config: EmbeddingConfig,
+ ):
+ super().__init__(config)
+ logger.info(
+ "Initializing `SentenceTransformerEmbeddingProvider` with separate models for search and rerank."
+ )
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize SentenceTransformerEmbeddingProvider."
+ )
+ if provider != "sentence-transformers":
+ raise ValueError(
+ "SentenceTransformerEmbeddingProvider must be initialized with provider `sentence-transformers`."
+ )
+ try:
+ from sentence_transformers import CrossEncoder, SentenceTransformer
+
+ self.SentenceTransformer = SentenceTransformer
+ # TODO - Modify this to be configurable, as `bge-reranker-large` is a `SentenceTransformer` model
+ self.CrossEncoder = CrossEncoder
+ except ImportError as e:
+ raise ValueError(
+ "Must download sentence-transformers library to run `SentenceTransformerEmbeddingProvider`."
+ ) from e
+
+ # Initialize separate models for search and rerank
+ self.do_search = False
+ self.do_rerank = False
+
+ self.search_encoder = self._init_model(
+ config, EmbeddingProvider.PipeStage.BASE
+ )
+ self.rerank_encoder = self._init_model(
+ config, EmbeddingProvider.PipeStage.RERANK
+ )
+
+ def _init_model(self, config: EmbeddingConfig, stage: str):
+ stage_name = stage.name.lower()
+ model = config.dict().get(f"{stage_name}_model", None)
+ dimension = config.dict().get(f"{stage_name}_dimension", None)
+
+ transformer_type = config.dict().get(
+ f"{stage_name}_transformer_type", "SentenceTransformer"
+ )
+
+ if stage == EmbeddingProvider.PipeStage.BASE:
+ self.do_search = True
+ # Check if a model is set for the stage
+ if not (model and dimension and transformer_type):
+ raise ValueError(
+ f"Must set {stage.name.lower()}_model and {stage.name.lower()}_dimension for {stage} stage in order to initialize SentenceTransformerEmbeddingProvider."
+ )
+
+ if stage == EmbeddingProvider.PipeStage.RERANK:
+ # Check if a model is set for the stage
+ if not (model and dimension and transformer_type):
+ return None
+
+ self.do_rerank = True
+ if transformer_type == "SentenceTransformer":
+ raise ValueError(
+ f"`SentenceTransformer` models are not yet supported for {stage} stage in SentenceTransformerEmbeddingProvider."
+ )
+
+ # Save the model_key and dimension into instance variables
+ setattr(self, f"{stage_name}_model", model)
+ setattr(self, f"{stage_name}_dimension", dimension)
+ setattr(self, f"{stage_name}_transformer_type", transformer_type)
+
+ # Initialize the model
+ encoder = (
+ self.SentenceTransformer(
+ model, truncate_dim=dimension, trust_remote_code=True
+ )
+ if transformer_type == "SentenceTransformer"
+ else self.CrossEncoder(model, trust_remote_code=True)
+ )
+ return encoder
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError("`get_embedding` only supports `SEARCH` stage.")
+ if not self.do_search:
+ raise ValueError(
+ "`get_embedding` can only be called for the search stage if a search model is set."
+ )
+ encoder = self.search_encoder
+ return encoder.encode([text]).tolist()[0]
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError("`get_embeddings` only supports `SEARCH` stage.")
+ if not self.do_search:
+ raise ValueError(
+ "`get_embeddings` can only be called for the search stage if a search model is set."
+ )
+ encoder = (
+ self.search_encoder
+ if stage == EmbeddingProvider.PipeStage.BASE
+ else self.rerank_encoder
+ )
+ return encoder.encode(texts).tolist()
+
+ def rerank(
+ self,
+ query: str,
+ results: list[VectorSearchResult],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+ limit: int = 10,
+ ) -> list[VectorSearchResult]:
+ if stage != EmbeddingProvider.PipeStage.RERANK:
+ raise ValueError("`rerank` only supports `RERANK` stage.")
+ if not self.do_rerank:
+ return results[:limit]
+
+ from copy import copy
+
+ texts = copy([doc.metadata["text"] for doc in results])
+ # Use the rank method from the rerank_encoder, which is a CrossEncoder model
+ reranked_scores = self.rerank_encoder.rank(
+ query, texts, return_documents=False, top_k=limit
+ )
+ # Map the reranked scores back to the original documents
+ reranked_results = []
+ for score in reranked_scores:
+ corpus_id = score["corpus_id"]
+ new_result = results[corpus_id]
+ new_result.score = float(score["score"])
+ reranked_results.append(new_result)
+
+ # Sort the documents by the new scores in descending order
+ reranked_results.sort(key=lambda doc: doc.score, reverse=True)
+ return reranked_results
+
+ def tokenize_string(
+ self,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[int]:
+ raise ValueError(
+ "SentenceTransformerEmbeddingProvider does not support tokenize_string."
+ )