aboutsummaryrefslogtreecommitdiff
import asyncio
import logging
from typing import Any, Optional

from ..base.abstractions.llm import GenerationConfig
from ..base.abstractions.search import KGSearchSettings, VectorSearchSettings
from ..base.logging.kv_logger import KVLoggingSingleton
from ..base.logging.run_manager import RunManager, manage_run
from ..base.pipeline.base_pipeline import AsyncPipeline
from ..base.pipes.base_pipe import AsyncPipe, AsyncState
from ..base.utils import to_async_generator

logger = logging.getLogger(__name__)


class RAGPipeline(AsyncPipeline):
    """A pipeline for RAG."""

    pipeline_type: str = "rag"

    def __init__(
        self,
        pipe_logger: Optional[KVLoggingSingleton] = None,
        run_manager: Optional[RunManager] = None,
    ):
        super().__init__(pipe_logger, run_manager)
        self._search_pipeline = None
        self._rag_pipeline = None

    async def run(
        self,
        input: Any,
        state: Optional[AsyncState] = None,
        run_manager: Optional[RunManager] = None,
        log_run_info=True,
        vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
        kg_search_settings: KGSearchSettings = KGSearchSettings(),
        rag_generation_config: GenerationConfig = GenerationConfig(),
        *args: Any,
        **kwargs: Any,
    ):
        self.state = state or AsyncState()
        async with manage_run(run_manager, self.pipeline_type):
            if log_run_info:
                await run_manager.log_run_info(
                    key="pipeline_type",
                    value=self.pipeline_type,
                    is_info_log=True,
                )

            if not self._search_pipeline:
                raise ValueError(
                    "_search_pipeline must be set before running the RAG pipeline"
                )

            async def multi_query_generator(input):
                tasks = []
                async for query in input:
                    task = asyncio.create_task(
                        self._search_pipeline.run(
                            to_async_generator([query]),
                            state=state,
                            stream=False,  # do not stream the search results
                            run_manager=run_manager,
                            log_run_info=False,  # do not log the run info as it is already logged above
                            vector_search_settings=vector_search_settings,
                            kg_search_settings=kg_search_settings,
                            *args,
                            **kwargs,
                        )
                    )
                    tasks.append((query, task))

                for query, task in tasks:
                    yield (query, await task)

            rag_results = await self._rag_pipeline.run(
                input=multi_query_generator(input),
                state=state,
                stream=rag_generation_config.stream,
                run_manager=run_manager,
                log_run_info=False,
                rag_generation_config=rag_generation_config,
                *args,
                **kwargs,
            )
            return rag_results

    def add_pipe(
        self,
        pipe: AsyncPipe,
        add_upstream_outputs: Optional[list[dict[str, str]]] = None,
        rag_pipe: bool = True,
        *args,
        **kwargs,
    ) -> None:
        logger.debug(f"Adding pipe {pipe.config.name} to the RAGPipeline")
        if not rag_pipe:
            raise ValueError(
                "Only pipes that are part of the RAG pipeline can be added to the RAG pipeline"
            )
        if not self._rag_pipeline:
            self._rag_pipeline = AsyncPipeline()
        self._rag_pipeline.add_pipe(
            pipe, add_upstream_outputs, *args, **kwargs
        )

    def set_search_pipeline(
        self,
        _search_pipeline: AsyncPipeline,
        *args,
        **kwargs,
    ) -> None:
        logger.debug(f"Setting search pipeline for the RAGPipeline")
        self._search_pipeline = _search_pipeline