aboutsummaryrefslogtreecommitdiff
"""Advanced operators (to deal with streams of higher order) ."""
from __future__ import annotations

from typing import AsyncIterator, AsyncIterable, TypeVar, Union, cast
from typing_extensions import ParamSpec

from . import combine

from ..core import Streamer, pipable_operator
from ..manager import StreamerManager


__all__ = ["concat", "flatten", "switch", "concatmap", "flatmap", "switchmap"]


T = TypeVar("T")
U = TypeVar("U")
P = ParamSpec("P")


# Helper to manage stream of higher order


@pipable_operator
async def base_combine(
    source: AsyncIterable[AsyncIterable[T]],
    switch: bool = False,
    ordered: bool = False,
    task_limit: int | None = None,
) -> AsyncIterator[T]:
    """Base operator for managing an asynchronous sequence of sequences.

    The sequences are awaited concurrently, although it's possible to limit
    the amount of running sequences using the `task_limit` argument.

    The ``switch`` argument enables the switch mecanism, which cause the
    previous subsequence to be discarded when a new one is created.

    The items can either be generated in order or as soon as they're received,
    depending on the ``ordered`` argument.
    """

    # Task limit
    if task_limit is not None and not task_limit > 0:
        raise ValueError("The task limit must be None or greater than 0")

    # Safe context
    async with StreamerManager[Union[AsyncIterable[T], T]]() as manager:
        main_streamer: Streamer[
            AsyncIterable[T] | T
        ] | None = await manager.enter_and_create_task(source)

        # Loop over events
        while manager.tasks:
            # Extract streamer groups
            substreamers = manager.streamers[1:]
            mainstreamers = [main_streamer] if main_streamer in manager.tasks else []

            # Switch - use the main streamer then the substreamer
            if switch:
                filters = mainstreamers + substreamers
            # Concat - use the first substreamer then the main streamer
            elif ordered:
                filters = substreamers[:1] + mainstreamers
            # Flat - use the substreamers then the main streamer
            else:
                filters = substreamers + mainstreamers

            # Wait for next event
            streamer, task = await manager.wait_single_event(filters)

            # Get result
            try:
                result = task.result()

            # End of stream
            except StopAsyncIteration:
                # Main streamer is finished
                if streamer is main_streamer:
                    main_streamer = None

                # A substreamer is finished
                else:
                    await manager.clean_streamer(streamer)

                    # Re-schedule the main streamer if necessary
                    if main_streamer is not None and main_streamer not in manager.tasks:
                        manager.create_task(main_streamer)

            # Process result
            else:
                # Switch mecanism
                if switch and streamer is main_streamer:
                    await manager.clean_streamers(substreamers)

                # Setup a new source
                if streamer is main_streamer:
                    assert isinstance(result, AsyncIterable)
                    await manager.enter_and_create_task(result)

                    # Re-schedule the main streamer if task limit allows it
                    if task_limit is None or task_limit > len(manager.tasks):
                        manager.create_task(streamer)

                # Yield the result
                else:
                    item = cast("T", result)
                    yield item

                    # Re-schedule the streamer
                    manager.create_task(streamer)


# Advanced operators (for streams of higher order)


@pipable_operator
def concat(
    source: AsyncIterable[AsyncIterable[T]], task_limit: int | None = None
) -> AsyncIterator[T]:
    """Given an asynchronous sequence of sequences, generate the elements
    of the sequences in order.

    The sequences are awaited concurrently, although it's possible to limit
    the amount of running sequences using the `task_limit` argument.

    Errors raised in the source or an element sequence are propagated.
    """
    return base_combine.raw(source, task_limit=task_limit, switch=False, ordered=True)


@pipable_operator
def flatten(
    source: AsyncIterable[AsyncIterable[T]], task_limit: int | None = None
) -> AsyncIterator[T]:
    """Given an asynchronous sequence of sequences, generate the elements
    of the sequences as soon as they're received.

    The sequences are awaited concurrently, although it's possible to limit
    the amount of running sequences using the `task_limit` argument.

    Errors raised in the source or an element sequence are propagated.
    """
    return base_combine.raw(source, task_limit=task_limit, switch=False, ordered=False)


@pipable_operator
def switch(source: AsyncIterable[AsyncIterable[T]]) -> AsyncIterator[T]:
    """Given an asynchronous sequence of sequences, generate the elements of
    the most recent sequence.

    Element sequences are generated eagerly, and closed once they are
    superseded by a more recent sequence. Once the main sequence is finished,
    the last subsequence will be exhausted completely.

    Errors raised in the source or an element sequence (that was not already
    closed) are propagated.
    """
    return base_combine.raw(source, switch=True)


# Advanced *-map operators


@pipable_operator
def concatmap(
    source: AsyncIterable[T],
    func: combine.SmapCallable[T, AsyncIterable[U]],
    *more_sources: AsyncIterable[T],
    task_limit: int | None = None,
) -> AsyncIterator[U]:
    """Apply a given function that creates a sequence from the elements of one
    or several asynchronous sequences, and generate the elements of the created
    sequences in order.

    The function is applied as described in `map`, and must return an
    asynchronous sequence. The returned sequences are awaited concurrently,
    although it's possible to limit the amount of running sequences using
    the `task_limit` argument.
    """
    mapped = combine.smap.raw(source, func, *more_sources)
    return concat.raw(mapped, task_limit=task_limit)


@pipable_operator
def flatmap(
    source: AsyncIterable[T],
    func: combine.SmapCallable[T, AsyncIterable[U]],
    *more_sources: AsyncIterable[T],
    task_limit: int | None = None,
) -> AsyncIterator[U]:
    """Apply a given function that creates a sequence from the elements of one
    or several asynchronous sequences, and generate the elements of the created
    sequences as soon as they arrive.

    The function is applied as described in `map`, and must return an
    asynchronous sequence. The returned sequences are awaited concurrently,
    although it's possible to limit the amount of running sequences using
    the `task_limit` argument.

    Errors raised in a source or output sequence are propagated.
    """
    mapped = combine.smap.raw(source, func, *more_sources)
    return flatten.raw(mapped, task_limit=task_limit)


@pipable_operator
def switchmap(
    source: AsyncIterable[T],
    func: combine.SmapCallable[T, AsyncIterable[U]],
    *more_sources: AsyncIterable[T],
) -> AsyncIterator[U]:
    """Apply a given function that creates a sequence from the elements of one
    or several asynchronous sequences and generate the elements of the most
    recently created sequence.

    The function is applied as described in `map`, and must return an
    asynchronous sequence. Errors raised in a source or output sequence (that
    was not already closed) are propagated.
    """
    mapped = combine.smap.raw(source, func, *more_sources)
    return switch.raw(mapped)