aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/pipes/base_pipe.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/base/pipes/base_pipe.py')
-rwxr-xr-xR2R/r2r/base/pipes/base_pipe.py163
1 files changed, 163 insertions, 0 deletions
diff --git a/R2R/r2r/base/pipes/base_pipe.py b/R2R/r2r/base/pipes/base_pipe.py
new file mode 100755
index 00000000..63e3d04e
--- /dev/null
+++ b/R2R/r2r/base/pipes/base_pipe.py
@@ -0,0 +1,163 @@
+import asyncio
+import logging
+import uuid
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, AsyncGenerator, Optional
+
+from pydantic import BaseModel
+
+from r2r.base.logging.kv_logger import KVLoggingSingleton
+from r2r.base.logging.run_manager import RunManager, manage_run
+
+logger = logging.getLogger(__name__)
+
+
+class PipeType(Enum):
+ INGESTOR = "ingestor"
+ EVAL = "eval"
+ GENERATOR = "generator"
+ SEARCH = "search"
+ TRANSFORM = "transform"
+ OTHER = "other"
+
+
+class AsyncState:
+ """A state object for storing data between pipes."""
+
+ def __init__(self):
+ self.data = {}
+ self.lock = asyncio.Lock()
+
+ async def update(self, outer_key: str, values: dict):
+ """Update the state with new values."""
+ async with self.lock:
+ if not isinstance(values, dict):
+ raise ValueError("Values must be contained in a dictionary.")
+ if outer_key not in self.data:
+ self.data[outer_key] = {}
+ for inner_key, inner_value in values.items():
+ self.data[outer_key][inner_key] = inner_value
+
+ async def get(self, outer_key: str, inner_key: str, default=None):
+ """Get a value from the state."""
+ async with self.lock:
+ if outer_key not in self.data:
+ raise ValueError(
+ f"Key {outer_key} does not exist in the state."
+ )
+ if inner_key not in self.data[outer_key]:
+ return default or {}
+ return self.data[outer_key][inner_key]
+
+ async def delete(self, outer_key: str, inner_key: Optional[str] = None):
+ """Delete a value from the state."""
+ async with self.lock:
+ if outer_key in self.data and not inner_key:
+ del self.data[outer_key]
+ else:
+ if inner_key not in self.data[outer_key]:
+ raise ValueError(
+ f"Key {inner_key} does not exist in the state."
+ )
+ del self.data[outer_key][inner_key]
+
+
+class AsyncPipe:
+ """An asynchronous pipe for processing data with logging capabilities."""
+
+ class PipeConfig(BaseModel):
+ """Configuration for a pipe."""
+
+ name: str = "default_pipe"
+ max_log_queue_size: int = 100
+
+ class Config:
+ extra = "forbid"
+ arbitrary_types_allowed = True
+
+ class Input(BaseModel):
+ """Input for a pipe."""
+
+ message: AsyncGenerator[Any, None]
+
+ class Config:
+ extra = "forbid"
+ arbitrary_types_allowed = True
+
+ def __init__(
+ self,
+ type: PipeType = PipeType.OTHER,
+ config: Optional[PipeConfig] = None,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ run_manager: Optional[RunManager] = None,
+ ):
+ self._config = config or self.PipeConfig()
+ self._type = type
+ self.pipe_logger = pipe_logger or KVLoggingSingleton()
+ self.log_queue = asyncio.Queue()
+ self.log_worker_task = None
+ self._run_manager = run_manager or RunManager(self.pipe_logger)
+
+ logger.debug(
+ f"Initialized pipe {self.config.name} of type {self.type}"
+ )
+
+ @property
+ def config(self) -> PipeConfig:
+ return self._config
+
+ @property
+ def type(self) -> PipeType:
+ return self._type
+
+ async def log_worker(self):
+ while True:
+ log_data = await self.log_queue.get()
+ run_id, key, value = log_data
+ await self.pipe_logger.log(run_id, key, value)
+ self.log_queue.task_done()
+
+ async def enqueue_log(self, run_id: uuid.UUID, key: str, value: str):
+ if self.log_queue.qsize() < self.config.max_log_queue_size:
+ await self.log_queue.put((run_id, key, value))
+
+ async def run(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_manager: Optional[RunManager] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[Any, None]:
+ """Run the pipe with logging capabilities."""
+
+ run_manager = run_manager or self._run_manager
+
+ async def wrapped_run() -> AsyncGenerator[Any, None]:
+ async with manage_run(run_manager, self.config.name) as run_id:
+ self.log_worker_task = asyncio.create_task(
+ self.log_worker(), name=f"log-worker-{self.config.name}"
+ )
+ try:
+ async for result in self._run_logic(
+ input, state, run_id=run_id, *args, **kwargs
+ ):
+ yield result
+ finally:
+ await self.log_queue.join()
+ self.log_worker_task.cancel()
+ self.log_queue = asyncio.Queue()
+
+ return wrapped_run()
+
+ @abstractmethod
+ async def _run_logic(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[Any, None]:
+ pass