about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/aiostream/stream/transform.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiostream/stream/transform.py')
-rw-r--r--.venv/lib/python3.12/site-packages/aiostream/stream/transform.py128
1 files changed, 128 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiostream/stream/transform.py b/.venv/lib/python3.12/site-packages/aiostream/stream/transform.py
new file mode 100644
index 00000000..f11bffa6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/aiostream/stream/transform.py
@@ -0,0 +1,128 @@
+"""Transformation operators."""
+
+from __future__ import annotations
+
+import asyncio
+import itertools
+from typing import (
+    Protocol,
+    TypeVar,
+    AsyncIterable,
+    AsyncIterator,
+    Awaitable,
+    cast,
+)
+
+from ..core import streamcontext, pipable_operator
+
+from . import select
+from . import create
+from . import aggregate
+from .combine import map, amap, smap
+
+__all__ = ["map", "enumerate", "starmap", "cycle", "chunks"]
+
+# map, amap and smap are also transform operators
+map, amap, smap
+
+T = TypeVar("T")
+U = TypeVar("U")
+
+
+@pipable_operator
+async def enumerate(
+    source: AsyncIterable[T], start: int = 0, step: int = 1
+) -> AsyncIterator[tuple[int, T]]:
+    """Generate ``(index, value)`` tuples from an asynchronous sequence.
+
+    This index is computed using a starting point and an increment,
+    respectively defaulting to ``0`` and ``1``.
+    """
+    count = itertools.count(start, step)
+    async with streamcontext(source) as streamer:
+        async for item in streamer:
+            yield next(count), item
+
+
+X = TypeVar("X", contravariant=True)
+Y = TypeVar("Y", covariant=True)
+
+
+class AsyncStarmapCallable(Protocol[X, Y]):
+    def __call__(self, arg: X, /, *args: X) -> Awaitable[Y]:
+        ...
+
+
+class SyncStarmapCallable(Protocol[X, Y]):
+    def __call__(self, arg: X, /, *args: X) -> Y:
+        ...
+
+
+@pipable_operator
+def starmap(
+    source: AsyncIterable[tuple[T, ...]],
+    func: SyncStarmapCallable[T, U] | AsyncStarmapCallable[T, U],
+    ordered: bool = True,
+    task_limit: int | None = None,
+) -> AsyncIterator[U]:
+    """Apply a given function to the unpacked elements of
+    an asynchronous sequence.
+
+    Each element is unpacked before applying the function.
+    The given function can either be synchronous or asynchronous.
+
+    The results can either be returned in or out of order, depending on
+    the corresponding ``ordered`` argument. This argument is ignored if
+    the provided function is synchronous.
+
+    The coroutines run concurrently but their amount can be limited using
+    the ``task_limit`` argument. A value of ``1`` will cause the coroutines
+    to run sequentially. This argument is ignored if the provided function
+    is synchronous.
+    """
+    if asyncio.iscoroutinefunction(func):
+        async_func = cast("AsyncStarmapCallable[T, U]", func)
+
+        async def astarfunc(args: tuple[T, ...], *_: object) -> U:
+            awaitable = async_func(*args)
+            return await awaitable
+
+        return amap.raw(source, astarfunc, ordered=ordered, task_limit=task_limit)
+
+    else:
+        sync_func = cast("SyncStarmapCallable[T, U]", func)
+
+        def starfunc(args: tuple[T, ...], *_: object) -> U:
+            return sync_func(*args)
+
+        return smap.raw(source, starfunc)
+
+
+@pipable_operator
+async def cycle(source: AsyncIterable[T]) -> AsyncIterator[T]:
+    """Iterate indefinitely over an asynchronous sequence.
+
+    Note: it does not perform any buffering, but re-iterate over
+    the same given sequence instead. If the sequence is not
+    re-iterable, the generator might end up looping indefinitely
+    without yielding any item.
+    """
+    while True:
+        async with streamcontext(source) as streamer:
+            async for item in streamer:
+                yield item
+            # Prevent blocking while loop if the stream is empty
+            await asyncio.sleep(0)
+
+
+@pipable_operator
+async def chunks(source: AsyncIterable[T], n: int) -> AsyncIterator[list[T]]:
+    """Generate chunks of size ``n`` from an asynchronous sequence.
+
+    The chunks are lists, and the last chunk might contain less than ``n``
+    elements.
+    """
+    async with streamcontext(source) as streamer:
+        async for first in streamer:
+            xs = select.take(create.preserve(streamer), n - 1)
+            yield [first] + await aggregate.list(xs)