diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py | 194 |
1 files changed, 194 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py new file mode 100644 index 00000000..297d9167 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py @@ -0,0 +1,194 @@ +import logging +import os +from typing import Any + +from ollama import AsyncClient, Client + +from core.base import ( + ChunkSearchResult, + EmbeddingConfig, + EmbeddingProvider, + EmbeddingPurpose, + R2RException, +) + +logger = logging.getLogger() + + +class OllamaEmbeddingProvider(EmbeddingProvider): + def __init__(self, config: EmbeddingConfig): + super().__init__(config) + provider = config.provider + if not provider: + raise ValueError( + "Must set provider in order to initialize `OllamaEmbeddingProvider`." + ) + if provider != "ollama": + raise ValueError( + "OllamaEmbeddingProvider must be initialized with provider `ollama`." + ) + if config.rerank_model: + raise ValueError( + "OllamaEmbeddingProvider does not support separate reranking." + ) + + self.base_model = config.base_model + self.base_dimension = config.base_dimension + self.base_url = os.getenv("OLLAMA_API_BASE") + logger.info( + f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}" + ) + self.client = Client(host=self.base_url) + self.aclient = AsyncClient(host=self.base_url) + + self.set_prefixes(config.prefixes or {}, self.base_model) + self.batch_size = config.batch_size or 32 + + def _get_embedding_kwargs(self, **kwargs): + embedding_kwargs = { + "model": self.base_model, + } + embedding_kwargs.update(kwargs) + return embedding_kwargs + + async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]: + texts = task["texts"] + purpose = task.get("purpose", EmbeddingPurpose.INDEX) + kwargs = self._get_embedding_kwargs(**task.get("kwargs", {})) + + try: + embeddings = [] + for i in range(0, len(texts), self.batch_size): + batch = texts[i : i + self.batch_size] + prefixed_batch = [ + self.prefixes.get(purpose, "") + text for text in batch + ] + response = await self.aclient.embed( + input=prefixed_batch, **kwargs + ) + embeddings.extend(response["embeddings"]) + return embeddings + except Exception as e: + error_msg = f"Error getting embeddings: {str(e)}" + logger.error(error_msg) + raise R2RException(error_msg, 400) from e + + def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]: + texts = task["texts"] + purpose = task.get("purpose", EmbeddingPurpose.INDEX) + kwargs = self._get_embedding_kwargs(**task.get("kwargs", {})) + + try: + embeddings = [] + for i in range(0, len(texts), self.batch_size): + batch = texts[i : i + self.batch_size] + prefixed_batch = [ + self.prefixes.get(purpose, "") + text for text in batch + ] + response = self.client.embed(input=prefixed_batch, **kwargs) + embeddings.extend(response["embeddings"]) + return embeddings + except Exception as e: + error_msg = f"Error getting embeddings: {str(e)}" + logger.error(error_msg) + raise R2RException(error_msg, 400) from e + + async def async_get_embedding( + self, + text: str, + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[float]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OllamaEmbeddingProvider only supports search stage." + ) + + task = { + "texts": [text], + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + result = await self._execute_with_backoff_async(task) + return result[0] + + def get_embedding( + self, + text: str, + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[float]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OllamaEmbeddingProvider only supports search stage." + ) + + task = { + "texts": [text], + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + result = self._execute_with_backoff_sync(task) + return result[0] + + async def async_get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[list[float]]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OllamaEmbeddingProvider only supports search stage." + ) + + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + return await self._execute_with_backoff_async(task) + + def get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[list[float]]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OllamaEmbeddingProvider only supports search stage." + ) + + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + return self._execute_with_backoff_sync(task) + + def rerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK, + limit: int = 10, + ) -> list[ChunkSearchResult]: + return results[:limit] + + async def arerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK, + limit: int = 10, + ): + return results[:limit] |