diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/embeddings')
4 files changed, 751 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/__init__.py new file mode 100644 index 00000000..3fa67442 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/__init__.py @@ -0,0 +1,9 @@ +from .litellm import LiteLLMEmbeddingProvider +from .ollama import OllamaEmbeddingProvider +from .openai import OpenAIEmbeddingProvider + +__all__ = [ + "LiteLLMEmbeddingProvider", + "OpenAIEmbeddingProvider", + "OllamaEmbeddingProvider", +] diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/litellm.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/litellm.py new file mode 100644 index 00000000..5f705c91 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/litellm.py @@ -0,0 +1,305 @@ +import logging +import math +import os +from copy import copy +from typing import Any + +import litellm +import requests +from aiohttp import ClientError, ClientSession +from litellm import AuthenticationError, aembedding, embedding + +from core.base import ( + ChunkSearchResult, + EmbeddingConfig, + EmbeddingProvider, + EmbeddingPurpose, + R2RException, +) + +logger = logging.getLogger() + + +class LiteLLMEmbeddingProvider(EmbeddingProvider): + def __init__( + self, + config: EmbeddingConfig, + *args, + **kwargs, + ) -> None: + super().__init__(config) + + self.litellm_embedding = embedding + self.litellm_aembedding = aembedding + + provider = config.provider + if not provider: + raise ValueError( + "Must set provider in order to initialize `LiteLLMEmbeddingProvider`." + ) + if provider != "litellm": + raise ValueError( + "LiteLLMEmbeddingProvider must be initialized with provider `litellm`." + ) + + self.rerank_url = None + if config.rerank_model: + if "huggingface" not in config.rerank_model: + raise ValueError( + "LiteLLMEmbeddingProvider only supports re-ranking via the HuggingFace text-embeddings-inference API" + ) + + url = os.getenv("HUGGINGFACE_API_BASE") or config.rerank_url + if not url: + raise ValueError( + "LiteLLMEmbeddingProvider requires a valid reranking API url to be set via `embedding.rerank_url` in the r2r.toml, or via the environment variable `HUGGINGFACE_API_BASE`." + ) + self.rerank_url = url + + self.base_model = config.base_model + if "amazon" in self.base_model: + logger.warn("Amazon embedding model detected, dropping params") + litellm.drop_params = True + self.base_dimension = config.base_dimension + + def _get_embedding_kwargs(self, **kwargs): + embedding_kwargs = { + "model": self.base_model, + "dimensions": self.base_dimension, + } + embedding_kwargs.update(kwargs) + return embedding_kwargs + + async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]: + texts = task["texts"] + kwargs = self._get_embedding_kwargs(**task.get("kwargs", {})) + + if "dimensions" in kwargs and math.isnan(kwargs["dimensions"]): + kwargs.pop("dimensions") + logger.warning("Dropping nan dimensions from kwargs") + + try: + response = await self.litellm_aembedding( + input=texts, + **kwargs, + ) + return [data["embedding"] for data in response.data] + except AuthenticationError: + logger.error( + "Authentication error: Invalid API key or credentials." + ) + raise + except Exception as e: + error_msg = f"Error getting embeddings: {str(e)}" + logger.error(error_msg) + + raise R2RException(error_msg, 400) from e + + def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]: + texts = task["texts"] + kwargs = self._get_embedding_kwargs(**task.get("kwargs", {})) + try: + response = self.litellm_embedding( + input=texts, + **kwargs, + ) + return [data["embedding"] for data in response.data] + except AuthenticationError: + logger.error( + "Authentication error: Invalid API key or credentials." + ) + raise + except Exception as e: + error_msg = f"Error getting embeddings: {str(e)}" + logger.error(error_msg) + raise R2RException(error_msg, 400) from e + + async def async_get_embedding( + self, + text: str, + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[float]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "LiteLLMEmbeddingProvider only supports search stage." + ) + + task = { + "texts": [text], + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + return (await self._execute_with_backoff_async(task))[0] + + def get_embedding( + self, + text: str, + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[float]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "Error getting embeddings: LiteLLMEmbeddingProvider only supports search stage." + ) + + task = { + "texts": [text], + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + return self._execute_with_backoff_sync(task)[0] + + async def async_get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[list[float]]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "LiteLLMEmbeddingProvider only supports search stage." + ) + + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + return await self._execute_with_backoff_async(task) + + def get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[list[float]]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "LiteLLMEmbeddingProvider only supports search stage." + ) + + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + return self._execute_with_backoff_sync(task) + + def rerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK, + limit: int = 10, + ): + if self.config.rerank_model is not None: + if not self.rerank_url: + raise ValueError( + "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider" + ) + + texts = [result.text for result in results] + + payload = { + "query": query, + "texts": texts, + "model-id": self.config.rerank_model.split("huggingface/")[1], + } + + headers = {"Content-Type": "application/json"} + + try: + response = requests.post( + self.rerank_url, json=payload, headers=headers + ) + response.raise_for_status() + reranked_results = response.json() + + # Copy reranked results into new array + scored_results = [] + for rank_info in reranked_results: + original_result = results[rank_info["index"]] + copied_result = copy(original_result) + # Inject the reranking score into the result object + copied_result.score = rank_info["score"] + scored_results.append(copied_result) + + # Return only the ChunkSearchResult objects, limited to specified count + return scored_results[:limit] + + except requests.RequestException as e: + logger.error(f"Error during reranking: {str(e)}") + # Fall back to returning the original results if reranking fails + return results[:limit] + else: + return results[:limit] + + async def arerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK, + limit: int = 10, + ) -> list[ChunkSearchResult]: + """Asynchronously rerank search results using the configured rerank + model. + + Args: + query: The search query string + results: List of ChunkSearchResult objects to rerank + limit: Maximum number of results to return + + Returns: + List of reranked ChunkSearchResult objects, limited to specified count + """ + if self.config.rerank_model is not None: + if not self.rerank_url: + raise ValueError( + "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider" + ) + + texts = [result.text for result in results] + + payload = { + "query": query, + "texts": texts, + "model-id": self.config.rerank_model.split("huggingface/")[1], + } + + headers = {"Content-Type": "application/json"} + + try: + async with ClientSession() as session: + async with session.post( + self.rerank_url, json=payload, headers=headers + ) as response: + response.raise_for_status() + reranked_results = await response.json() + + # Copy reranked results into new array + scored_results = [] + for rank_info in reranked_results: + original_result = results[rank_info["index"]] + copied_result = copy(original_result) + # Inject the reranking score into the result object + copied_result.score = rank_info["score"] + scored_results.append(copied_result) + + # Return only the ChunkSearchResult objects, limited to specified count + return scored_results[:limit] + + except (ClientError, Exception) as e: + logger.error(f"Error during async reranking: {str(e)}") + # Fall back to returning the original results if reranking fails + return results[:limit] + else: + return results[:limit] diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py new file mode 100644 index 00000000..297d9167 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py @@ -0,0 +1,194 @@ +import logging +import os +from typing import Any + +from ollama import AsyncClient, Client + +from core.base import ( + ChunkSearchResult, + EmbeddingConfig, + EmbeddingProvider, + EmbeddingPurpose, + R2RException, +) + +logger = logging.getLogger() + + +class OllamaEmbeddingProvider(EmbeddingProvider): + def __init__(self, config: EmbeddingConfig): + super().__init__(config) + provider = config.provider + if not provider: + raise ValueError( + "Must set provider in order to initialize `OllamaEmbeddingProvider`." + ) + if provider != "ollama": + raise ValueError( + "OllamaEmbeddingProvider must be initialized with provider `ollama`." + ) + if config.rerank_model: + raise ValueError( + "OllamaEmbeddingProvider does not support separate reranking." + ) + + self.base_model = config.base_model + self.base_dimension = config.base_dimension + self.base_url = os.getenv("OLLAMA_API_BASE") + logger.info( + f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}" + ) + self.client = Client(host=self.base_url) + self.aclient = AsyncClient(host=self.base_url) + + self.set_prefixes(config.prefixes or {}, self.base_model) + self.batch_size = config.batch_size or 32 + + def _get_embedding_kwargs(self, **kwargs): + embedding_kwargs = { + "model": self.base_model, + } + embedding_kwargs.update(kwargs) + return embedding_kwargs + + async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]: + texts = task["texts"] + purpose = task.get("purpose", EmbeddingPurpose.INDEX) + kwargs = self._get_embedding_kwargs(**task.get("kwargs", {})) + + try: + embeddings = [] + for i in range(0, len(texts), self.batch_size): + batch = texts[i : i + self.batch_size] + prefixed_batch = [ + self.prefixes.get(purpose, "") + text for text in batch + ] + response = await self.aclient.embed( + input=prefixed_batch, **kwargs + ) + embeddings.extend(response["embeddings"]) + return embeddings + except Exception as e: + error_msg = f"Error getting embeddings: {str(e)}" + logger.error(error_msg) + raise R2RException(error_msg, 400) from e + + def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]: + texts = task["texts"] + purpose = task.get("purpose", EmbeddingPurpose.INDEX) + kwargs = self._get_embedding_kwargs(**task.get("kwargs", {})) + + try: + embeddings = [] + for i in range(0, len(texts), self.batch_size): + batch = texts[i : i + self.batch_size] + prefixed_batch = [ + self.prefixes.get(purpose, "") + text for text in batch + ] + response = self.client.embed(input=prefixed_batch, **kwargs) + embeddings.extend(response["embeddings"]) + return embeddings + except Exception as e: + error_msg = f"Error getting embeddings: {str(e)}" + logger.error(error_msg) + raise R2RException(error_msg, 400) from e + + async def async_get_embedding( + self, + text: str, + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[float]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OllamaEmbeddingProvider only supports search stage." + ) + + task = { + "texts": [text], + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + result = await self._execute_with_backoff_async(task) + return result[0] + + def get_embedding( + self, + text: str, + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[float]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OllamaEmbeddingProvider only supports search stage." + ) + + task = { + "texts": [text], + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + result = self._execute_with_backoff_sync(task) + return result[0] + + async def async_get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[list[float]]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OllamaEmbeddingProvider only supports search stage." + ) + + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + return await self._execute_with_backoff_async(task) + + def get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[list[float]]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OllamaEmbeddingProvider only supports search stage." + ) + + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + return self._execute_with_backoff_sync(task) + + def rerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK, + limit: int = 10, + ) -> list[ChunkSearchResult]: + return results[:limit] + + async def arerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK, + limit: int = 10, + ): + return results[:limit] diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py new file mode 100644 index 00000000..907cebd9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py @@ -0,0 +1,243 @@ +import logging +import os +from typing import Any + +import tiktoken +from openai import AsyncOpenAI, AuthenticationError, OpenAI +from openai._types import NOT_GIVEN + +from core.base import ( + ChunkSearchResult, + EmbeddingConfig, + EmbeddingProvider, + EmbeddingPurpose, +) + +logger = logging.getLogger() + + +class OpenAIEmbeddingProvider(EmbeddingProvider): + MODEL_TO_TOKENIZER = { + "text-embedding-ada-002": "cl100k_base", + "text-embedding-3-small": "cl100k_base", + "text-embedding-3-large": "cl100k_base", + } + MODEL_TO_DIMENSIONS = { + "text-embedding-ada-002": [1536], + "text-embedding-3-small": [512, 1536], + "text-embedding-3-large": [256, 1024, 3072], + } + + def __init__(self, config: EmbeddingConfig): + super().__init__(config) + if not config.provider: + raise ValueError( + "Must set provider in order to initialize OpenAIEmbeddingProvider." + ) + + if config.provider != "openai": + raise ValueError( + "OpenAIEmbeddingProvider must be initialized with provider `openai`." + ) + if not os.getenv("OPENAI_API_KEY"): + raise ValueError( + "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider." + ) + self.client = OpenAI() + self.async_client = AsyncOpenAI() + + if config.rerank_model: + raise ValueError( + "OpenAIEmbeddingProvider does not support separate reranking." + ) + + if config.base_model and "openai/" in config.base_model: + self.base_model = config.base_model.split("/")[-1] + else: + self.base_model = config.base_model + self.base_dimension = config.base_dimension + + if not self.base_model: + raise ValueError( + "Must set base_model in order to initialize OpenAIEmbeddingProvider." + ) + + if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER: + raise ValueError( + f"OpenAI embedding model {self.base_model} not supported." + ) + + if self.base_dimension: + if ( + self.base_dimension + not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[ + self.base_model + ] + ): + raise ValueError( + f"Dimensions {self.base_dimension} for {self.base_model} are not supported" + ) + else: + # If base_dimension is not set, use the largest available dimension for the model + self.base_dimension = max( + OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model] + ) + + def _get_dimensions(self): + return ( + NOT_GIVEN + if self.base_model == "text-embedding-ada-002" + else self.base_dimension + or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model][-1] + ) + + def _get_embedding_kwargs(self, **kwargs): + return { + "model": self.base_model, + "dimensions": self._get_dimensions(), + } | kwargs + + async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]: + texts = task["texts"] + kwargs = self._get_embedding_kwargs(**task.get("kwargs", {})) + + try: + response = await self.async_client.embeddings.create( + input=texts, + **kwargs, + ) + return [data.embedding for data in response.data] + except AuthenticationError as e: + raise ValueError( + "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable." + ) from e + except Exception as e: + error_msg = f"Error getting embeddings: {str(e)}" + logger.error(error_msg) + raise ValueError(error_msg) from e + + def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]: + texts = task["texts"] + kwargs = self._get_embedding_kwargs(**task.get("kwargs", {})) + try: + response = self.client.embeddings.create( + input=texts, + **kwargs, + ) + return [data.embedding for data in response.data] + except AuthenticationError as e: + raise ValueError( + "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable." + ) from e + except Exception as e: + error_msg = f"Error getting embeddings: {str(e)}" + logger.error(error_msg) + raise ValueError(error_msg) from e + + async def async_get_embedding( + self, + text: str, + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[float]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OpenAIEmbeddingProvider only supports search stage." + ) + + task = { + "texts": [text], + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + result = await self._execute_with_backoff_async(task) + return result[0] + + def get_embedding( + self, + text: str, + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[float]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OpenAIEmbeddingProvider only supports search stage." + ) + + task = { + "texts": [text], + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + result = self._execute_with_backoff_sync(task) + return result[0] + + async def async_get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[list[float]]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OpenAIEmbeddingProvider only supports search stage." + ) + + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + return await self._execute_with_backoff_async(task) + + def get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[list[float]]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OpenAIEmbeddingProvider only supports search stage." + ) + + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + return self._execute_with_backoff_sync(task) + + def rerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK, + limit: int = 10, + ): + return results[:limit] + + async def arerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK, + limit: int = 10, + ): + return results[:limit] + + def tokenize_string(self, text: str, model: str) -> list[int]: + if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER: + raise ValueError(f"OpenAI embedding model {model} not supported.") + encoding = tiktoken.get_encoding( + OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model] + ) + return encoding.encode(text) |