about summary refs log tree commit diff
import asyncio
from typing import Any, AsyncGenerator

import pytest

from r2r import AsyncPipe, AsyncPipeline, PipeType


class MultiplierPipe(AsyncPipe):
    def __init__(self, multiplier=1, delay=0, name="multiplier_pipe"):
        super().__init__(
            type=PipeType.OTHER,
            config=self.PipeConfig(name=name),
        )
        self.multiplier = multiplier
        self.delay = delay

    async def _run_logic(
        self,
        input: AsyncGenerator[Any, None],
        state,
        run_id=None,
        *args,
        **kwargs,
    ) -> AsyncGenerator[Any, None]:
        async for item in input.message:
            if self.delay > 0:
                await asyncio.sleep(self.delay)  # Simulate processing delay
            if isinstance(item, list):
                processed = [x * self.multiplier for x in item]
            elif isinstance(item, int):
                processed = item * self.multiplier
            else:
                raise ValueError(f"Unsupported type: {type(item)}")
            yield processed


class FanOutPipe(AsyncPipe):
    def __init__(self, multiplier=1, delay=0, name="fan_out_pipe"):
        super().__init__(
            type=PipeType.OTHER,
            config=self.PipeConfig(name=name),
        )
        self.multiplier = multiplier
        self.delay = delay

    async def _run_logic(
        self,
        input: AsyncGenerator[Any, None],
        state,
        run_id=None,
        *args,
        **kwargs,
    ) -> AsyncGenerator[Any, None]:
        inputs = []
        async for item in input.message:
            inputs.append(item)
        for it in range(self.multiplier):
            if self.delay > 0:
                await asyncio.sleep(self.delay)
            yield [(it + 1) * ele for ele in inputs]


class FanInPipe(AsyncPipe):
    def __init__(self, delay=0, name="fan_in_pipe"):
        super().__init__(
            type=PipeType.OTHER,
            config=self.PipeConfig(name=name),
        )
        self.delay = delay

    async def _run_logic(
        self,
        input: AsyncGenerator[Any, None],
        state,
        run_id=None,
        *args,
        **kwargs,
    ) -> AsyncGenerator[Any, None]:
        total_sum = 0
        async for batch in input.message:
            if self.delay > 0:
                await asyncio.sleep(self.delay)  # Simulate processing delay
            total_sum += sum(
                batch
            )  # Assuming batch is iterable and contains numeric values
        yield total_sum


@pytest.fixture
def pipe_factory():
    def create_pipe(type, **kwargs):
        if type == "multiplier":
            return MultiplierPipe(**kwargs)
        elif type == "fan_out":
            return FanOutPipe(**kwargs)
        elif type == "fan_in":
            return FanInPipe(**kwargs)
        else:
            raise ValueError("Unsupported pipe type")

    return create_pipe


@pytest.mark.asyncio
@pytest.mark.parametrize("multiplier, delay, name", [(2, 0.1, "pipe")])
async def test_single_multiplier(pipe_factory, multiplier, delay, name):
    pipe = pipe_factory(
        "multiplier", multiplier=multiplier, delay=delay, name=name
    )

    async def input_generator():
        for i in [1, 2, 3]:
            yield i

    pipeline = AsyncPipeline()
    pipeline.add_pipe(pipe)

    result = []
    for output in await pipeline.run(input_generator()):
        result.append(output)

    expected_result = [i * multiplier for i in [1, 2, 3]]
    assert (
        result == expected_result
    ), "Pipeline output did not match expected multipliers"


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b",
    [(2, 0.1, "pipe_a", 2, 0.1, "pipe_b")],
)
async def test_double_multiplier(
    pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b
):
    pipe_a = pipe_factory(
        "multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a
    )
    pipe_b = pipe_factory(
        "multiplier", multiplier=multiplier_b, delay=delay_b, name=name_b
    )

    async def input_generator():
        for i in [1, 2, 3]:
            yield i

    pipeline = AsyncPipeline()
    pipeline.add_pipe(pipe_a)
    pipeline.add_pipe(pipe_b)

    result = []
    for output in await pipeline.run(input_generator()):
        result.append(output)

    expected_result = [i * multiplier_a * multiplier_b for i in [1, 2, 3]]
    assert (
        result == expected_result
    ), "Pipeline output did not match expected multipliers"


