diff options
Diffstat (limited to 'R2R/r2r/providers/embeddings/ollama/ollama_base.py')
-rwxr-xr-x | R2R/r2r/providers/embeddings/ollama/ollama_base.py | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/R2R/r2r/providers/embeddings/ollama/ollama_base.py b/R2R/r2r/providers/embeddings/ollama/ollama_base.py new file mode 100755 index 00000000..31a8c717 --- /dev/null +++ b/R2R/r2r/providers/embeddings/ollama/ollama_base.py @@ -0,0 +1,156 @@ +import asyncio +import logging +import os +import random +from typing import Any + +from ollama import AsyncClient, Client + +from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult + +logger = logging.getLogger(__name__) + + +class OllamaEmbeddingProvider(EmbeddingProvider): + def __init__(self, config: EmbeddingConfig): + super().__init__(config) + provider = config.provider + if not provider: + raise ValueError( + "Must set provider in order to initialize `OllamaEmbeddingProvider`." + ) + if provider != "ollama": + raise ValueError( + "OllamaEmbeddingProvider must be initialized with provider `ollama`." + ) + if config.rerank_model: + raise ValueError( + "OllamaEmbeddingProvider does not support separate reranking." + ) + + self.base_model = config.base_model + self.base_dimension = config.base_dimension + self.base_url = os.getenv("OLLAMA_API_BASE") + logger.info( + f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}" + ) + self.client = Client(host=self.base_url) + self.aclient = AsyncClient(host=self.base_url) + + self.request_queue = asyncio.Queue() + self.max_retries = 2 + self.initial_backoff = 1 + self.max_backoff = 60 + self.concurrency_limit = 10 + self.semaphore = asyncio.Semaphore(self.concurrency_limit) + + async def process_queue(self): + while True: + task = await self.request_queue.get() + try: + result = await self.execute_task_with_backoff(task) + task["future"].set_result(result) + except Exception as e: + task["future"].set_exception(e) + finally: + self.request_queue.task_done() + + async def execute_task_with_backoff(self, task: dict[str, Any]): + retries = 0 + backoff = self.initial_backoff + while retries < self.max_retries: + try: + async with self.semaphore: + response = await asyncio.wait_for( + self.aclient.embeddings( + prompt=task["text"], model=self.base_model + ), + timeout=30, + ) + return response["embedding"] + except Exception as e: + logger.warning( + f"Request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.max_retries: + raise Exception( + f"Max retries reached. Last error: {str(e)}" + ) + await asyncio.sleep(backoff + random.uniform(0, 1)) + backoff = min(backoff * 2, self.max_backoff) + + def get_embedding( + self, + text: str, + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[float]: + if stage != EmbeddingProvider.PipeStage.BASE: + raise ValueError( + "OllamaEmbeddingProvider only supports search stage." + ) + + try: + response = self.client.embeddings( + prompt=text, model=self.base_model + ) + return response["embedding"] + except Exception as e: + logger.error(f"Error getting embedding: {str(e)}") + raise + + def get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[list[float]]: + return [self.get_embedding(text, stage) for text in texts] + + async def async_get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[list[float]]: + if stage != EmbeddingProvider.PipeStage.BASE: + raise ValueError( + "OllamaEmbeddingProvider only supports search stage." + ) + + queue_processor = asyncio.create_task(self.process_queue()) + futures = [] + for text in texts: + future = asyncio.Future() + await self.request_queue.put({"text": text, "future": future}) + futures.append(future) + + try: + results = await asyncio.gather(*futures, return_exceptions=True) + # Check if any result is an exception and raise it + exceptions = set([r for r in results if isinstance(r, Exception)]) + if exceptions: + raise Exception( + f"Embedding generation failed for one or more embeddings." + ) + return results + except Exception as e: + logger.error(f"Embedding generation failed: {str(e)}") + raise + finally: + await self.request_queue.join() + queue_processor.cancel() + + def rerank( + self, + query: str, + results: list[VectorSearchResult], + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK, + limit: int = 10, + ) -> list[VectorSearchResult]: + return results[:limit] + + def tokenize_string( + self, text: str, model: str, stage: EmbeddingProvider.PipeStage + ) -> list[int]: + raise NotImplementedError( + "Tokenization is not supported by OllamaEmbeddingProvider." + ) |