aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/pipeline/base_pipeline.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/base/pipeline/base_pipeline.py')
-rwxr-xr-xR2R/r2r/base/pipeline/base_pipeline.py233
1 files changed, 233 insertions, 0 deletions
diff --git a/R2R/r2r/base/pipeline/base_pipeline.py b/R2R/r2r/base/pipeline/base_pipeline.py
new file mode 100755
index 00000000..3c1eff9a
--- /dev/null
+++ b/R2R/r2r/base/pipeline/base_pipeline.py
@@ -0,0 +1,233 @@
+"""Base pipeline class for running a sequence of pipes."""
+
+import asyncio
+import logging
+from enum import Enum
+from typing import Any, AsyncGenerator, Optional
+
+from ..logging.kv_logger import KVLoggingSingleton
+from ..logging.run_manager import RunManager, manage_run
+from ..pipes.base_pipe import AsyncPipe, AsyncState
+
+logger = logging.getLogger(__name__)
+
+
+class PipelineTypes(Enum):
+ EVAL = "eval"
+ INGESTION = "ingestion"
+ SEARCH = "search"
+ RAG = "rag"
+ OTHER = "other"
+
+
+class AsyncPipeline:
+ """Pipeline class for running a sequence of pipes."""
+
+ pipeline_type: str = "other"
+
+ def __init__(
+ self,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ run_manager: Optional[RunManager] = None,
+ ):
+ self.pipes: list[AsyncPipe] = []
+ self.upstream_outputs: list[list[dict[str, str]]] = []
+ self.pipe_logger = pipe_logger or KVLoggingSingleton()
+ self.run_manager = run_manager or RunManager(self.pipe_logger)
+ self.futures = {}
+ self.level = 0
+
+ def add_pipe(
+ self,
+ pipe: AsyncPipe,
+ add_upstream_outputs: Optional[list[dict[str, str]]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ """Add a pipe to the pipeline."""
+ self.pipes.append(pipe)
+ if not add_upstream_outputs:
+ add_upstream_outputs = []
+ self.upstream_outputs.append(add_upstream_outputs)
+
+ async def run(
+ self,
+ input: Any,
+ state: Optional[AsyncState] = None,
+ stream: bool = False,
+ run_manager: Optional[RunManager] = None,
+ log_run_info: bool = True,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ """Run the pipeline."""
+ run_manager = run_manager or self.run_manager
+
+ try:
+ PipelineTypes(self.pipeline_type)
+ except ValueError:
+ raise ValueError(
+ f"Invalid pipeline type: {self.pipeline_type}, must be one of {PipelineTypes.__members__.keys()}"
+ )
+
+ self.state = state or AsyncState()
+ current_input = input
+ async with manage_run(run_manager, self.pipeline_type):
+ if log_run_info:
+ await run_manager.log_run_info(
+ key="pipeline_type",
+ value=self.pipeline_type,
+ is_info_log=True,
+ )
+ try:
+ for pipe_num in range(len(self.pipes)):
+ config_name = self.pipes[pipe_num].config.name
+ self.futures[config_name] = asyncio.Future()
+
+ current_input = self._run_pipe(
+ pipe_num,
+ current_input,
+ run_manager,
+ *args,
+ **kwargs,
+ )
+ self.futures[config_name].set_result(current_input)
+ if not stream:
+ final_result = await self._consume_all(current_input)
+ return final_result
+ else:
+ return current_input
+ except Exception as error:
+ logger.error(f"Pipeline failed with error: {error}")
+ raise error
+
+ async def _consume_all(self, gen: AsyncGenerator) -> list[Any]:
+ result = []
+ async for item in gen:
+ if hasattr(
+ item, "__aiter__"
+ ): # Check if the item is an async generator
+ sub_result = await self._consume_all(item)
+ result.extend(sub_result)
+ else:
+ result.append(item)
+ return result
+
+ async def _run_pipe(
+ self,
+ pipe_num: int,
+ input: Any,
+ run_manager: RunManager,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ # Collect inputs, waiting for the necessary futures
+ pipe = self.pipes[pipe_num]
+ add_upstream_outputs = self.sort_upstream_outputs(
+ self.upstream_outputs[pipe_num]
+ )
+ input_dict = {"message": input}
+
+ # Group upstream outputs by prev_pipe_name
+ grouped_upstream_outputs = {}
+ for upstream_input in add_upstream_outputs:
+ upstream_pipe_name = upstream_input["prev_pipe_name"]
+ if upstream_pipe_name not in grouped_upstream_outputs:
+ grouped_upstream_outputs[upstream_pipe_name] = []
+ grouped_upstream_outputs[upstream_pipe_name].append(upstream_input)
+
+ for (
+ upstream_pipe_name,
+ upstream_inputs,
+ ) in grouped_upstream_outputs.items():
+
+ async def resolve_future_output(future):
+ result = future.result()
+ # consume the async generator
+ return [item async for item in result]
+
+ async def replay_items_as_async_gen(items):
+ for item in items:
+ yield item
+
+ temp_results = await resolve_future_output(
+ self.futures[upstream_pipe_name]
+ )
+ if upstream_pipe_name == self.pipes[pipe_num - 1].config.name:
+ input_dict["message"] = replay_items_as_async_gen(temp_results)
+
+ for upstream_input in upstream_inputs:
+ outputs = await self.state.get(upstream_pipe_name, "output")
+ prev_output_field = upstream_input.get(
+ "prev_output_field", None
+ )
+ if not prev_output_field:
+ raise ValueError(
+ "`prev_output_field` must be specified in the upstream_input"
+ )
+ input_dict[upstream_input["input_field"]] = outputs[
+ prev_output_field
+ ]
+
+ # Handle the pipe generator
+ async for ele in await pipe.run(
+ pipe.Input(**input_dict),
+ self.state,
+ run_manager,
+ *args,
+ **kwargs,
+ ):
+ yield ele
+
+ def sort_upstream_outputs(
+ self, add_upstream_outputs: list[dict[str, str]]
+ ) -> list[dict[str, str]]:
+ pipe_name_to_index = {
+ pipe.config.name: index for index, pipe in enumerate(self.pipes)
+ }
+
+ def get_pipe_index(upstream_output):
+ return pipe_name_to_index[upstream_output["prev_pipe_name"]]
+
+ sorted_outputs = sorted(
+ add_upstream_outputs, key=get_pipe_index, reverse=True
+ )
+ return sorted_outputs
+
+
+class EvalPipeline(AsyncPipeline):
+ """A pipeline for evaluation."""
+
+ pipeline_type: str = "eval"
+
+ async def run(
+ self,
+ input: Any,
+ state: Optional[AsyncState] = None,
+ stream: bool = False,
+ run_manager: Optional[RunManager] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ return await super().run(
+ input, state, stream, run_manager, *args, **kwargs
+ )
+
+ def add_pipe(
+ self,
+ pipe: AsyncPipe,
+ add_upstream_outputs: Optional[list[dict[str, str]]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ logger.debug(f"Adding pipe {pipe.config.name} to the EvalPipeline")
+ return super().add_pipe(pipe, add_upstream_outputs, *args, **kwargs)
+
+
+async def dequeue_requests(queue: asyncio.Queue) -> AsyncGenerator:
+ """Create an async generator to dequeue requests."""
+ while True:
+ request = await queue.get()
+ if request is None:
+ break
+ yield request