import asyncio
from typing import Any, AsyncGenerator
import pytest
from r2r import AsyncPipe, AsyncPipeline, PipeType
class MultiplierPipe(AsyncPipe):
def __init__(self, multiplier=1, delay=0, name="multiplier_pipe"):
super().__init__(
type=PipeType.OTHER,
config=self.PipeConfig(name=name),
)
self.multiplier = multiplier
self.delay = delay
async def _run_logic(
self,
input: AsyncGenerator[Any, None],
state,
run_id=None,
*args,
**kwargs,
) -> AsyncGenerator[Any, None]:
async for item in input.message:
if self.delay > 0:
await asyncio.sleep(self.delay) # Simulate processing delay
if isinstance(item, list):
processed = [x * self.multiplier for x in item]
elif isinstance(item, int):
processed = item * self.multiplier
else:
raise ValueError(f"Unsupported type: {type(item)}")
yield processed
class FanOutPipe(AsyncPipe):
def __init__(self, multiplier=1, delay=0, name="fan_out_pipe"):
super().__init__(
type=PipeType.OTHER,
config=self.PipeConfig(name=name),
)
self.multiplier = multiplier
self.delay = delay
async def _run_logic(
self,
input: AsyncGenerator[Any, None],
state,
run_id=None,
*args,
**kwargs,
) -> AsyncGenerator[Any, None]:
inputs = []
async for item in input.message:
inputs.append(item)
for it in range(self.multiplier):
if self.delay > 0:
await asyncio.sleep(self.delay)
yield [(it + 1) * ele for ele in inputs]
class FanInPipe(AsyncPipe):
def __init__(self, delay=0, name="fan_in_pipe"):
super().__init__(
type=PipeType.OTHER,
config=self.PipeConfig(name=name),
)
self.delay = delay
async def _run_logic(
self,
input: AsyncGenerator[Any, None],
state,
run_id=None,
*args,
**kwargs,
) -> AsyncGenerator[Any, None]:
total_sum = 0
async for batch in input.message:
if self.delay > 0:
await asyncio.sleep(self.delay) # Simulate processing delay
total_sum += sum(
batch
) # Assuming batch is iterable and contains numeric values
yield total_sum
@pytest.fixture
def pipe_factory():
def create_pipe(type, **kwargs):
if type == "multiplier":
return MultiplierPipe(**kwargs)
elif type == "fan_out":
return FanOutPipe(**kwargs)
elif type == "fan_in":
return FanInPipe(**kwargs)
else:
raise ValueError("Unsupported pipe type")
return create_pipe
@pytest.mark.asyncio
@pytest.mark.parametrize("multiplier, delay, name", [(2, 0.1, "pipe")])
async def test_single_multiplier(pipe_factory, multiplier, delay, name):
pipe = pipe_factory(
"multiplier", multiplier=multiplier, delay=delay, name=name
)
async def input_generator():
for i in [1, 2, 3]:
yield i
pipeline = AsyncPipeline()
pipeline.add_pipe(pipe)
result = []
for output in await pipeline.run(input_generator()):
result.append(output)
expected_result = [i * multiplier for i in [1, 2, 3]]
assert (
result == expected_result
), "Pipeline output did not match expected multipliers"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b",
[(2, 0.1, "pipe_a", 2, 0.1, "pipe_b")],
)
async def test_double_multiplier(
pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b
):
pipe_a = pipe_factory(
"multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a
)
pipe_b = pipe_factory(
"multiplier", multiplier=multiplier_b, delay=delay_b, name=name_b
)
async def input_generator():
for i in [1, 2, 3]:
yield i
pipeline = AsyncPipeline()
pipeline.add_pipe(pipe_a)
pipeline.add_pipe(pipe_b)
result = []
for output in await pipeline.run(input_generator()):
result.append(output)
expected_result = [i * multiplier_a * multiplier_b for i in [1, 2, 3]]
assert (
result == expected_result
), "Pipeline output did not match expected multipliers"
@pytest.mark.asyncio
@pytest.mark.parametrize("multiplier, delay, name", [(3, 0.1, "pipe")])
async def test_fan_out(pipe_factory, multiplier, delay, name):
pipe = pipe_factory(
"fan_out", multiplier=multiplier, delay=delay, name=name
)
async def input_generator():
for i in [1, 2, 3]:
yield i
pipeline = AsyncPipeline()
pipeline.add_pipe(pipe)
result = []
for output in await pipeline.run(input_generator()):
result.append(output)
expected_result = [
[i + 1, 2 * (i + 1), 3 * (i + 1)] for i in range(multiplier)
]
assert (
result == expected_result
), "Pipeline output did not match expected multipliers"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b",
[
(2, 0.1, "pipe_a", 2, 0.1, "pipe_b"),
(4, 0.1, "pipe_a", 3, 0.1, "pipe_b"),
],
)
async def multiply_then_fan_out(
pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b
):
pipe_a = pipe_factory(
"multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a
)
pipe_b = pipe_factory(
"fan_out", multiplier=multiplier_b, delay=delay_b, name=name_b
)
async def input_generator():
for i in [1, 2, 3]:
yield i
pipeline = AsyncPipeline()
pipeline.add_pipe(pipe_a)
pipeline.add_pipe(pipe_b)
result = []
async for output in await pipeline.run(input_generator()):
result.append(output)
expected_result = [[i * multiplier_a] async for i in input_generator()]
assert (
result[0] == expected_result
), "Pipeline output did not match expected multipliers"
@pytest.mark.asyncio
@pytest.mark.parametrize("multiplier, delay, name", [(3, 0.1, "pipe")])
async def test_fan_in_sum(pipe_factory, multiplier, delay, name):
# Create fan-out to generate multiple streams
fan_out_pipe = pipe_factory(
"fan_out", multiplier=multiplier, delay=delay, name=name + "_a"
)
# Summing fan-in pipe
fan_in_sum_pipe = pipe_factory("fan_in", delay=delay, name=name + "_b")
async def input_generator():
for i in [1, 2, 3]:
yield i
pipeline = AsyncPipeline()
pipeline.add_pipe(fan_out_pipe)
pipeline.add_pipe(fan_in_sum_pipe)
result = await pipeline.run(input_generator())
# Calculate expected results based on the multiplier and the sum of inputs
expected_result = sum(
[sum([j * i for j in [1, 2, 3]]) for i in range(1, multiplier + 1)]
)
assert (
result[0] == expected_result
), "Pipeline output did not match expected sums"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b",
[
(3, 0.1, "pipe_a", 2, 0.1, "pipe_b"),
(4, 0.1, "pipe_a", 3, 0.1, "pipe_b"),
],
)
async def test_fan_out_then_multiply(
pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b
):
pipe_a = pipe_factory(
"multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a
)
pipe_b = pipe_factory(
"fan_out", multiplier=multiplier_b, delay=delay_b, name=name_b
)
pipe_c = pipe_factory("fan_in", delay=0.1, name="pipe_c")
async def input_generator():
for i in [1, 2, 3]:
yield i
pipeline = AsyncPipeline()
pipeline.add_pipe(pipe_a)
pipeline.add_pipe(pipe_b)
pipeline.add_pipe(pipe_c)
result = await pipeline.run(input_generator())
expected_result = sum(
[
sum([j * i * multiplier_a for j in [1, 2, 3]])
for i in range(1, multiplier_b + 1)
]
)
assert (
result[0] == expected_result
), "Pipeline output did not match expected multipliers"