aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/pipes/base_pipe.py
blob: 63e3d04e8e497961cf53ee2ebb0da987a33b4f1f (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import asyncio
import logging
import uuid
from abc import abstractmethod
from enum import Enum
from typing import Any, AsyncGenerator, Optional

from pydantic import BaseModel

from r2r.base.logging.kv_logger import KVLoggingSingleton
from r2r.base.logging.run_manager import RunManager, manage_run

logger = logging.getLogger(__name__)


class PipeType(Enum):
    INGESTOR = "ingestor"
    EVAL = "eval"
    GENERATOR = "generator"
    SEARCH = "search"
    TRANSFORM = "transform"
    OTHER = "other"


class AsyncState:
    """A state object for storing data between pipes."""

    def __init__(self):
        self.data = {}
        self.lock = asyncio.Lock()

    async def update(self, outer_key: str, values: dict):
        """Update the state with new values."""
        async with self.lock:
            if not isinstance(values, dict):
                raise ValueError("Values must be contained in a dictionary.")
            if outer_key not in self.data:
                self.data[outer_key] = {}
            for inner_key, inner_value in values.items():
                self.data[outer_key][inner_key] = inner_value

    async def get(self, outer_key: str, inner_key: str, default=None):
        """Get a value from the state."""
        async with self.lock:
            if outer_key not in self.data:
                raise ValueError(
                    f"Key {outer_key} does not exist in the state."
                )
            if inner_key not in self.data[outer_key]:
                return default or {}
            return self.data[outer_key][inner_key]

    async def delete(self, outer_key: str, inner_key: Optional[str] = None):
        """Delete a value from the state."""
        async with self.lock:
            if outer_key in self.data and not inner_key:
                del self.data[outer_key]
            else:
                if inner_key not in self.data[outer_key]:
                    raise ValueError(
                        f"Key {inner_key} does not exist in the state."
                    )
                del self.data[outer_key][inner_key]


class AsyncPipe:
    """An asynchronous pipe for processing data with logging capabilities."""

    class PipeConfig(BaseModel):
        """Configuration for a pipe."""

        name: str = "default_pipe"
        max_log_queue_size: int = 100

        class Config:
            extra = "forbid"
            arbitrary_types_allowed = True

    class Input(BaseModel):
        """Input for a pipe."""

        message: AsyncGenerator[Any, None]

        class Config:
            extra = "forbid"
            arbitrary_types_allowed = True

    def __init__(
        self,
        type: PipeType = PipeType.OTHER,
        config: Optional[PipeConfig] = None,
        pipe_logger: Optional[KVLoggingSingleton] = None,
        run_manager: Optional[RunManager] = None,
    ):
        self._config = config or self.PipeConfig()
        self._type = type
        self.pipe_logger = pipe_logger or KVLoggingSingleton()
        self.log_queue = asyncio.Queue()
        self.log_worker_task = None
        self._run_manager = run_manager or RunManager(self.pipe_logger)

        logger.debug(
            f"Initialized pipe {self.config.name} of type {self.type}"
        )

    @property
    def config(self) -> PipeConfig:
        return self._config

    @property
    def type(self) -> PipeType:
        return self._type

    async def log_worker(self):
        while True:
            log_data = await self.log_queue.get()
            run_id, key, value = log_data
            await self.pipe_logger.log(run_id, key, value)
            self.log_queue.task_done()

    async def enqueue_log(self, run_id: uuid.UUID, key: str, value: str):
        if self.log_queue.qsize() < self.config.max_log_queue_size:
            await self.log_queue.put((run_id, key, value))

    async def run(
        self,
        input: Input,
        state: AsyncState,
        run_manager: Optional[RunManager] = None,
        *args: Any,
        **kwargs: Any,
    ) -> AsyncGenerator[Any, None]:
        """Run the pipe with logging capabilities."""

        run_manager = run_manager or self._run_manager

        async def wrapped_run() -> AsyncGenerator[Any, None]:
            async with manage_run(run_manager, self.config.name) as run_id:
                self.log_worker_task = asyncio.create_task(
                    self.log_worker(), name=f"log-worker-{self.config.name}"
                )
                try:
                    async for result in self._run_logic(
                        input, state, run_id=run_id, *args, **kwargs
                    ):
                        yield result
                finally:
                    await self.log_queue.join()
                    self.log_worker_task.cancel()
                    self.log_queue = asyncio.Queue()

        return wrapped_run()

    @abstractmethod
    async def _run_logic(
        self,
        input: Input,
        state: AsyncState,
        run_id: uuid.UUID,
        *args: Any,
        **kwargs: Any,
    ) -> AsyncGenerator[Any, None]:
        pass