aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py194
1 files changed, 194 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py
new file mode 100644
index 00000000..297d9167
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py
@@ -0,0 +1,194 @@
+import logging
+import os
+from typing import Any
+
+from ollama import AsyncClient, Client
+
+from core.base import (
+ ChunkSearchResult,
+ EmbeddingConfig,
+ EmbeddingProvider,
+ EmbeddingPurpose,
+ R2RException,
+)
+
+logger = logging.getLogger()
+
+
+class OllamaEmbeddingProvider(EmbeddingProvider):
+ def __init__(self, config: EmbeddingConfig):
+ super().__init__(config)
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize `OllamaEmbeddingProvider`."
+ )
+ if provider != "ollama":
+ raise ValueError(
+ "OllamaEmbeddingProvider must be initialized with provider `ollama`."
+ )
+ if config.rerank_model:
+ raise ValueError(
+ "OllamaEmbeddingProvider does not support separate reranking."
+ )
+
+ self.base_model = config.base_model
+ self.base_dimension = config.base_dimension
+ self.base_url = os.getenv("OLLAMA_API_BASE")
+ logger.info(
+ f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}"
+ )
+ self.client = Client(host=self.base_url)
+ self.aclient = AsyncClient(host=self.base_url)
+
+ self.set_prefixes(config.prefixes or {}, self.base_model)
+ self.batch_size = config.batch_size or 32
+
+ def _get_embedding_kwargs(self, **kwargs):
+ embedding_kwargs = {
+ "model": self.base_model,
+ }
+ embedding_kwargs.update(kwargs)
+ return embedding_kwargs
+
+ async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ purpose = task.get("purpose", EmbeddingPurpose.INDEX)
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+
+ try:
+ embeddings = []
+ for i in range(0, len(texts), self.batch_size):
+ batch = texts[i : i + self.batch_size]
+ prefixed_batch = [
+ self.prefixes.get(purpose, "") + text for text in batch
+ ]
+ response = await self.aclient.embed(
+ input=prefixed_batch, **kwargs
+ )
+ embeddings.extend(response["embeddings"])
+ return embeddings
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise R2RException(error_msg, 400) from e
+
+ def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ purpose = task.get("purpose", EmbeddingPurpose.INDEX)
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+
+ try:
+ embeddings = []
+ for i in range(0, len(texts), self.batch_size):
+ batch = texts[i : i + self.batch_size]
+ prefixed_batch = [
+ self.prefixes.get(purpose, "") + text for text in batch
+ ]
+ response = self.client.embed(input=prefixed_batch, **kwargs)
+ embeddings.extend(response["embeddings"])
+ return embeddings
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise R2RException(error_msg, 400) from e
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ result = await self._execute_with_backoff_async(task)
+ return result[0]
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ result = self._execute_with_backoff_sync(task)
+ return result[0]
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ def rerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ) -> list[ChunkSearchResult]:
+ return results[:limit]
+
+ async def arerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ):
+ return results[:limit]