1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
|
import asyncio
import logging
import uuid
from abc import abstractmethod
from enum import Enum
from typing import Any, AsyncGenerator, Optional
from pydantic import BaseModel
from r2r.base.logging.kv_logger import KVLoggingSingleton
from r2r.base.logging.run_manager import RunManager, manage_run
logger = logging.getLogger(__name__)
class PipeType(Enum):
INGESTOR = "ingestor"
EVAL = "eval"
GENERATOR = "generator"
SEARCH = "search"
TRANSFORM = "transform"
OTHER = "other"
class AsyncState:
"""A state object for storing data between pipes."""
def __init__(self):
self.data = {}
self.lock = asyncio.Lock()
async def update(self, outer_key: str, values: dict):
"""Update the state with new values."""
async with self.lock:
if not isinstance(values, dict):
raise ValueError("Values must be contained in a dictionary.")
if outer_key not in self.data:
self.data[outer_key] = {}
for inner_key, inner_value in values.items():
self.data[outer_key][inner_key] = inner_value
async def get(self, outer_key: str, inner_key: str, default=None):
"""Get a value from the state."""
async with self.lock:
if outer_key not in self.data:
raise ValueError(
f"Key {outer_key} does not exist in the state."
)
if inner_key not in self.data[outer_key]:
return default or {}
return self.data[outer_key][inner_key]
async def delete(self, outer_key: str, inner_key: Optional[str] = None):
"""Delete a value from the state."""
async with self.lock:
if outer_key in self.data and not inner_key:
del self.data[outer_key]
else:
if inner_key not in self.data[outer_key]:
raise ValueError(
f"Key {inner_key} does not exist in the state."
)
del self.data[outer_key][inner_key]
class AsyncPipe:
"""An asynchronous pipe for processing data with logging capabilities."""
class PipeConfig(BaseModel):
"""Configuration for a pipe."""
name: str = "default_pipe"
max_log_queue_size: int = 100
class Config:
extra = "forbid"
arbitrary_types_allowed = True
class Input(BaseModel):
"""Input for a pipe."""
message: AsyncGenerator[Any, None]
class Config:
extra = "forbid"
arbitrary_types_allowed = True
def __init__(
self,
type: PipeType = PipeType.OTHER,
config: Optional[PipeConfig] = None,
pipe_logger: Optional[KVLoggingSingleton] = None,
run_manager: Optional[RunManager] = None,
):
self._config = config or self.PipeConfig()
self._type = type
self.pipe_logger = pipe_logger or KVLoggingSingleton()
self.log_queue = asyncio.Queue()
self.log_worker_task = None
self._run_manager = run_manager or RunManager(self.pipe_logger)
logger.debug(
f"Initialized pipe {self.config.name} of type {self.type}"
)
@property
def config(self) -> PipeConfig:
return self._config
@property
def type(self) -> PipeType:
return self._type
async def log_worker(self):
while True:
log_data = await self.log_queue.get()
run_id, key, value = log_data
await self.pipe_logger.log(run_id, key, value)
self.log_queue.task_done()
async def enqueue_log(self, run_id: uuid.UUID, key: str, value: str):
if self.log_queue.qsize() < self.config.max_log_queue_size:
await self.log_queue.put((run_id, key, value))
async def run(
self,
input: Input,
state: AsyncState,
run_manager: Optional[RunManager] = None,
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[Any, None]:
"""Run the pipe with logging capabilities."""
run_manager = run_manager or self._run_manager
async def wrapped_run() -> AsyncGenerator[Any, None]:
async with manage_run(run_manager, self.config.name) as run_id:
self.log_worker_task = asyncio.create_task(
self.log_worker(), name=f"log-worker-{self.config.name}"
)
try:
async for result in self._run_logic(
input, state, run_id=run_id, *args, **kwargs
):
yield result
finally:
await self.log_queue.join()
self.log_worker_task.cancel()
self.log_queue = asyncio.Queue()
return wrapped_run()
@abstractmethod
async def _run_logic(
self,
input: Input,
state: AsyncState,
run_id: uuid.UUID,
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[Any, None]:
pass
|