about summary refs log tree commit diff
path: root/R2R/r2r/base/pipeline
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/base/pipeline
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/base/pipeline')
-rwxr-xr-xR2R/r2r/base/pipeline/__init__.py0
-rwxr-xr-xR2R/r2r/base/pipeline/base_pipeline.py233
2 files changed, 233 insertions, 0 deletions
diff --git a/R2R/r2r/base/pipeline/__init__.py b/R2R/r2r/base/pipeline/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/base/pipeline/__init__.py
diff --git a/R2R/r2r/base/pipeline/base_pipeline.py b/R2R/r2r/base/pipeline/base_pipeline.py
new file mode 100755
index 00000000..3c1eff9a
--- /dev/null
+++ b/R2R/r2r/base/pipeline/base_pipeline.py
@@ -0,0 +1,233 @@
+"""Base pipeline class for running a sequence of pipes."""
+
+import asyncio
+import logging
+from enum import Enum
+from typing import Any, AsyncGenerator, Optional
+
+from ..logging.kv_logger import KVLoggingSingleton
+from ..logging.run_manager import RunManager, manage_run
+from ..pipes.base_pipe import AsyncPipe, AsyncState
+
+logger = logging.getLogger(__name__)
+
+
+class PipelineTypes(Enum):
+    EVAL = "eval"
+    INGESTION = "ingestion"
+    SEARCH = "search"
+    RAG = "rag"
+    OTHER = "other"
+
+
+class AsyncPipeline:
+    """Pipeline class for running a sequence of pipes."""
+
+    pipeline_type: str = "other"
+
+    def __init__(
+        self,
+        pipe_logger: Optional[KVLoggingSingleton] = None,
+        run_manager: Optional[RunManager] = None,
+    ):
+        self.pipes: list[AsyncPipe] = []
+        self.upstream_outputs: list[list[dict[str, str]]] = []
+        self.pipe_logger = pipe_logger or KVLoggingSingleton()
+        self.run_manager = run_manager or RunManager(self.pipe_logger)
+        self.futures = {}
+        self.level = 0
+
+    def add_pipe(
+        self,
+        pipe: AsyncPipe,
+        add_upstream_outputs: Optional[list[dict[str, str]]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        """Add a pipe to the pipeline."""
+        self.pipes.append(pipe)
+        if not add_upstream_outputs:
+            add_upstream_outputs = []
+        self.upstream_outputs.append(add_upstream_outputs)
+
+    async def run(
+        self,
+        input: Any,
+        state: Optional[AsyncState] = None,
+        stream: bool = False,
+        run_manager: Optional[RunManager] = None,
+        log_run_info: bool = True,
+        *args: Any,
+        **kwargs: Any,
+    ):
+        """Run the pipeline."""
+        run_manager = run_manager or self.run_manager
+
+        try:
+            PipelineTypes(self.pipeline_type)
+        except ValueError:
+            raise ValueError(
+                f"Invalid pipeline type: {self.pipeline_type}, must be one of {PipelineTypes.__members__.keys()}"
+            )
+
+        self.state = state or AsyncState()
+        current_input = input
+        async with manage_run(run_manager, self.pipeline_type):
+            if log_run_info:
+                await run_manager.log_run_info(
+                    key="pipeline_type",
+                    value=self.pipeline_type,
+                    is_info_log=True,
+                )
+            try:
+                for pipe_num in range(len(self.pipes)):
+                    config_name = self.pipes[pipe_num].config.name
+                    self.futures[config_name] = asyncio.Future()
+
+                    current_input = self._run_pipe(
+                        pipe_num,
+                        current_input,
+                        run_manager,
+                        *args,
+                        **kwargs,
+                    )
+                    self.futures[config_name].set_result(current_input)
+                if not stream:
+                    final_result = await self._consume_all(current_input)
+                    return final_result
+                else:
+                    return current_input
+            except Exception as error:
+                logger.error(f"Pipeline failed with error: {error}")
+                raise error
+
+    async def _consume_all(self, gen: AsyncGenerator) -> list[Any]:
+        result = []
+        async for item in gen:
+            if hasattr(
+                item, "__aiter__"
+            ):  # Check if the item is an async generator
+                sub_result = await self._consume_all(item)
+                result.extend(sub_result)
+            else:
+                result.append(item)
+        return result
+
+    async def _run_pipe(
+        self,
+        pipe_num: int,
+        input: Any,
+        run_manager: RunManager,
+        *args: Any,
+        **kwargs: Any,
+    ):
+        # Collect inputs, waiting for the necessary futures
+        pipe = self.pipes[pipe_num]
+        add_upstream_outputs = self.sort_upstream_outputs(
+            self.upstream_outputs[pipe_num]
+        )
+        input_dict = {"message": input}
+
+        # Group upstream outputs by prev_pipe_name
+        grouped_upstream_outputs = {}
+        for upstream_input in add_upstream_outputs:
+            upstream_pipe_name = upstream_input["prev_pipe_name"]
+            if upstream_pipe_name not in grouped_upstream_outputs:
+                grouped_upstream_outputs[upstream_pipe_name] = []
+            grouped_upstream_outputs[upstream_pipe_name].append(upstream_input)
+
+        for (
+            upstream_pipe_name,
+            upstream_inputs,
+        ) in grouped_upstream_outputs.items():
+
+            async def resolve_future_output(future):
+                result = future.result()
+                # consume the async generator
+                return [item async for item in result]
+
+            async def replay_items_as_async_gen(items):
+                for item in items:
+                    yield item
+
+            temp_results = await resolve_future_output(
+                self.futures[upstream_pipe_name]
+            )
+            if upstream_pipe_name == self.pipes[pipe_num - 1].config.name:
+                input_dict["message"] = replay_items_as_async_gen(temp_results)
+
+            for upstream_input in upstream_inputs:
+                outputs = await self.state.get(upstream_pipe_name, "output")
+                prev_output_field = upstream_input.get(
+                    "prev_output_field", None
+                )
+                if not prev_output_field:
+                    raise ValueError(
+                        "`prev_output_field` must be specified in the upstream_input"
+                    )
+                input_dict[upstream_input["input_field"]] = outputs[
+                    prev_output_field
+                ]
+
+        # Handle the pipe generator
+        async for ele in await pipe.run(
+            pipe.Input(**input_dict),
+            self.state,
+            run_manager,
+            *args,
+            **kwargs,
+        ):
+            yield ele
+
+    def sort_upstream_outputs(
+        self, add_upstream_outputs: list[dict[str, str]]
+    ) -> list[dict[str, str]]:
+        pipe_name_to_index = {
+            pipe.config.name: index for index, pipe in enumerate(self.pipes)
+        }
+
+        def get_pipe_index(upstream_output):
+            return pipe_name_to_index[upstream_output["prev_pipe_name"]]
+
+        sorted_outputs = sorted(
+            add_upstream_outputs, key=get_pipe_index, reverse=True
+        )
+        return sorted_outputs
+
+
+class EvalPipeline(AsyncPipeline):
+    """A pipeline for evaluation."""
+
+    pipeline_type: str = "eval"
+
+    async def run(
+        self,
+        input: Any,
+        state: Optional[AsyncState] = None,
+        stream: bool = False,
+        run_manager: Optional[RunManager] = None,
+        *args: Any,
+        **kwargs: Any,
+    ):
+        return await super().run(
+            input, state, stream, run_manager, *args, **kwargs
+        )
+
+    def add_pipe(
+        self,
+        pipe: AsyncPipe,
+        add_upstream_outputs: Optional[list[dict[str, str]]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        logger.debug(f"Adding pipe {pipe.config.name} to the EvalPipeline")
+        return super().add_pipe(pipe, add_upstream_outputs, *args, **kwargs)
+
+
+async def dequeue_requests(queue: asyncio.Queue) -> AsyncGenerator:
+    """Create an async generator to dequeue requests."""
+    while True:
+        request = await queue.get()
+        if request is None:
+            break
+        yield request