aboutsummaryrefslogtreecommitdiff
"""Transformation operators."""

from __future__ import annotations

import asyncio
import itertools
from typing import (
    Protocol,
    TypeVar,
    AsyncIterable,
    AsyncIterator,
    Awaitable,
    cast,
)

from ..core import streamcontext, pipable_operator

from . import select
from . import create
from . import aggregate
from .combine import map, amap, smap

__all__ = ["map", "enumerate", "starmap", "cycle", "chunks"]

# map, amap and smap are also transform operators
map, amap, smap

T = TypeVar("T")
U = TypeVar("U")


@pipable_operator
async def enumerate(
    source: AsyncIterable[T], start: int = 0, step: int = 1
) -> AsyncIterator[tuple[int, T]]:
    """Generate ``(index, value)`` tuples from an asynchronous sequence.

    This index is computed using a starting point and an increment,
    respectively defaulting to ``0`` and ``1``.
    """
    count = itertools.count(start, step)
    async with streamcontext(source) as streamer:
        async for item in streamer:
            yield next(count), item


X = TypeVar("X", contravariant=True)
Y = TypeVar("Y", covariant=True)


class AsyncStarmapCallable(Protocol[X, Y]):
    def __call__(self, arg: X, /, *args: X) -> Awaitable[Y]:
        ...


class SyncStarmapCallable(Protocol[X, Y]):
    def __call__(self, arg: X, /, *args: X) -> Y:
        ...


@pipable_operator
def starmap(
    source: AsyncIterable[tuple[T, ...]],
    func: SyncStarmapCallable[T, U] | AsyncStarmapCallable[T, U],
    ordered: bool = True,
    task_limit: int | None = None,
) -> AsyncIterator[U]:
    """Apply a given function to the unpacked elements of
    an asynchronous sequence.

    Each element is unpacked before applying the function.
    The given function can either be synchronous or asynchronous.

    The results can either be returned in or out of order, depending on
    the corresponding ``ordered`` argument. This argument is ignored if
    the provided function is synchronous.

    The coroutines run concurrently but their amount can be limited using
    the ``task_limit`` argument. A value of ``1`` will cause the coroutines
    to run sequentially. This argument is ignored if the provided function
    is synchronous.
    """
    if asyncio.iscoroutinefunction(func):
        async_func = cast("AsyncStarmapCallable[T, U]", func)

        async def astarfunc(args: tuple[T, ...], *_: object) -> U:
            awaitable = async_func(*args)
            return await awaitable

        return amap.raw(source, astarfunc, ordered=ordered, task_limit=task_limit)

    else:
        sync_func = cast("SyncStarmapCallable[T, U]", func)

        def starfunc(args: tuple[T, ...], *_: object) -> U:
            return sync_func(*args)

        return smap.raw(source, starfunc)


@pipable_operator
async def cycle(source: AsyncIterable[T]) -> AsyncIterator[T]:
    """Iterate indefinitely over an asynchronous sequence.

    Note: it does not perform any buffering, but re-iterate over
    the same given sequence instead. If the sequence is not
    re-iterable, the generator might end up looping indefinitely
    without yielding any item.
    """
    while True:
        async with streamcontext(source) as streamer:
            async for item in streamer:
                yield item
            # Prevent blocking while loop if the stream is empty
            await asyncio.sleep(0)


@pipable_operator
async def chunks(source: AsyncIterable[T], n: int) -> AsyncIterator[list[T]]:
    """Generate chunks of size ``n`` from an asynchronous sequence.

    The chunks are lists, and the last chunk might contain less than ``n``
    elements.
    """
    async with streamcontext(source) as streamer:
        async for first in streamer:
            xs = select.take(create.preserve(streamer), n - 1)
            yield [first] + await aggregate.list(xs)