aboutsummaryrefslogtreecommitdiff
import logging
import time
import uuid
from typing import Optional

from r2r.base import (
    GenerationConfig,
    KGSearchSettings,
    KVLoggingSingleton,
    R2RException,
    RunManager,
    VectorSearchSettings,
    manage_run,
    to_async_generator,
)
from r2r.pipes import EvalPipe
from r2r.telemetry.telemetry_decorator import telemetry_event

from ..abstractions import R2RPipelines, R2RProviders
from ..assembly.config import R2RConfig
from .base import Service

logger = logging.getLogger(__name__)


class RetrievalService(Service):
    def __init__(
        self,
        config: R2RConfig,
        providers: R2RProviders,
        pipelines: R2RPipelines,
        run_manager: RunManager,
        logging_connection: KVLoggingSingleton,
    ):
        super().__init__(
            config, providers, pipelines, run_manager, logging_connection
        )

    @telemetry_event("Search")
    async def search(
        self,
        query: str,
        vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
        kg_search_settings: KGSearchSettings = KGSearchSettings(),
        *args,
        **kwargs,
    ):
        async with manage_run(self.run_manager, "search_app") as run_id:
            t0 = time.time()

            if (
                kg_search_settings.use_kg_search
                and self.config.kg.provider is None
            ):
                raise R2RException(
                    status_code=400,
                    message="Knowledge Graph search is not enabled in the configuration.",
                )

            if (
                vector_search_settings.use_vector_search
                and self.config.vector_database.provider is None
            ):
                raise R2RException(
                    status_code=400,
                    message="Vector search is not enabled in the configuration.",
                )

            # TODO - Remove these transforms once we have a better way to handle this
            for filter, value in vector_search_settings.search_filters.items():
                if isinstance(value, uuid.UUID):
                    vector_search_settings.search_filters[filter] = str(value)

            results = await self.pipelines.search_pipeline.run(
                input=to_async_generator([query]),
                vector_search_settings=vector_search_settings,
                kg_search_settings=kg_search_settings,
                run_manager=self.run_manager,
                *args,
                **kwargs,
            )

            t1 = time.time()
            latency = f"{t1 - t0:.2f}"

            await self.logging_connection.log(
                log_id=run_id,
                key="search_latency",
                value=latency,
                is_info_log=False,
            )

            return results.dict()

    @telemetry_event("RAG")
    async def rag(
        self,
        query: str,
        rag_generation_config: GenerationConfig,
        vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
        kg_search_settings: KGSearchSettings = KGSearchSettings(),
        *args,
        **kwargs,
    ):
        async with manage_run(self.run_manager, "rag_app") as run_id:
            try:
                t0 = time.time()

                # TODO - Remove these transforms once we have a better way to handle this
                for (
                    filter,
                    value,
                ) in vector_search_settings.search_filters.items():
                    if isinstance(value, uuid.UUID):
                        vector_search_settings.search_filters[filter] = str(
                            value
                        )

                if rag_generation_config.stream:
                    t1 = time.time()
                    latency = f"{t1 - t0:.2f}"

                    await self.logging_connection.log(
                        log_id=run_id,
                        key="rag_generation_latency",
                        value=latency,
                        is_info_log=False,
                    )

                    async def stream_response():
                        async with manage_run(self.run_manager, "arag"):
                            async for (
                                chunk
                            ) in await self.pipelines.streaming_rag_pipeline.run(
                                input=to_async_generator([query]),
                                run_manager=self.run_manager,
                                vector_search_settings=vector_search_settings,
                                kg_search_settings=kg_search_settings,
                                rag_generation_config=rag_generation_config,
                            ):
                                yield chunk

                    return stream_response()

                results = await self.pipelines.rag_pipeline.run(
                    input=to_async_generator([query]),
                    run_manager=self.run_manager,
                    vector_search_settings=vector_search_settings,
                    kg_search_settings=kg_search_settings,
                    rag_generation_config=rag_generation_config,
                    *args,
                    **kwargs,
                )

                t1 = time.time()
                latency = f"{t1 - t0:.2f}"

                await self.logging_connection.log(
                    log_id=run_id,
                    key="rag_generation_latency",
                    value=latency,
                    is_info_log=False,
                )

                if len(results) == 0:
                    raise R2RException(
                        status_code=404, message="No results found"
                    )
                if len(results) > 1:
                    logger.warning(
                        f"Multiple results found for query: {query}"
                    )
                # unpack the first result
                return results[0]

            except Exception as e:
                logger.error(f"Pipeline error: {str(e)}")
                if "NoneType" in str(e):
                    raise R2RException(
                        status_code=502,
                        message="Ollama server not reachable or returned an invalid response",
                    )
                raise R2RException(
                    status_code=500, message="Internal Server Error"
                )

    @telemetry_event("Evaluate")
    async def evaluate(
        self,
        query: str,
        context: str,
        completion: str,
        eval_generation_config: Optional[GenerationConfig],
        *args,
        **kwargs,
    ):
        eval_payload = EvalPipe.EvalPayload(
            query=query,
            context=context,
            completion=completion,
        )
        result = await self.eval_pipeline.run(
            input=to_async_generator([eval_payload]),
            run_manager=self.run_manager,
            eval_generation_config=eval_generation_config,
        )
        return result