aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/aiostream/stream/advanced.py
blob: 106f0f1d2e7a25a8fd5dba9b2682ae3fd81802cd (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""Advanced operators (to deal with streams of higher order) ."""
from __future__ import annotations

from typing import AsyncIterator, AsyncIterable, TypeVar, Union, cast
from typing_extensions import ParamSpec

from . import combine

from ..core import Streamer, pipable_operator
from ..manager import StreamerManager


__all__ = ["concat", "flatten", "switch", "concatmap", "flatmap", "switchmap"]


T = TypeVar("T")
U = TypeVar("U")
P = ParamSpec("P")


# Helper to manage stream of higher order


@pipable_operator
async def base_combine(
    source: AsyncIterable[AsyncIterable[T]],
    switch: bool = False,
    ordered: bool = False,
    task_limit: int | None = None,
) -> AsyncIterator[T]:
    """Base operator for managing an asynchronous sequence of sequences.

    The sequences are awaited concurrently, although it's possible to limit
    the amount of running sequences using the `task_limit` argument.

    The ``switch`` argument enables the switch mecanism, which cause the
    previous subsequence to be discarded when a new one is created.

    The items can either be generated in order or as soon as they're received,
    depending on the ``ordered`` argument.
    """

    # Task limit
    if task_limit is not None and not task_limit > 0:
        raise ValueError("The task limit must be None or greater than 0")

    # Safe context
    async with StreamerManager[Union[AsyncIterable[T], T]]() as manager:
        main_streamer: Streamer[
            AsyncIterable[T] | T
        ] | None = await manager.enter_and_create_task(source)

        # Loop over events
        while manager.tasks:
            # Extract streamer groups
            substreamers = manager.streamers[1:]
            mainstreamers = [main_streamer] if main_streamer in manager.tasks else []

            # Switch - use the main streamer then the substreamer
            if switch:
                filters = mainstreamers + substreamers
            # Concat - use the first substreamer then the main streamer
            elif ordered:
                filters = substreamers[:1] + mainstreamers
            # Flat - use the substreamers then the main streamer
            else:
                filters = substreamers + mainstreamers

            # Wait for next event
            streamer, task = await manager.wait_single_event(filters)

            # Get result
            try:
                result = task.result()

            # End of stream
            except StopAsyncIteration:
                # Main streamer is finished
                if streamer is main_streamer:
                    main_streamer = None

                # A substreamer is finished
                else:
                    await manager.clean_streamer(streamer)

                    # Re-schedule the main streamer if necessary
                    if main_streamer is not None and main_streamer not in manager.tasks:
                        manager.create_task(main_streamer)

            # Process result
            else:
                # Switch mecanism
                if switch and streamer is main_streamer:
                    await manager.clean_streamers(substreamers)

                # Setup a new source
                if streamer is main_streamer:
                    assert isinstance(result, AsyncIterable)
                    await manager.enter_and_create_task(result)

                    # Re-schedule the main streamer if task limit allows it
                    if task_limit is None or task_limit > len(manager.tasks):
                        manager.create_task(streamer)

                # Yield the result
                else:
                    item = cast("T", result)
                    yield item

                    # Re-schedule the streamer
                    manager.create_task(streamer)


# Advanced operators (for streams of higher order)


@pipable_operator
def concat(
    source: AsyncIterable[AsyncIterable[T]], task_limit: int | None = None
) -> AsyncIterator[T]:
    """Given an asynchronous sequence of sequences, generate the elements
    of the sequences in order.

    The sequences are awaited concurrently, although it's possible to limit
    the amount of running sequences using the `task_limit` argument.

    Errors raised in the source or an element sequence are propagated.
    """
    return base_combine.raw(source, task_limit=task_limit, switch=False, ordered=True)


@pipable_operator
def flatten(
    source: AsyncIterable[AsyncIterable[T]], task_limit: int | None = None
) -> AsyncIterator[T]:
    """Given an asynchronous sequence of sequences, generate the elements
    of the sequences as soon as they're received.

    The sequences are awaited concurrently, although it's possible to limit
    the amount of running sequences using the `task_limit` argument.

    Errors raised in the source or an element sequence are propagated.
    """
    return base_combine.raw(source, task_limit=task_limit, switch=False, ordered=False)


@pipable_operator
def switch(source: AsyncIterable[AsyncIterable[T]]) -> AsyncIterator[T]:
    """Given an asynchronous sequence of sequences, generate the elements of
    the most recent sequence.

    Element sequences are generated eagerly, and closed once they are
    superseded by a more recent sequence. Once the main sequence is finished,
    the last subsequence will be exhausted completely.

    Errors raised in the source or an element sequence (that was not already
    closed) are propagated.
    """
    return base_combine.raw(source, switch=True)


# Advanced *-map operators


@pipable_operator
def concatmap(
    source: AsyncIterable[T],
    func: combine.SmapCallable[T, AsyncIterable[U]],
    *more_sources: AsyncIterable[T],
    task_limit: int | None = None,
) -> AsyncIterator[U]:
    """Apply a given function that creates a sequence from the elements of one
    or several asynchronous sequences, and generate the elements of the created
    sequences in order.

    The function is applied as described in `map`, and must return an
    asynchronous sequence. The returned sequences are awaited concurrently,
    although it's possible to limit the amount of running sequences using
    the `task_limit` argument.
    """
    mapped = combine.smap.raw(source, func, *more_sources)
    return concat.raw(mapped, task_limit=task_limit)


@pipable_operator
def flatmap(
    source: AsyncIterable[T],
    func: combine.SmapCallable[T, AsyncIterable[U]],
    *more_sources: AsyncIterable[T],
    task_limit: int | None = None,
) -> AsyncIterator[U]:
    """Apply a given function that creates a sequence from the elements of one
    or several asynchronous sequences, and generate the elements of the created
    sequences as soon as they arrive.

    The function is applied as described in `map`, and must return an
    asynchronous sequence. The returned sequences are awaited concurrently,
    although it's possible to limit the amount of running sequences using
    the `task_limit` argument.

    Errors raised in a source or output sequence are propagated.
    """
    mapped = combine.smap.raw(source, func, *more_sources)
    return flatten.raw(mapped, task_limit=task_limit)


@pipable_operator
def switchmap(
    source: AsyncIterable[T],
    func: combine.SmapCallable[T, AsyncIterable[U]],
    *more_sources: AsyncIterable[T],
) -> AsyncIterator[U]:
    """Apply a given function that creates a sequence from the elements of one
    or several asynchronous sequences and generate the elements of the most
    recently created sequence.

    The function is applied as described in `map`, and must return an
    asynchronous sequence. Errors raised in a source or output sequence (that
    was not already closed) are propagated.
    """
    mapped = combine.smap.raw(source, func, *more_sources)
    return switch.raw(mapped)