diff options
Diffstat (limited to 'R2R/r2r/pipelines/search_pipeline.py')
-rwxr-xr-x | R2R/r2r/pipelines/search_pipeline.py | 140 |
1 files changed, 140 insertions, 0 deletions
diff --git a/R2R/r2r/pipelines/search_pipeline.py b/R2R/r2r/pipelines/search_pipeline.py new file mode 100755 index 00000000..25e0c7bb --- /dev/null +++ b/R2R/r2r/pipelines/search_pipeline.py @@ -0,0 +1,140 @@ +import asyncio +import logging +from asyncio import Queue +from typing import Any, Optional + +from ..base.abstractions.search import ( + AggregateSearchResult, + KGSearchSettings, + VectorSearchSettings, +) +from ..base.logging.kv_logger import KVLoggingSingleton +from ..base.logging.run_manager import RunManager, manage_run +from ..base.pipeline.base_pipeline import AsyncPipeline, dequeue_requests +from ..base.pipes.base_pipe import AsyncPipe, AsyncState + +logger = logging.getLogger(__name__) + + +class SearchPipeline(AsyncPipeline): + """A pipeline for search.""" + + pipeline_type: str = "search" + + def __init__( + self, + pipe_logger: Optional[KVLoggingSingleton] = None, + run_manager: Optional[RunManager] = None, + ): + super().__init__(pipe_logger, run_manager) + self._parsing_pipe = None + self._vector_search_pipeline = None + self._kg_search_pipeline = None + + async def run( + self, + input: Any, + state: Optional[AsyncState] = None, + stream: bool = False, + run_manager: Optional[RunManager] = None, + log_run_info: bool = True, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + kg_search_settings: KGSearchSettings = KGSearchSettings(), + *args: Any, + **kwargs: Any, + ): + self.state = state or AsyncState() + do_vector_search = ( + self._vector_search_pipeline is not None + and vector_search_settings.use_vector_search + ) + do_kg = ( + self._kg_search_pipeline is not None + and kg_search_settings.use_kg_search + ) + async with manage_run(run_manager, self.pipeline_type): + if log_run_info: + await run_manager.log_run_info( + key="pipeline_type", + value=self.pipeline_type, + is_info_log=True, + ) + + vector_search_queue = Queue() + kg_queue = Queue() + + async def enqueue_requests(): + async for message in input: + if do_vector_search: + await vector_search_queue.put(message) + if do_kg: + await kg_queue.put(message) + + await vector_search_queue.put(None) + await kg_queue.put(None) + + # Start the document enqueuing process + enqueue_task = asyncio.create_task(enqueue_requests()) + + # Start the embedding and KG pipelines in parallel + if do_vector_search: + vector_search_task = asyncio.create_task( + self._vector_search_pipeline.run( + dequeue_requests(vector_search_queue), + state, + stream, + run_manager, + log_run_info=False, + vector_search_settings=vector_search_settings, + ) + ) + + if do_kg: + kg_task = asyncio.create_task( + self._kg_search_pipeline.run( + dequeue_requests(kg_queue), + state, + stream, + run_manager, + log_run_info=False, + kg_search_settings=kg_search_settings, + ) + ) + + await enqueue_task + + vector_search_results = ( + await vector_search_task if do_vector_search else None + ) + kg_results = await kg_task if do_kg else None + + return AggregateSearchResult( + vector_search_results=vector_search_results, + kg_search_results=kg_results, + ) + + def add_pipe( + self, + pipe: AsyncPipe, + add_upstream_outputs: Optional[list[dict[str, str]]] = None, + kg_pipe: bool = False, + vector_search_pipe: bool = False, + *args, + **kwargs, + ) -> None: + logger.debug(f"Adding pipe {pipe.config.name} to the SearchPipeline") + + if kg_pipe: + if not self._kg_search_pipeline: + self._kg_search_pipeline = AsyncPipeline() + self._kg_search_pipeline.add_pipe( + pipe, add_upstream_outputs, *args, **kwargs + ) + elif vector_search_pipe: + if not self._vector_search_pipeline: + self._vector_search_pipeline = AsyncPipeline() + self._vector_search_pipeline.add_pipe( + pipe, add_upstream_outputs, *args, **kwargs + ) + else: + raise ValueError("Pipe must be a vector search or KG pipe") |