aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/providers/embeddings/ollama/ollama_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/providers/embeddings/ollama/ollama_base.py')
-rwxr-xr-xR2R/r2r/providers/embeddings/ollama/ollama_base.py156
1 files changed, 156 insertions, 0 deletions
diff --git a/R2R/r2r/providers/embeddings/ollama/ollama_base.py b/R2R/r2r/providers/embeddings/ollama/ollama_base.py
new file mode 100755
index 00000000..31a8c717
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/ollama/ollama_base.py
@@ -0,0 +1,156 @@
+import asyncio
+import logging
+import os
+import random
+from typing import Any
+
+from ollama import AsyncClient, Client
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class OllamaEmbeddingProvider(EmbeddingProvider):
+ def __init__(self, config: EmbeddingConfig):
+ super().__init__(config)
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize `OllamaEmbeddingProvider`."
+ )
+ if provider != "ollama":
+ raise ValueError(
+ "OllamaEmbeddingProvider must be initialized with provider `ollama`."
+ )
+ if config.rerank_model:
+ raise ValueError(
+ "OllamaEmbeddingProvider does not support separate reranking."
+ )
+
+ self.base_model = config.base_model
+ self.base_dimension = config.base_dimension
+ self.base_url = os.getenv("OLLAMA_API_BASE")
+ logger.info(
+ f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}"
+ )
+ self.client = Client(host=self.base_url)
+ self.aclient = AsyncClient(host=self.base_url)
+
+ self.request_queue = asyncio.Queue()
+ self.max_retries = 2
+ self.initial_backoff = 1
+ self.max_backoff = 60
+ self.concurrency_limit = 10
+ self.semaphore = asyncio.Semaphore(self.concurrency_limit)
+
+ async def process_queue(self):
+ while True:
+ task = await self.request_queue.get()
+ try:
+ result = await self.execute_task_with_backoff(task)
+ task["future"].set_result(result)
+ except Exception as e:
+ task["future"].set_exception(e)
+ finally:
+ self.request_queue.task_done()
+
+ async def execute_task_with_backoff(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.initial_backoff
+ while retries < self.max_retries:
+ try:
+ async with self.semaphore:
+ response = await asyncio.wait_for(
+ self.aclient.embeddings(
+ prompt=task["text"], model=self.base_model
+ ),
+ timeout=30,
+ )
+ return response["embedding"]
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.max_retries:
+ raise Exception(
+ f"Max retries reached. Last error: {str(e)}"
+ )
+ await asyncio.sleep(backoff + random.uniform(0, 1))
+ backoff = min(backoff * 2, self.max_backoff)
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ try:
+ response = self.client.embeddings(
+ prompt=text, model=self.base_model
+ )
+ return response["embedding"]
+ except Exception as e:
+ logger.error(f"Error getting embedding: {str(e)}")
+ raise
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ return [self.get_embedding(text, stage) for text in texts]
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ queue_processor = asyncio.create_task(self.process_queue())
+ futures = []
+ for text in texts:
+ future = asyncio.Future()
+ await self.request_queue.put({"text": text, "future": future})
+ futures.append(future)
+
+ try:
+ results = await asyncio.gather(*futures, return_exceptions=True)
+ # Check if any result is an exception and raise it
+ exceptions = set([r for r in results if isinstance(r, Exception)])
+ if exceptions:
+ raise Exception(
+ f"Embedding generation failed for one or more embeddings."
+ )
+ return results
+ except Exception as e:
+ logger.error(f"Embedding generation failed: {str(e)}")
+ raise
+ finally:
+ await self.request_queue.join()
+ queue_processor.cancel()
+
+ def rerank(
+ self,
+ query: str,
+ results: list[VectorSearchResult],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+ limit: int = 10,
+ ) -> list[VectorSearchResult]:
+ return results[:limit]
+
+ def tokenize_string(
+ self, text: str, model: str, stage: EmbeddingProvider.PipeStage
+ ) -> list[int]:
+ raise NotImplementedError(
+ "Tokenization is not supported by OllamaEmbeddingProvider."
+ )