aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/aiostream/stream/combine.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiostream/stream/combine.py')
-rw-r--r--.venv/lib/python3.12/site-packages/aiostream/stream/combine.py282
1 files changed, 282 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiostream/stream/combine.py b/.venv/lib/python3.12/site-packages/aiostream/stream/combine.py
new file mode 100644
index 00000000..a782a730
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/aiostream/stream/combine.py
@@ -0,0 +1,282 @@
+"""Combination operators."""
+from __future__ import annotations
+
+import asyncio
+import builtins
+
+from typing import (
+ Awaitable,
+ Protocol,
+ TypeVar,
+ AsyncIterable,
+ AsyncIterator,
+ Callable,
+ cast,
+)
+from typing_extensions import ParamSpec
+
+from ..aiter_utils import AsyncExitStack, anext
+from ..core import streamcontext, pipable_operator
+
+from . import create
+from . import select
+from . import advanced
+from . import aggregate
+
+__all__ = ["chain", "zip", "map", "merge", "ziplatest", "amap", "smap"]
+
+T = TypeVar("T")
+U = TypeVar("U")
+K = TypeVar("K")
+P = ParamSpec("P")
+
+
+@pipable_operator
+async def chain(
+ source: AsyncIterable[T], *more_sources: AsyncIterable[T]
+) -> AsyncIterator[T]:
+ """Chain asynchronous sequences together, in the order they are given.
+
+ Note: the sequences are not iterated until it is required,
+ so if the operation is interrupted, the remaining sequences
+ will be left untouched.
+ """
+ sources = source, *more_sources
+ for source in sources:
+ async with streamcontext(source) as streamer:
+ async for item in streamer:
+ yield item
+
+
+@pipable_operator
+async def zip(
+ source: AsyncIterable[T], *more_sources: AsyncIterable[T]
+) -> AsyncIterator[tuple[T, ...]]:
+ """Combine and forward the elements of several asynchronous sequences.
+
+ Each generated value is a tuple of elements, using the same order as
+ their respective sources. The generation continues until the shortest
+ sequence is exhausted.
+
+ Note: the different sequences are awaited in parrallel, so that their
+ waiting times don't add up.
+ """
+ sources = source, *more_sources
+
+ # One sources
+ if len(sources) == 1:
+ (source,) = sources
+ async with streamcontext(source) as streamer:
+ async for item in streamer:
+ yield (item,)
+ return
+
+ # N sources
+ async with AsyncExitStack() as stack:
+ # Handle resources
+ streamers = [
+ await stack.enter_async_context(streamcontext(source)) for source in sources
+ ]
+ # Loop over items
+ while True:
+ try:
+ coros = builtins.map(anext, streamers)
+ items = await asyncio.gather(*coros)
+ except StopAsyncIteration:
+ break
+ else:
+ yield tuple(items)
+
+
+X = TypeVar("X", contravariant=True)
+Y = TypeVar("Y", covariant=True)
+
+
+class SmapCallable(Protocol[X, Y]):
+ def __call__(self, arg: X, /, *args: X) -> Y:
+ ...
+
+
+class AmapCallable(Protocol[X, Y]):
+ async def __call__(self, arg: X, /, *args: X) -> Y:
+ ...
+
+
+class MapCallable(Protocol[X, Y]):
+ def __call__(self, arg: X, /, *args: X) -> Awaitable[Y] | Y:
+ ...
+
+
+@pipable_operator
+async def smap(
+ source: AsyncIterable[T],
+ func: SmapCallable[T, U],
+ *more_sources: AsyncIterable[T],
+) -> AsyncIterator[U]:
+ """Apply a given function to the elements of one or several
+ asynchronous sequences.
+
+ Each element is used as a positional argument, using the same order as
+ their respective sources. The generation continues until the shortest
+ sequence is exhausted. The function is treated synchronously.
+
+ Note: if more than one sequence is provided, they're awaited concurrently
+ so that their waiting times don't add up.
+ """
+ stream = zip(source, *more_sources)
+ async with streamcontext(stream) as streamer:
+ async for item in streamer:
+ yield func(*item)
+
+
+@pipable_operator
+def amap(
+ source: AsyncIterable[T],
+ corofn: AmapCallable[T, U],
+ *more_sources: AsyncIterable[T],
+ ordered: bool = True,
+ task_limit: int | None = None,
+) -> AsyncIterator[U]:
+ """Apply a given coroutine function to the elements of one or several
+ asynchronous sequences.
+
+ Each element is used as a positional argument, using the same order as
+ their respective sources. The generation continues until the shortest
+ sequence is exhausted.
+
+ The results can either be returned in or out of order, depending on
+ the corresponding ``ordered`` argument.
+
+ The coroutines run concurrently but their amount can be limited using
+ the ``task_limit`` argument. A value of ``1`` will cause the coroutines
+ to run sequentially.
+
+ If more than one sequence is provided, they're also awaited concurrently,
+ so that their waiting times don't add up.
+ """
+
+ async def func(arg: T, *args: T) -> AsyncIterable[U]:
+ yield await corofn(arg, *args)
+
+ if ordered:
+ return advanced.concatmap.raw(
+ source, func, *more_sources, task_limit=task_limit
+ )
+ return advanced.flatmap.raw(source, func, *more_sources, task_limit=task_limit)
+
+
+@pipable_operator
+def map(
+ source: AsyncIterable[T],
+ func: MapCallable[T, U],
+ *more_sources: AsyncIterable[T],
+ ordered: bool = True,
+ task_limit: int | None = None,
+) -> AsyncIterator[U]:
+ """Apply a given function to the elements of one or several
+ asynchronous sequences.
+
+ Each element is used as a positional argument, using the same order as
+ their respective sources. The generation continues until the shortest
+ sequence is exhausted. The function can either be synchronous or
+ asynchronous (coroutine function).
+
+ The results can either be returned in or out of order, depending on
+ the corresponding ``ordered`` argument. This argument is ignored if the
+ provided function is synchronous.
+
+ The coroutines run concurrently but their amount can be limited using
+ the ``task_limit`` argument. A value of ``1`` will cause the coroutines
+ to run sequentially. This argument is ignored if the provided function
+ is synchronous.
+
+ If more than one sequence is provided, they're also awaited concurrently,
+ so that their waiting times don't add up.
+
+ It might happen that the provided function returns a coroutine but is not
+ a coroutine function per se. In this case, one can wrap the function with
+ ``aiostream.async_`` in order to force ``map`` to await the resulting
+ coroutine. The following example illustrates the use ``async_`` with a
+ lambda function::
+
+ from aiostream import stream, async_
+ ...
+ ys = stream.map(xs, async_(lambda ms: asyncio.sleep(ms / 1000)))
+ """
+ if asyncio.iscoroutinefunction(func):
+ return amap.raw(
+ source, func, *more_sources, ordered=ordered, task_limit=task_limit
+ )
+ sync_func = cast("SmapCallable[T, U]", func)
+ return smap.raw(source, sync_func, *more_sources)
+
+
+@pipable_operator
+def merge(
+ source: AsyncIterable[T], *more_sources: AsyncIterable[T]
+) -> AsyncIterator[T]:
+ """Merge several asynchronous sequences together.
+
+ All the sequences are iterated simultaneously and their elements
+ are forwarded as soon as they're available. The generation continues
+ until all the sequences are exhausted.
+ """
+ sources = [source, *more_sources]
+ source_stream: AsyncIterable[AsyncIterable[T]] = create.iterate.raw(sources)
+ return advanced.flatten.raw(source_stream)
+
+
+@pipable_operator
+def ziplatest(
+ source: AsyncIterable[T],
+ *more_sources: AsyncIterable[T],
+ partial: bool = True,
+ default: T | None = None,
+) -> AsyncIterator[tuple[T | None, ...]]:
+ """Combine several asynchronous sequences together, producing a tuple with
+ the lastest element of each sequence whenever a new element is received.
+
+ The value to use when a sequence has not procuded any element yet is given
+ by the ``default`` argument (defaulting to ``None``).
+
+ The producing of partial results can be disabled by setting the optional
+ argument ``partial`` to ``False``.
+
+ All the sequences are iterated simultaneously and their elements
+ are forwarded as soon as they're available. The generation continues
+ until all the sequences are exhausted.
+ """
+ sources = source, *more_sources
+ n = len(sources)
+
+ # Custom getter
+ def getter(dct: dict[int, T]) -> Callable[[int], T | None]:
+ return lambda key: dct.get(key, default)
+
+ # Add source index to the items
+ def make_func(i: int) -> SmapCallable[T, dict[int, T]]:
+ def func(x: T, *_: object) -> dict[int, T]:
+ return {i: x}
+
+ return func
+
+ new_sources = [smap.raw(source, make_func(i)) for i, source in enumerate(sources)]
+
+ # Merge the sources
+ merged = merge.raw(*new_sources)
+
+ # Accumulate the current state in a dict
+ accumulated = aggregate.accumulate.raw(merged, lambda x, e: {**x, **e})
+
+ # Filter partial result
+ filtered = (
+ accumulated
+ if partial
+ else select.filter.raw(accumulated, lambda x: len(x) == n)
+ )
+
+ # Convert the state dict to a tuple
+ def dict_to_tuple(x: dict[int, T], *_: object) -> tuple[T | None, ...]:
+ return tuple(builtins.map(getter(x), range(n)))
+
+ return smap.raw(filtered, dict_to_tuple)