From 4a52a71956a8d46fcb7294ac71734504bb09bcc2 Mon Sep 17 00:00:00 2001 From: S. Solomon Darnell Date: Fri, 28 Mar 2025 21:52:21 -0500 Subject: two version of R2R are here --- R2R/r2r/pipelines/rag_pipeline.py | 115 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100755 R2R/r2r/pipelines/rag_pipeline.py (limited to 'R2R/r2r/pipelines/rag_pipeline.py') diff --git a/R2R/r2r/pipelines/rag_pipeline.py b/R2R/r2r/pipelines/rag_pipeline.py new file mode 100755 index 00000000..b257ccaa --- /dev/null +++ b/R2R/r2r/pipelines/rag_pipeline.py @@ -0,0 +1,115 @@ +import asyncio +import logging +from typing import Any, Optional + +from ..base.abstractions.llm import GenerationConfig +from ..base.abstractions.search import KGSearchSettings, VectorSearchSettings +from ..base.logging.kv_logger import KVLoggingSingleton +from ..base.logging.run_manager import RunManager, manage_run +from ..base.pipeline.base_pipeline import AsyncPipeline +from ..base.pipes.base_pipe import AsyncPipe, AsyncState +from ..base.utils import to_async_generator + +logger = logging.getLogger(__name__) + + +class RAGPipeline(AsyncPipeline): + """A pipeline for RAG.""" + + pipeline_type: str = "rag" + + def __init__( + self, + pipe_logger: Optional[KVLoggingSingleton] = None, + run_manager: Optional[RunManager] = None, + ): + super().__init__(pipe_logger, run_manager) + self._search_pipeline = None + self._rag_pipeline = None + + async def run( + self, + input: Any, + state: Optional[AsyncState] = None, + run_manager: Optional[RunManager] = None, + log_run_info=True, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + kg_search_settings: KGSearchSettings = KGSearchSettings(), + rag_generation_config: GenerationConfig = GenerationConfig(), + *args: Any, + **kwargs: Any, + ): + self.state = state or AsyncState() + async with manage_run(run_manager, self.pipeline_type): + if log_run_info: + await run_manager.log_run_info( + key="pipeline_type", + value=self.pipeline_type, + is_info_log=True, + ) + + if not self._search_pipeline: + raise ValueError( + "_search_pipeline must be set before running the RAG pipeline" + ) + + async def multi_query_generator(input): + tasks = [] + async for query in input: + task = asyncio.create_task( + self._search_pipeline.run( + to_async_generator([query]), + state=state, + stream=False, # do not stream the search results + run_manager=run_manager, + log_run_info=False, # do not log the run info as it is already logged above + vector_search_settings=vector_search_settings, + kg_search_settings=kg_search_settings, + *args, + **kwargs, + ) + ) + tasks.append((query, task)) + + for query, task in tasks: + yield (query, await task) + + rag_results = await self._rag_pipeline.run( + input=multi_query_generator(input), + state=state, + stream=rag_generation_config.stream, + run_manager=run_manager, + log_run_info=False, + rag_generation_config=rag_generation_config, + *args, + **kwargs, + ) + return rag_results + + def add_pipe( + self, + pipe: AsyncPipe, + add_upstream_outputs: Optional[list[dict[str, str]]] = None, + rag_pipe: bool = True, + *args, + **kwargs, + ) -> None: + logger.debug(f"Adding pipe {pipe.config.name} to the RAGPipeline") + if not rag_pipe: + raise ValueError( + "Only pipes that are part of the RAG pipeline can be added to the RAG pipeline" + ) + if not self._rag_pipeline: + self._rag_pipeline = AsyncPipeline() + self._rag_pipeline.add_pipe( + pipe, add_upstream_outputs, *args, **kwargs + ) + + def set_search_pipeline( + self, + _search_pipeline: AsyncPipeline, + *args, + **kwargs, + ) -> None: + logger.debug(f"Setting search pipeline for the RAGPipeline") + self._search_pipeline = _search_pipeline -- cgit v1.2.3