"""Base pipeline class for running a sequence of pipes.""" import asyncio import logging from enum import Enum from typing import Any, AsyncGenerator, Optional from ..logging.kv_logger import KVLoggingSingleton from ..logging.run_manager import RunManager, manage_run from ..pipes.base_pipe import AsyncPipe, AsyncState logger = logging.getLogger(__name__) class PipelineTypes(Enum): EVAL = "eval" INGESTION = "ingestion" SEARCH = "search" RAG = "rag" OTHER = "other" class AsyncPipeline: """Pipeline class for running a sequence of pipes.""" pipeline_type: str = "other" def __init__( self, pipe_logger: Optional[KVLoggingSingleton] = None, run_manager: Optional[RunManager] = None, ): self.pipes: list[AsyncPipe] = [] self.upstream_outputs: list[list[dict[str, str]]] = [] self.pipe_logger = pipe_logger or KVLoggingSingleton() self.run_manager = run_manager or RunManager(self.pipe_logger) self.futures = {} self.level = 0 def add_pipe( self, pipe: AsyncPipe, add_upstream_outputs: Optional[list[dict[str, str]]] = None, *args, **kwargs, ) -> None: """Add a pipe to the pipeline.""" self.pipes.append(pipe) if not add_upstream_outputs: add_upstream_outputs = [] self.upstream_outputs.append(add_upstream_outputs) async def run( self, input: Any, state: Optional[AsyncState] = None, stream: bool = False, run_manager: Optional[RunManager] = None, log_run_info: bool = True, *args: Any, **kwargs: Any, ): """Run the pipeline.""" run_manager = run_manager or self.run_manager try: PipelineTypes(self.pipeline_type) except ValueError: raise ValueError( f"Invalid pipeline type: {self.pipeline_type}, must be one of {PipelineTypes.__members__.keys()}" ) self.state = state or AsyncState() current_input = input async with manage_run(run_manager, self.pipeline_type): if log_run_info: await run_manager.log_run_info( key="pipeline_type", value=self.pipeline_type, is_info_log=True, ) try: for pipe_num in range(len(self.pipes)): config_name = self.pipes[pipe_num].config.name self.futures[config_name] = asyncio.Future() current_input = self._run_pipe( pipe_num, current_input, run_manager, *args, **kwargs, ) self.futures[config_name].set_result(current_input) if not stream: final_result = await self._consume_all(current_input) return final_result else: return current_input except Exception as error: logger.error(f"Pipeline failed with error: {error}") raise error async def _consume_all(self, gen: AsyncGenerator) -> list[Any]: result = [] async for item in gen: if hasattr( item, "__aiter__" ): # Check if the item is an async generator sub_result = await self._consume_all(item) result.extend(sub_result) else: result.append(item) return result async def _run_pipe( self, pipe_num: int, input: Any, run_manager: RunManager, *args: Any, **kwargs: Any, ): # Collect inputs, waiting for the necessary futures pipe = self.pipes[pipe_num] add_upstream_outputs = self.sort_upstream_outputs( self.upstream_outputs[pipe_num] ) input_dict = {"message": input} # Group upstream outputs by prev_pipe_name grouped_upstream_outputs = {} for upstream_input in add_upstream_outputs: upstream_pipe_name = upstream_input["prev_pipe_name"] if upstream_pipe_name not in grouped_upstream_outputs: grouped_upstream_outputs[upstream_pipe_name] = [] grouped_upstream_outputs[upstream_pipe_name].append(upstream_input) for ( upstream_pipe_name, upstream_inputs, ) in grouped_upstream_outputs.items(): async def resolve_future_output(future): result = future.result() # consume the async generator return [item async for item in result] async def replay_items_as_async_gen(items): for item in items: yield item temp_results = await resolve_future_output( self.futures[upstream_pipe_name] ) if upstream_pipe_name == self.pipes[pipe_num - 1].config.name: input_dict["message"] = replay_items_as_async_gen(temp_results) for upstream_input in upstream_inputs: outputs = await self.state.get(upstream_pipe_name, "output") prev_output_field = upstream_input.get( "prev_output_field", None ) if not prev_output_field: raise ValueError( "`prev_output_field` must be specified in the upstream_input" ) input_dict[upstream_input["input_field"]] = outputs[ prev_output_field ] # Handle the pipe generator async for ele in await pipe.run( pipe.Input(**input_dict), self.state, run_manager, *args, **kwargs, ): yield ele def sort_upstream_outputs( self, add_upstream_outputs: list[dict[str, str]] ) -> list[dict[str, str]]: pipe_name_to_index = { pipe.config.name: index for index, pipe in enumerate(self.pipes) } def get_pipe_index(upstream_output): return pipe_name_to_index[upstream_output["prev_pipe_name"]] sorted_outputs = sorted( add_upstream_outputs, key=get_pipe_index, reverse=True ) return sorted_outputs class EvalPipeline(AsyncPipeline): """A pipeline for evaluation.""" pipeline_type: str = "eval" async def run( self, input: Any, state: Optional[AsyncState] = None, stream: bool = False, run_manager: Optional[RunManager] = None, *args: Any, **kwargs: Any, ): return await super().run( input, state, stream, run_manager, *args, **kwargs ) def add_pipe( self, pipe: AsyncPipe, add_upstream_outputs: Optional[list[dict[str, str]]] = None, *args, **kwargs, ) -> None: logger.debug(f"Adding pipe {pipe.config.name} to the EvalPipeline") return super().add_pipe(pipe, add_upstream_outputs, *args, **kwargs) async def dequeue_requests(queue: asyncio.Queue) -> AsyncGenerator: """Create an async generator to dequeue requests.""" while True: request = await queue.get() if request is None: break yield request