about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/aiostream/stream/select.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiostream/stream/select.py')
-rw-r--r--.venv/lib/python3.12/site-packages/aiostream/stream/select.py284
1 files changed, 284 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiostream/stream/select.py b/.venv/lib/python3.12/site-packages/aiostream/stream/select.py
new file mode 100644
index 00000000..9390f464
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/aiostream/stream/select.py
@@ -0,0 +1,284 @@
+"""Selection operators."""
+from __future__ import annotations
+
+import asyncio
+import builtins
+import collections
+
+from typing import Awaitable, Callable, TypeVar, AsyncIterable, AsyncIterator
+
+from . import transform
+from ..aiter_utils import aiter, anext
+from ..core import streamcontext, pipable_operator
+
+__all__ = [
+    "take",
+    "takelast",
+    "skip",
+    "skiplast",
+    "getitem",
+    "filter",
+    "until",
+    "dropwhile",
+    "takewhile",
+]
+
+T = TypeVar("T")
+
+
+@pipable_operator
+async def take(source: AsyncIterable[T], n: int) -> AsyncIterator[T]:
+    """Forward the first ``n`` elements from an asynchronous sequence.
+
+    If ``n`` is negative, it simply terminates before iterating the source.
+    """
+    enumerated = transform.enumerate.raw(source)
+    async with streamcontext(enumerated) as streamer:
+        if n <= 0:
+            return
+        async for i, item in streamer:
+            yield item
+            if i >= n - 1:
+                return
+
+
+@pipable_operator
+async def takelast(source: AsyncIterable[T], n: int) -> AsyncIterator[T]:
+    """Forward the last ``n`` elements from an asynchronous sequence.
+
+    If ``n`` is negative, it simply terminates after iterating the source.
+
+    Note: it is required to reach the end of the source before the first
+    element is generated.
+    """
+    queue: collections.deque[T] = collections.deque(maxlen=n if n > 0 else 0)
+    async with streamcontext(source) as streamer:
+        async for item in streamer:
+            queue.append(item)
+        for item in queue:
+            yield item
+
+
+@pipable_operator
+async def skip(source: AsyncIterable[T], n: int) -> AsyncIterator[T]:
+    """Forward an asynchronous sequence, skipping the first ``n`` elements.
+
+    If ``n`` is negative, no elements are skipped.
+    """
+    enumerated = transform.enumerate.raw(source)
+    async with streamcontext(enumerated) as streamer:
+        async for i, item in streamer:
+            if i >= n:
+                yield item
+
+
+@pipable_operator
+async def skiplast(source: AsyncIterable[T], n: int) -> AsyncIterator[T]:
+    """Forward an asynchronous sequence, skipping the last ``n`` elements.
+
+    If ``n`` is negative, no elements are skipped.
+
+    Note: it is required to reach the ``n+1`` th element of the source
+    before the first element is generated.
+    """
+    queue: collections.deque[T] = collections.deque(maxlen=n if n > 0 else 0)
+    async with streamcontext(source) as streamer:
+        async for item in streamer:
+            if n <= 0:
+                yield item
+                continue
+            if len(queue) == n:
+                yield queue[0]
+            queue.append(item)
+
+
+@pipable_operator
+async def filterindex(
+    source: AsyncIterable[T], func: Callable[[int], bool]
+) -> AsyncIterator[T]:
+    """Filter an asynchronous sequence using the index of the elements.
+
+    The given function is synchronous, takes the index as an argument,
+    and returns ``True`` if the corresponding should be forwarded,
+    ``False`` otherwise.
+    """
+    enumerated = transform.enumerate.raw(source)
+    async with streamcontext(enumerated) as streamer:
+        async for i, item in streamer:
+            if func(i):
+                yield item
+
+
+@pipable_operator
+def slice(source: AsyncIterable[T], *args: int) -> AsyncIterator[T]:
+    """Slice an asynchronous sequence.
+
+    The arguments are the same as the builtin type slice.
+
+    There are two limitations compare to regular slices:
+    - Positive stop index with negative start index is not supported
+    - Negative step is not supported
+    """
+    s = builtins.slice(*args)
+    start, stop, step = s.start or 0, s.stop, s.step or 1
+    aiterator = aiter(source)
+    # Filter the first items
+    if start < 0:
+        aiterator = takelast.raw(aiterator, abs(start))
+    elif start > 0:
+        aiterator = skip.raw(aiterator, start)
+    # Filter the last items
+    if stop is not None:
+        if stop >= 0 and start < 0:
+            raise ValueError("Positive stop with negative start is not supported")
+        elif stop >= 0:
+            aiterator = take.raw(aiterator, stop - start)
+        else:
+            aiterator = skiplast.raw(aiterator, abs(stop))
+    # Filter step items
+    if step is not None:
+        if step > 1:
+            aiterator = filterindex.raw(aiterator, lambda i: i % step == 0)
+        elif step < 0:
+            raise ValueError("Negative step not supported")
+    # Return
+    return aiterator
+
+
+@pipable_operator
+async def item(source: AsyncIterable[T], index: int) -> AsyncIterator[T]:
+    """Forward the ``n``th element of an asynchronous sequence.
+
+    The index can be negative and works like regular indexing.
+    If the index is out of range, and ``IndexError`` is raised.
+    """
+    # Prepare
+    if index >= 0:
+        source = skip.raw(source, index)
+    else:
+        source = takelast(source, abs(index))
+    async with streamcontext(source) as streamer:
+        # Get first item
+        try:
+            result = await anext(streamer)
+        except StopAsyncIteration:
+            raise IndexError("Index out of range")
+        # Check length
+        if index < 0:
+            count = 1
+            async for _ in streamer:
+                count += 1
+            if count != abs(index):
+                raise IndexError("Index out of range")
+        # Yield result
+        yield result
+
+
+@pipable_operator
+def getitem(source: AsyncIterable[T], index: int | builtins.slice) -> AsyncIterator[T]:
+    """Forward one or several items from an asynchronous sequence.
+
+    The argument can either be a slice or an integer.
+    See the slice and item operators for more information.
+    """
+    if isinstance(index, builtins.slice):
+        return slice.raw(source, index.start, index.stop, index.step)
+    if isinstance(index, int):
+        return item.raw(source, index)
+    raise TypeError("Not a valid index (int or slice)")
+
+
+@pipable_operator
+async def filter(
+    source: AsyncIterable[T], func: Callable[[T], bool | Awaitable[bool]]
+) -> AsyncIterator[T]:
+    """Filter an asynchronous sequence using an arbitrary function.
+
+    The function takes the item as an argument and returns ``True``
+    if it should be forwarded, ``False`` otherwise.
+    The function can either be synchronous or asynchronous.
+    """
+    iscorofunc = asyncio.iscoroutinefunction(func)
+    async with streamcontext(source) as streamer:
+        async for item in streamer:
+            result = func(item)
+            if iscorofunc:
+                assert isinstance(result, Awaitable)
+                result = await result
+            if result:
+                yield item
+
+
+@pipable_operator
+async def until(
+    source: AsyncIterable[T], func: Callable[[T], bool | Awaitable[bool]]
+) -> AsyncIterator[T]:
+    """Forward an asynchronous sequence until a condition is met.
+
+    Contrary to the ``takewhile`` operator, the last tested element is included
+    in the sequence.
+
+    The given function takes the item as an argument and returns a boolean
+    corresponding to the condition to meet. The function can either be
+    synchronous or asynchronous.
+    """
+    iscorofunc = asyncio.iscoroutinefunction(func)
+    async with streamcontext(source) as streamer:
+        async for item in streamer:
+            result = func(item)
+            if iscorofunc:
+                assert isinstance(result, Awaitable)
+                result = await result
+            yield item
+            if result:
+                return
+
+
+@pipable_operator
+async def takewhile(
+    source: AsyncIterable[T], func: Callable[[T], bool | Awaitable[bool]]
+) -> AsyncIterator[T]:
+    """Forward an asynchronous sequence while a condition is met.
+
+    Contrary to the ``until`` operator, the last tested element is not included
+    in the sequence.
+
+    The given function takes the item as an argument and returns a boolean
+    corresponding to the condition to meet. The function can either be
+    synchronous or asynchronous.
+    """
+    iscorofunc = asyncio.iscoroutinefunction(func)
+    async with streamcontext(source) as streamer:
+        async for item in streamer:
+            result = func(item)
+            if iscorofunc:
+                assert isinstance(result, Awaitable)
+                result = await result
+            if not result:
+                return
+            yield item
+
+
+@pipable_operator
+async def dropwhile(
+    source: AsyncIterable[T], func: Callable[[T], bool | Awaitable[bool]]
+) -> AsyncIterator[T]:
+    """Discard the elements from an asynchronous sequence
+    while a condition is met.
+
+    The given function takes the item as an argument and returns a boolean
+    corresponding to the condition to meet. The function can either be
+    synchronous or asynchronous.
+    """
+    iscorofunc = asyncio.iscoroutinefunction(func)
+    async with streamcontext(source) as streamer:
+        async for item in streamer:
+            result = func(item)
+            if iscorofunc:
+                assert isinstance(result, Awaitable)
+                result = await result
+            if not result:
+                yield item
+                break
+        async for item in streamer:
+            yield item