aboutsummaryrefslogtreecommitdiff
import asyncio
import logging
import os
import random
from typing import Any

from ollama import AsyncClient, Client

from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult

logger = logging.getLogger(__name__)


class OllamaEmbeddingProvider(EmbeddingProvider):
    def __init__(self, config: EmbeddingConfig):
        super().__init__(config)
        provider = config.provider
        if not provider:
            raise ValueError(
                "Must set provider in order to initialize `OllamaEmbeddingProvider`."
            )
        if provider != "ollama":
            raise ValueError(
                "OllamaEmbeddingProvider must be initialized with provider `ollama`."
            )
        if config.rerank_model:
            raise ValueError(
                "OllamaEmbeddingProvider does not support separate reranking."
            )

        self.base_model = config.base_model
        self.base_dimension = config.base_dimension
        self.base_url = os.getenv("OLLAMA_API_BASE")
        logger.info(
            f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}"
        )
        self.client = Client(host=self.base_url)
        self.aclient = AsyncClient(host=self.base_url)

        self.request_queue = asyncio.Queue()
        self.max_retries = 2
        self.initial_backoff = 1
        self.max_backoff = 60
        self.concurrency_limit = 10
        self.semaphore = asyncio.Semaphore(self.concurrency_limit)

    async def process_queue(self):
        while True:
            task = await self.request_queue.get()
            try:
                result = await self.execute_task_with_backoff(task)
                task["future"].set_result(result)
            except Exception as e:
                task["future"].set_exception(e)
            finally:
                self.request_queue.task_done()

    async def execute_task_with_backoff(self, task: dict[str, Any]):
        retries = 0
        backoff = self.initial_backoff
        while retries < self.max_retries:
            try:
                async with self.semaphore:
                    response = await asyncio.wait_for(
                        self.aclient.embeddings(
                            prompt=task["text"], model=self.base_model
                        ),
                        timeout=30,
                    )
                return response["embedding"]
            except Exception as e:
                logger.warning(
                    f"Request failed (attempt {retries + 1}): {str(e)}"
                )
                retries += 1
                if retries == self.max_retries:
                    raise Exception(
                        f"Max retries reached. Last error: {str(e)}"
                    )
                await asyncio.sleep(backoff + random.uniform(0, 1))
                backoff = min(backoff * 2, self.max_backoff)

    def get_embedding(
        self,
        text: str,
        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
    ) -> list[float]:
        if stage != EmbeddingProvider.PipeStage.BASE:
            raise ValueError(
                "OllamaEmbeddingProvider only supports search stage."
            )

        try:
            response = self.client.embeddings(
                prompt=text, model=self.base_model
            )
            return response["embedding"]
        except Exception as e:
            logger.error(f"Error getting embedding: {str(e)}")
            raise

    def get_embeddings(
        self,
        texts: list[str],
        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
    ) -> list[list[float]]:
        return [self.get_embedding(text, stage) for text in texts]

    async def async_get_embeddings(
        self,
        texts: list[str],
        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
    ) -> list[list[float]]:
        if stage != EmbeddingProvider.PipeStage.BASE:
            raise ValueError(
                "OllamaEmbeddingProvider only supports search stage."
            )

        queue_processor = asyncio.create_task(self.process_queue())
        futures = []
        for text in texts:
            future = asyncio.Future()
            await self.request_queue.put({"text": text, "future": future})
            futures.append(future)

        try:
            results = await asyncio.gather(*futures, return_exceptions=True)
            # Check if any result is an exception and raise it
            exceptions = set([r for r in results if isinstance(r, Exception)])
            if exceptions:
                raise Exception(
                    f"Embedding generation failed for one or more embeddings."
                )
            return results
        except Exception as e:
            logger.error(f"Embedding generation failed: {str(e)}")
            raise
        finally:
            await self.request_queue.join()
            queue_processor.cancel()

    def rerank(
        self,
        query: str,
        results: list[VectorSearchResult],
        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
        limit: int = 10,
    ) -> list[VectorSearchResult]:
        return results[:limit]

    def tokenize_string(
        self, text: str, model: str, stage: EmbeddingProvider.PipeStage
    ) -> list[int]:
        raise NotImplementedError(
            "Tokenization is not supported by OllamaEmbeddingProvider."
        )