1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
|
"""Advanced operators (to deal with streams of higher order) ."""
from __future__ import annotations
from typing import AsyncIterator, AsyncIterable, TypeVar, Union, cast
from typing_extensions import ParamSpec
from . import combine
from ..core import Streamer, pipable_operator
from ..manager import StreamerManager
__all__ = ["concat", "flatten", "switch", "concatmap", "flatmap", "switchmap"]
T = TypeVar("T")
U = TypeVar("U")
P = ParamSpec("P")
# Helper to manage stream of higher order
@pipable_operator
async def base_combine(
source: AsyncIterable[AsyncIterable[T]],
switch: bool = False,
ordered: bool = False,
task_limit: int | None = None,
) -> AsyncIterator[T]:
"""Base operator for managing an asynchronous sequence of sequences.
The sequences are awaited concurrently, although it's possible to limit
the amount of running sequences using the `task_limit` argument.
The ``switch`` argument enables the switch mecanism, which cause the
previous subsequence to be discarded when a new one is created.
The items can either be generated in order or as soon as they're received,
depending on the ``ordered`` argument.
"""
# Task limit
if task_limit is not None and not task_limit > 0:
raise ValueError("The task limit must be None or greater than 0")
# Safe context
async with StreamerManager[Union[AsyncIterable[T], T]]() as manager:
main_streamer: Streamer[
AsyncIterable[T] | T
] | None = await manager.enter_and_create_task(source)
# Loop over events
while manager.tasks:
# Extract streamer groups
substreamers = manager.streamers[1:]
mainstreamers = [main_streamer] if main_streamer in manager.tasks else []
# Switch - use the main streamer then the substreamer
if switch:
filters = mainstreamers + substreamers
# Concat - use the first substreamer then the main streamer
elif ordered:
filters = substreamers[:1] + mainstreamers
# Flat - use the substreamers then the main streamer
else:
filters = substreamers + mainstreamers
# Wait for next event
streamer, task = await manager.wait_single_event(filters)
# Get result
try:
result = task.result()
# End of stream
except StopAsyncIteration:
# Main streamer is finished
if streamer is main_streamer:
main_streamer = None
# A substreamer is finished
else:
await manager.clean_streamer(streamer)
# Re-schedule the main streamer if necessary
if main_streamer is not None and main_streamer not in manager.tasks:
manager.create_task(main_streamer)
# Process result
else:
# Switch mecanism
if switch and streamer is main_streamer:
await manager.clean_streamers(substreamers)
# Setup a new source
if streamer is main_streamer:
assert isinstance(result, AsyncIterable)
await manager.enter_and_create_task(result)
# Re-schedule the main streamer if task limit allows it
if task_limit is None or task_limit > len(manager.tasks):
manager.create_task(streamer)
# Yield the result
else:
item = cast("T", result)
yield item
# Re-schedule the streamer
manager.create_task(streamer)
# Advanced operators (for streams of higher order)
@pipable_operator
def concat(
source: AsyncIterable[AsyncIterable[T]], task_limit: int | None = None
) -> AsyncIterator[T]:
"""Given an asynchronous sequence of sequences, generate the elements
of the sequences in order.
The sequences are awaited concurrently, although it's possible to limit
the amount of running sequences using the `task_limit` argument.
Errors raised in the source or an element sequence are propagated.
"""
return base_combine.raw(source, task_limit=task_limit, switch=False, ordered=True)
@pipable_operator
def flatten(
source: AsyncIterable[AsyncIterable[T]], task_limit: int | None = None
) -> AsyncIterator[T]:
"""Given an asynchronous sequence of sequences, generate the elements
of the sequences as soon as they're received.
The sequences are awaited concurrently, although it's possible to limit
the amount of running sequences using the `task_limit` argument.
Errors raised in the source or an element sequence are propagated.
"""
return base_combine.raw(source, task_limit=task_limit, switch=False, ordered=False)
@pipable_operator
def switch(source: AsyncIterable[AsyncIterable[T]]) -> AsyncIterator[T]:
"""Given an asynchronous sequence of sequences, generate the elements of
the most recent sequence.
Element sequences are generated eagerly, and closed once they are
superseded by a more recent sequence. Once the main sequence is finished,
the last subsequence will be exhausted completely.
Errors raised in the source or an element sequence (that was not already
closed) are propagated.
"""
return base_combine.raw(source, switch=True)
# Advanced *-map operators
@pipable_operator
def concatmap(
source: AsyncIterable[T],
func: combine.SmapCallable[T, AsyncIterable[U]],
*more_sources: AsyncIterable[T],
task_limit: int | None = None,
) -> AsyncIterator[U]:
"""Apply a given function that creates a sequence from the elements of one
or several asynchronous sequences, and generate the elements of the created
sequences in order.
The function is applied as described in `map`, and must return an
asynchronous sequence. The returned sequences are awaited concurrently,
although it's possible to limit the amount of running sequences using
the `task_limit` argument.
"""
mapped = combine.smap.raw(source, func, *more_sources)
return concat.raw(mapped, task_limit=task_limit)
@pipable_operator
def flatmap(
source: AsyncIterable[T],
func: combine.SmapCallable[T, AsyncIterable[U]],
*more_sources: AsyncIterable[T],
task_limit: int | None = None,
) -> AsyncIterator[U]:
"""Apply a given function that creates a sequence from the elements of one
or several asynchronous sequences, and generate the elements of the created
sequences as soon as they arrive.
The function is applied as described in `map`, and must return an
asynchronous sequence. The returned sequences are awaited concurrently,
although it's possible to limit the amount of running sequences using
the `task_limit` argument.
Errors raised in a source or output sequence are propagated.
"""
mapped = combine.smap.raw(source, func, *more_sources)
return flatten.raw(mapped, task_limit=task_limit)
@pipable_operator
def switchmap(
source: AsyncIterable[T],
func: combine.SmapCallable[T, AsyncIterable[U]],
*more_sources: AsyncIterable[T],
) -> AsyncIterator[U]:
"""Apply a given function that creates a sequence from the elements of one
or several asynchronous sequences and generate the elements of the most
recently created sequence.
The function is applied as described in `map`, and must return an
asynchronous sequence. Errors raised in a source or output sequence (that
was not already closed) are propagated.
"""
mapped = combine.smap.raw(source, func, *more_sources)
return switch.raw(mapped)
|