diff options
Diffstat (limited to 'R2R/r2r/base/pipeline/base_pipeline.py')
-rwxr-xr-x | R2R/r2r/base/pipeline/base_pipeline.py | 233 |
1 files changed, 233 insertions, 0 deletions
diff --git a/R2R/r2r/base/pipeline/base_pipeline.py b/R2R/r2r/base/pipeline/base_pipeline.py new file mode 100755 index 00000000..3c1eff9a --- /dev/null +++ b/R2R/r2r/base/pipeline/base_pipeline.py @@ -0,0 +1,233 @@ +"""Base pipeline class for running a sequence of pipes.""" + +import asyncio +import logging +from enum import Enum +from typing import Any, AsyncGenerator, Optional + +from ..logging.kv_logger import KVLoggingSingleton +from ..logging.run_manager import RunManager, manage_run +from ..pipes.base_pipe import AsyncPipe, AsyncState + +logger = logging.getLogger(__name__) + + +class PipelineTypes(Enum): + EVAL = "eval" + INGESTION = "ingestion" + SEARCH = "search" + RAG = "rag" + OTHER = "other" + + +class AsyncPipeline: + """Pipeline class for running a sequence of pipes.""" + + pipeline_type: str = "other" + + def __init__( + self, + pipe_logger: Optional[KVLoggingSingleton] = None, + run_manager: Optional[RunManager] = None, + ): + self.pipes: list[AsyncPipe] = [] + self.upstream_outputs: list[list[dict[str, str]]] = [] + self.pipe_logger = pipe_logger or KVLoggingSingleton() + self.run_manager = run_manager or RunManager(self.pipe_logger) + self.futures = {} + self.level = 0 + + def add_pipe( + self, + pipe: AsyncPipe, + add_upstream_outputs: Optional[list[dict[str, str]]] = None, + *args, + **kwargs, + ) -> None: + """Add a pipe to the pipeline.""" + self.pipes.append(pipe) + if not add_upstream_outputs: + add_upstream_outputs = [] + self.upstream_outputs.append(add_upstream_outputs) + + async def run( + self, + input: Any, + state: Optional[AsyncState] = None, + stream: bool = False, + run_manager: Optional[RunManager] = None, + log_run_info: bool = True, + *args: Any, + **kwargs: Any, + ): + """Run the pipeline.""" + run_manager = run_manager or self.run_manager + + try: + PipelineTypes(self.pipeline_type) + except ValueError: + raise ValueError( + f"Invalid pipeline type: {self.pipeline_type}, must be one of {PipelineTypes.__members__.keys()}" + ) + + self.state = state or AsyncState() + current_input = input + async with manage_run(run_manager, self.pipeline_type): + if log_run_info: + await run_manager.log_run_info( + key="pipeline_type", + value=self.pipeline_type, + is_info_log=True, + ) + try: + for pipe_num in range(len(self.pipes)): + config_name = self.pipes[pipe_num].config.name + self.futures[config_name] = asyncio.Future() + + current_input = self._run_pipe( + pipe_num, + current_input, + run_manager, + *args, + **kwargs, + ) + self.futures[config_name].set_result(current_input) + if not stream: + final_result = await self._consume_all(current_input) + return final_result + else: + return current_input + except Exception as error: + logger.error(f"Pipeline failed with error: {error}") + raise error + + async def _consume_all(self, gen: AsyncGenerator) -> list[Any]: + result = [] + async for item in gen: + if hasattr( + item, "__aiter__" + ): # Check if the item is an async generator + sub_result = await self._consume_all(item) + result.extend(sub_result) + else: + result.append(item) + return result + + async def _run_pipe( + self, + pipe_num: int, + input: Any, + run_manager: RunManager, + *args: Any, + **kwargs: Any, + ): + # Collect inputs, waiting for the necessary futures + pipe = self.pipes[pipe_num] + add_upstream_outputs = self.sort_upstream_outputs( + self.upstream_outputs[pipe_num] + ) + input_dict = {"message": input} + + # Group upstream outputs by prev_pipe_name + grouped_upstream_outputs = {} + for upstream_input in add_upstream_outputs: + upstream_pipe_name = upstream_input["prev_pipe_name"] + if upstream_pipe_name not in grouped_upstream_outputs: + grouped_upstream_outputs[upstream_pipe_name] = [] + grouped_upstream_outputs[upstream_pipe_name].append(upstream_input) + + for ( + upstream_pipe_name, + upstream_inputs, + ) in grouped_upstream_outputs.items(): + + async def resolve_future_output(future): + result = future.result() + # consume the async generator + return [item async for item in result] + + async def replay_items_as_async_gen(items): + for item in items: + yield item + + temp_results = await resolve_future_output( + self.futures[upstream_pipe_name] + ) + if upstream_pipe_name == self.pipes[pipe_num - 1].config.name: + input_dict["message"] = replay_items_as_async_gen(temp_results) + + for upstream_input in upstream_inputs: + outputs = await self.state.get(upstream_pipe_name, "output") + prev_output_field = upstream_input.get( + "prev_output_field", None + ) + if not prev_output_field: + raise ValueError( + "`prev_output_field` must be specified in the upstream_input" + ) + input_dict[upstream_input["input_field"]] = outputs[ + prev_output_field + ] + + # Handle the pipe generator + async for ele in await pipe.run( + pipe.Input(**input_dict), + self.state, + run_manager, + *args, + **kwargs, + ): + yield ele + + def sort_upstream_outputs( + self, add_upstream_outputs: list[dict[str, str]] + ) -> list[dict[str, str]]: + pipe_name_to_index = { + pipe.config.name: index for index, pipe in enumerate(self.pipes) + } + + def get_pipe_index(upstream_output): + return pipe_name_to_index[upstream_output["prev_pipe_name"]] + + sorted_outputs = sorted( + add_upstream_outputs, key=get_pipe_index, reverse=True + ) + return sorted_outputs + + +class EvalPipeline(AsyncPipeline): + """A pipeline for evaluation.""" + + pipeline_type: str = "eval" + + async def run( + self, + input: Any, + state: Optional[AsyncState] = None, + stream: bool = False, + run_manager: Optional[RunManager] = None, + *args: Any, + **kwargs: Any, + ): + return await super().run( + input, state, stream, run_manager, *args, **kwargs + ) + + def add_pipe( + self, + pipe: AsyncPipe, + add_upstream_outputs: Optional[list[dict[str, str]]] = None, + *args, + **kwargs, + ) -> None: + logger.debug(f"Adding pipe {pipe.config.name} to the EvalPipeline") + return super().add_pipe(pipe, add_upstream_outputs, *args, **kwargs) + + +async def dequeue_requests(queue: asyncio.Queue) -> AsyncGenerator: + """Create an async generator to dequeue requests.""" + while True: + request = await queue.get() + if request is None: + break + yield request |