@pytest.mark.asyncio
@pytest.mark.parametrize("multiplier, delay, name", [(3, 0.1, "pipe")])
async def test_fan_out(pipe_factory, multiplier, delay, name):
    pipe = pipe_factory(
        "fan_out", multiplier=multiplier, delay=delay, name=name
    )

    async def input_generator():
        for i in [1, 2, 3]:
            yield i

    pipeline = AsyncPipeline()
    pipeline.add_pipe(pipe)

    result = []
    for output in await pipeline.run(input_generator()):
        result.append(output)

    expected_result = [
        [i + 1, 2 * (i + 1), 3 * (i + 1)] for i in range(multiplier)
    ]
    assert (
        result == expected_result
    ), "Pipeline output did not match expected multipliers"


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b",
    [
        (2, 0.1, "pipe_a", 2, 0.1, "pipe_b"),
        (4, 0.1, "pipe_a", 3, 0.1, "pipe_b"),
    ],
)
async def multiply_then_fan_out(
    pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b
):
    pipe_a = pipe_factory(
        "multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a
    )
    pipe_b = pipe_factory(
        "fan_out", multiplier=multiplier_b, delay=delay_b, name=name_b
    )

    async def input_generator():
        for i in [1, 2, 3]:
            yield i

    pipeline = AsyncPipeline()
    pipeline.add_pipe(pipe_a)
    pipeline.add_pipe(pipe_b)

    result = []
    async for output in await pipeline.run(input_generator()):
        result.append(output)

    expected_result = [[i * multiplier_a] async for i in input_generator()]
    assert (
        result[0] == expected_result
    ), "Pipeline output did not match expected multipliers"


@pytest.mark.asyncio
@pytest.mark.parametrize("multiplier, delay, name", [(3, 0.1, "pipe")])
async def test_fan_in_sum(pipe_factory, multiplier, delay, name):
    # Create fan-out to generate multiple streams
    fan_out_pipe = pipe_factory(
        "fan_out", multiplier=multiplier, delay=delay, name=name + "_a"
    )
    # Summing fan-in pipe
    fan_in_sum_pipe = pipe_factory("fan_in", delay=delay, name=name + "_b")

    async def input_generator():
        for i in [1, 2, 3]:
            yield i

    pipeline = AsyncPipeline()
    pipeline.add_pipe(fan_out_pipe)
    pipeline.add_pipe(fan_in_sum_pipe)

    result = await pipeline.run(input_generator())

    # Calculate expected results based on the multiplier and the sum of inputs
    expected_result = sum(
        [sum([j * i for j in [1, 2, 3]]) for i in range(1, multiplier + 1)]
    )
    assert (
        result[0] == expected_result
    ), "Pipeline output did not match expected sums"


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b",
    [
        (3, 0.1, "pipe_a", 2, 0.1, "pipe_b"),
        (4, 0.1, "pipe_a", 3, 0.1, "pipe_b"),
    ],
)
async def test_fan_out_then_multiply(
    pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b
):
    pipe_a = pipe_factory(
        "multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a
    )
    pipe_b = pipe_factory(
        "fan_out", multiplier=multiplier_b, delay=delay_b, name=name_b
    )
    pipe_c = pipe_factory("fan_in", delay=0.1, name="pipe_c")

    async def input_generator():
        for i in [1, 2, 3]:
            yield i

    pipeline = AsyncPipeline()
    pipeline.add_pipe(pipe_a)
    pipeline.add_pipe(pipe_b)
    pipeline.add_pipe(pipe_c)

    result = await pipeline.run(input_generator())

    expected_result = sum(
        [
            sum([j * i * multiplier_a for j in [1, 2, 3]])
            for i in range(1, multiplier_b + 1)
        ]
    )
    assert (
        result[0] == expected_result
    ), "Pipeline output did not match expected multipliers"