aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipelines/rag_pipeline.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/pipelines/rag_pipeline.py')
-rwxr-xr-xR2R/r2r/pipelines/rag_pipeline.py115
1 files changed, 115 insertions, 0 deletions
diff --git a/R2R/r2r/pipelines/rag_pipeline.py b/R2R/r2r/pipelines/rag_pipeline.py
new file mode 100755
index 00000000..b257ccaa
--- /dev/null
+++ b/R2R/r2r/pipelines/rag_pipeline.py
@@ -0,0 +1,115 @@
+import asyncio
+import logging
+from typing import Any, Optional
+
+from ..base.abstractions.llm import GenerationConfig
+from ..base.abstractions.search import KGSearchSettings, VectorSearchSettings
+from ..base.logging.kv_logger import KVLoggingSingleton
+from ..base.logging.run_manager import RunManager, manage_run
+from ..base.pipeline.base_pipeline import AsyncPipeline
+from ..base.pipes.base_pipe import AsyncPipe, AsyncState
+from ..base.utils import to_async_generator
+
+logger = logging.getLogger(__name__)
+
+
+class RAGPipeline(AsyncPipeline):
+ """A pipeline for RAG."""
+
+ pipeline_type: str = "rag"
+
+ def __init__(
+ self,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ run_manager: Optional[RunManager] = None,
+ ):
+ super().__init__(pipe_logger, run_manager)
+ self._search_pipeline = None
+ self._rag_pipeline = None
+
+ async def run(
+ self,
+ input: Any,
+ state: Optional[AsyncState] = None,
+ run_manager: Optional[RunManager] = None,
+ log_run_info=True,
+ vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
+ kg_search_settings: KGSearchSettings = KGSearchSettings(),
+ rag_generation_config: GenerationConfig = GenerationConfig(),
+ *args: Any,
+ **kwargs: Any,
+ ):
+ self.state = state or AsyncState()
+ async with manage_run(run_manager, self.pipeline_type):
+ if log_run_info:
+ await run_manager.log_run_info(
+ key="pipeline_type",
+ value=self.pipeline_type,
+ is_info_log=True,
+ )
+
+ if not self._search_pipeline:
+ raise ValueError(
+ "_search_pipeline must be set before running the RAG pipeline"
+ )
+
+ async def multi_query_generator(input):
+ tasks = []
+ async for query in input:
+ task = asyncio.create_task(
+ self._search_pipeline.run(
+ to_async_generator([query]),
+ state=state,
+ stream=False, # do not stream the search results
+ run_manager=run_manager,
+ log_run_info=False, # do not log the run info as it is already logged above
+ vector_search_settings=vector_search_settings,
+ kg_search_settings=kg_search_settings,
+ *args,
+ **kwargs,
+ )
+ )
+ tasks.append((query, task))
+
+ for query, task in tasks:
+ yield (query, await task)
+
+ rag_results = await self._rag_pipeline.run(
+ input=multi_query_generator(input),
+ state=state,
+ stream=rag_generation_config.stream,
+ run_manager=run_manager,
+ log_run_info=False,
+ rag_generation_config=rag_generation_config,
+ *args,
+ **kwargs,
+ )
+ return rag_results
+
+ def add_pipe(
+ self,
+ pipe: AsyncPipe,
+ add_upstream_outputs: Optional[list[dict[str, str]]] = None,
+ rag_pipe: bool = True,
+ *args,
+ **kwargs,
+ ) -> None:
+ logger.debug(f"Adding pipe {pipe.config.name} to the RAGPipeline")
+ if not rag_pipe:
+ raise ValueError(
+ "Only pipes that are part of the RAG pipeline can be added to the RAG pipeline"
+ )
+ if not self._rag_pipeline:
+ self._rag_pipeline = AsyncPipeline()
+ self._rag_pipeline.add_pipe(
+ pipe, add_upstream_outputs, *args, **kwargs
+ )
+
+ def set_search_pipeline(
+ self,
+ _search_pipeline: AsyncPipeline,
+ *args,
+ **kwargs,
+ ) -> None:
+ logger.debug(f"Setting search pipeline for the RAGPipeline")
+ self._search_pipeline = _search_pipeline