about summary refs log tree commit diff
path: root/R2R/r2r/base/pipes/base_pipe.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/base/pipes/base_pipe.py')
-rwxr-xr-xR2R/r2r/base/pipes/base_pipe.py163
1 files changed, 163 insertions, 0 deletions
diff --git a/R2R/r2r/base/pipes/base_pipe.py b/R2R/r2r/base/pipes/base_pipe.py
new file mode 100755
index 00000000..63e3d04e
--- /dev/null
+++ b/R2R/r2r/base/pipes/base_pipe.py
@@ -0,0 +1,163 @@
+import asyncio
+import logging
+import uuid
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, AsyncGenerator, Optional
+
+from pydantic import BaseModel
+
+from r2r.base.logging.kv_logger import KVLoggingSingleton
+from r2r.base.logging.run_manager import RunManager, manage_run
+
+logger = logging.getLogger(__name__)
+
+
+class PipeType(Enum):
+    INGESTOR = "ingestor"
+    EVAL = "eval"
+    GENERATOR = "generator"
+    SEARCH = "search"
+    TRANSFORM = "transform"
+    OTHER = "other"
+
+
+class AsyncState:
+    """A state object for storing data between pipes."""
+
+    def __init__(self):
+        self.data = {}
+        self.lock = asyncio.Lock()
+
+    async def update(self, outer_key: str, values: dict):
+        """Update the state with new values."""
+        async with self.lock:
+            if not isinstance(values, dict):
+                raise ValueError("Values must be contained in a dictionary.")
+            if outer_key not in self.data:
+                self.data[outer_key] = {}
+            for inner_key, inner_value in values.items():
+                self.data[outer_key][inner_key] = inner_value
+
+    async def get(self, outer_key: str, inner_key: str, default=None):
+        """Get a value from the state."""
+        async with self.lock:
+            if outer_key not in self.data:
+                raise ValueError(
+                    f"Key {outer_key} does not exist in the state."
+                )
+            if inner_key not in self.data[outer_key]:
+                return default or {}
+            return self.data[outer_key][inner_key]
+
+    async def delete(self, outer_key: str, inner_key: Optional[str] = None):
+        """Delete a value from the state."""
+        async with self.lock:
+            if outer_key in self.data and not inner_key:
+                del self.data[outer_key]
+            else:
+                if inner_key not in self.data[outer_key]:
+                    raise ValueError(
+                        f"Key {inner_key} does not exist in the state."
+                    )
+                del self.data[outer_key][inner_key]
+
+
+class AsyncPipe:
+    """An asynchronous pipe for processing data with logging capabilities."""
+
+    class PipeConfig(BaseModel):
+        """Configuration for a pipe."""
+
+        name: str = "default_pipe"
+        max_log_queue_size: int = 100
+
+        class Config:
+            extra = "forbid"
+            arbitrary_types_allowed = True
+
+    class Input(BaseModel):
+        """Input for a pipe."""
+
+        message: AsyncGenerator[Any, None]
+
+        class Config:
+            extra = "forbid"
+            arbitrary_types_allowed = True
+
+    def __init__(
+        self,
+        type: PipeType = PipeType.OTHER,
+        config: Optional[PipeConfig] = None,
+        pipe_logger: Optional[KVLoggingSingleton] = None,
+        run_manager: Optional[RunManager] = None,
+    ):
+        self._config = config or self.PipeConfig()
+        self._type = type
+        self.pipe_logger = pipe_logger or KVLoggingSingleton()
+        self.log_queue = asyncio.Queue()
+        self.log_worker_task = None
+        self._run_manager = run_manager or RunManager(self.pipe_logger)
+
+        logger.debug(
+            f"Initialized pipe {self.config.name} of type {self.type}"
+        )
+
+    @property
+    def config(self) -> PipeConfig:
+        return self._config
+
+    @property
+    def type(self) -> PipeType:
+        return self._type
+
+    async def log_worker(self):
+        while True:
+            log_data = await self.log_queue.get()
+            run_id, key, value = log_data
+            await self.pipe_logger.log(run_id, key, value)
+            self.log_queue.task_done()
+
+    async def enqueue_log(self, run_id: uuid.UUID, key: str, value: str):
+        if self.log_queue.qsize() < self.config.max_log_queue_size:
+            await self.log_queue.put((run_id, key, value))
+
+    async def run(
+        self,
+        input: Input,
+        state: AsyncState,
+        run_manager: Optional[RunManager] = None,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[Any, None]:
+        """Run the pipe with logging capabilities."""
+
+        run_manager = run_manager or self._run_manager
+
+        async def wrapped_run() -> AsyncGenerator[Any, None]:
+            async with manage_run(run_manager, self.config.name) as run_id:
+                self.log_worker_task = asyncio.create_task(
+                    self.log_worker(), name=f"log-worker-{self.config.name}"
+                )
+                try:
+                    async for result in self._run_logic(
+                        input, state, run_id=run_id, *args, **kwargs
+                    ):
+                        yield result
+                finally:
+                    await self.log_queue.join()
+                    self.log_worker_task.cancel()
+                    self.log_queue = asyncio.Queue()
+
+        return wrapped_run()
+
+    @abstractmethod
+    async def _run_logic(
+        self,
+        input: Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[Any, None]:
+        pass