about summary refs log tree commit diff
path: root/R2R/tests/test_pipeline.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/tests/test_pipeline.py')
-rwxr-xr-xR2R/tests/test_pipeline.py291
1 files changed, 291 insertions, 0 deletions
diff --git a/R2R/tests/test_pipeline.py b/R2R/tests/test_pipeline.py
new file mode 100755
index 00000000..1811de0f
--- /dev/null
+++ b/R2R/tests/test_pipeline.py
@@ -0,0 +1,291 @@
+import asyncio
+from typing import Any, AsyncGenerator
+
+import pytest
+
+from r2r import AsyncPipe, AsyncPipeline, PipeType
+
+
+class MultiplierPipe(AsyncPipe):
+    def __init__(self, multiplier=1, delay=0, name="multiplier_pipe"):
+        super().__init__(
+            type=PipeType.OTHER,
+            config=self.PipeConfig(name=name),
+        )
+        self.multiplier = multiplier
+        self.delay = delay
+
+    async def _run_logic(
+        self,
+        input: AsyncGenerator[Any, None],
+        state,
+        run_id=None,
+        *args,
+        **kwargs,
+    ) -> AsyncGenerator[Any, None]:
+        async for item in input.message:
+            if self.delay > 0:
+                await asyncio.sleep(self.delay)  # Simulate processing delay
+            if isinstance(item, list):
+                processed = [x * self.multiplier for x in item]
+            elif isinstance(item, int):
+                processed = item * self.multiplier
+            else:
+                raise ValueError(f"Unsupported type: {type(item)}")
+            yield processed
+
+
+class FanOutPipe(AsyncPipe):
+    def __init__(self, multiplier=1, delay=0, name="fan_out_pipe"):
+        super().__init__(
+            type=PipeType.OTHER,
+            config=self.PipeConfig(name=name),
+        )
+        self.multiplier = multiplier
+        self.delay = delay
+
+    async def _run_logic(
+        self,
+        input: AsyncGenerator[Any, None],
+        state,
+        run_id=None,
+        *args,
+        **kwargs,
+    ) -> AsyncGenerator[Any, None]:
+        inputs = []
+        async for item in input.message:
+            inputs.append(item)
+        for it in range(self.multiplier):
+            if self.delay > 0:
+                await asyncio.sleep(self.delay)
+            yield [(it + 1) * ele for ele in inputs]
+
+
+class FanInPipe(AsyncPipe):
+    def __init__(self, delay=0, name="fan_in_pipe"):
+        super().__init__(
+            type=PipeType.OTHER,
+            config=self.PipeConfig(name=name),
+        )
+        self.delay = delay
+
+    async def _run_logic(
+        self,
+        input: AsyncGenerator[Any, None],
+        state,
+        run_id=None,
+        *args,
+        **kwargs,
+    ) -> AsyncGenerator[Any, None]:
+        total_sum = 0
+        async for batch in input.message:
+            if self.delay > 0:
+                await asyncio.sleep(self.delay)  # Simulate processing delay
+            total_sum += sum(
+                batch
+            )  # Assuming batch is iterable and contains numeric values
+        yield total_sum
+
+
+@pytest.fixture
+def pipe_factory():
+    def create_pipe(type, **kwargs):
+        if type == "multiplier":
+            return MultiplierPipe(**kwargs)
+        elif type == "fan_out":
+            return FanOutPipe(**kwargs)
+        elif type == "fan_in":
+            return FanInPipe(**kwargs)
+        else:
+            raise ValueError("Unsupported pipe type")
+
+    return create_pipe
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("multiplier, delay, name", [(2, 0.1, "pipe")])
+async def test_single_multiplier(pipe_factory, multiplier, delay, name):
+    pipe = pipe_factory(
+        "multiplier", multiplier=multiplier, delay=delay, name=name
+    )
+
+    async def input_generator():
+        for i in [1, 2, 3]:
+            yield i
+
+    pipeline = AsyncPipeline()
+    pipeline.add_pipe(pipe)
+
+    result = []
+    for output in await pipeline.run(input_generator()):
+        result.append(output)
+
+    expected_result = [i * multiplier for i in [1, 2, 3]]
+    assert (
+        result == expected_result
+    ), "Pipeline output did not match expected multipliers"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b",
+    [(2, 0.1, "pipe_a", 2, 0.1, "pipe_b")],
+)
+async def test_double_multiplier(
+    pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b
+):
+    pipe_a = pipe_factory(
+        "multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a
+    )
+    pipe_b = pipe_factory(
+        "multiplier", multiplier=multiplier_b, delay=delay_b, name=name_b
+    )
+
+    async def input_generator():
+        for i in [1, 2, 3]:
+            yield i
+
+    pipeline = AsyncPipeline()
+    pipeline.add_pipe(pipe_a)
+    pipeline.add_pipe(pipe_b)
+
+    result = []
+    for output in await pipeline.run(input_generator()):
+        result.append(output)
+
+    expected_result = [i * multiplier_a * multiplier_b for i in [1, 2, 3]]
+    assert (
+        result == expected_result
+    ), "Pipeline output did not match expected multipliers"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("multiplier, delay, name", [(3, 0.1, "pipe")])
+async def test_fan_out(pipe_factory, multiplier, delay, name):
+    pipe = pipe_factory(
+        "fan_out", multiplier=multiplier, delay=delay, name=name
+    )
+
+    async def input_generator():
+        for i in [1, 2, 3]:
+            yield i
+
+    pipeline = AsyncPipeline()
+    pipeline.add_pipe(pipe)
+
+    result = []
+    for output in await pipeline.run(input_generator()):
+        result.append(output)
+
+    expected_result = [
+        [i + 1, 2 * (i + 1), 3 * (i + 1)] for i in range(multiplier)
+    ]
+    assert (
+        result == expected_result
+    ), "Pipeline output did not match expected multipliers"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b",
+    [
+        (2, 0.1, "pipe_a", 2, 0.1, "pipe_b"),
+        (4, 0.1, "pipe_a", 3, 0.1, "pipe_b"),
+    ],
+)
+async def multiply_then_fan_out(
+    pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b
+):
+    pipe_a = pipe_factory(
+        "multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a
+    )
+    pipe_b = pipe_factory(
+        "fan_out", multiplier=multiplier_b, delay=delay_b, name=name_b
+    )
+
+    async def input_generator():
+        for i in [1, 2, 3]:
+            yield i
+
+    pipeline = AsyncPipeline()
+    pipeline.add_pipe(pipe_a)
+    pipeline.add_pipe(pipe_b)
+
+    result = []
+    async for output in await pipeline.run(input_generator()):
+        result.append(output)
+
+    expected_result = [[i * multiplier_a] async for i in input_generator()]
+    assert (
+        result[0] == expected_result
+    ), "Pipeline output did not match expected multipliers"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("multiplier, delay, name", [(3, 0.1, "pipe")])
+async def test_fan_in_sum(pipe_factory, multiplier, delay, name):
+    # Create fan-out to generate multiple streams
+    fan_out_pipe = pipe_factory(
+        "fan_out", multiplier=multiplier, delay=delay, name=name + "_a"
+    )
+    # Summing fan-in pipe
+    fan_in_sum_pipe = pipe_factory("fan_in", delay=delay, name=name + "_b")
+
+    async def input_generator():
+        for i in [1, 2, 3]:
+            yield i
+
+    pipeline = AsyncPipeline()
+    pipeline.add_pipe(fan_out_pipe)
+    pipeline.add_pipe(fan_in_sum_pipe)
+
+    result = await pipeline.run(input_generator())
+
+    # Calculate expected results based on the multiplier and the sum of inputs
+    expected_result = sum(
+        [sum([j * i for j in [1, 2, 3]]) for i in range(1, multiplier + 1)]
+    )
+    assert (
+        result[0] == expected_result
+    ), "Pipeline output did not match expected sums"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b",
+    [
+        (3, 0.1, "pipe_a", 2, 0.1, "pipe_b"),
+        (4, 0.1, "pipe_a", 3, 0.1, "pipe_b"),
+    ],
+)
+async def test_fan_out_then_multiply(
+    pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b
+):
+    pipe_a = pipe_factory(
+        "multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a
+    )
+    pipe_b = pipe_factory(
+        "fan_out", multiplier=multiplier_b, delay=delay_b, name=name_b
+    )
+    pipe_c = pipe_factory("fan_in", delay=0.1, name="pipe_c")
+
+    async def input_generator():
+        for i in [1, 2, 3]:
+            yield i
+
+    pipeline = AsyncPipeline()
+    pipeline.add_pipe(pipe_a)
+    pipeline.add_pipe(pipe_b)
+    pipeline.add_pipe(pipe_c)
+
+    result = await pipeline.run(input_generator())
+
+    expected_result = sum(
+        [
+            sum([j * i * multiplier_a for j in [1, 2, 3]])
+            for i in range(1, multiplier_b + 1)
+        ]
+    )
+    assert (
+        result[0] == expected_result
+    ), "Pipeline output did not match expected multipliers"