"""Selection operators.""" from __future__ import annotations import asyncio import builtins import collections from typing import Awaitable, Callable, TypeVar, AsyncIterable, AsyncIterator from . import transform from ..aiter_utils import aiter, anext from ..core import streamcontext, pipable_operator __all__ = [ "take", "takelast", "skip", "skiplast", "getitem", "filter", "until", "dropwhile", "takewhile", ] T = TypeVar("T") @pipable_operator async def take(source: AsyncIterable[T], n: int) -> AsyncIterator[T]: """Forward the first ``n`` elements from an asynchronous sequence. If ``n`` is negative, it simply terminates before iterating the source. """ enumerated = transform.enumerate.raw(source) async with streamcontext(enumerated) as streamer: if n <= 0: return async for i, item in streamer: yield item if i >= n - 1: return @pipable_operator async def takelast(source: AsyncIterable[T], n: int) -> AsyncIterator[T]: """Forward the last ``n`` elements from an asynchronous sequence. If ``n`` is negative, it simply terminates after iterating the source. Note: it is required to reach the end of the source before the first element is generated. """ queue: collections.deque[T] = collections.deque(maxlen=n if n > 0 else 0) async with streamcontext(source) as streamer: async for item in streamer: queue.append(item) for item in queue: yield item @pipable_operator async def skip(source: AsyncIterable[T], n: int) -> AsyncIterator[T]: """Forward an asynchronous sequence, skipping the first ``n`` elements. If ``n`` is negative, no elements are skipped. """ enumerated = transform.enumerate.raw(source) async with streamcontext(enumerated) as streamer: async for i, item in streamer: if i >= n: yield item @pipable_operator async def skiplast(source: AsyncIterable[T], n: int) -> AsyncIterator[T]: """Forward an asynchronous sequence, skipping the last ``n`` elements. If ``n`` is negative, no elements are skipped. Note: it is required to reach the ``n+1`` th element of the source before the first element is generated. """ queue: collections.deque[T] = collections.deque(maxlen=n if n > 0 else 0) async with streamcontext(source) as streamer: async for item in streamer: if n <= 0: yield item continue if len(queue) == n: yield queue[0] queue.append(item) @pipable_operator async def filterindex( source: AsyncIterable[T], func: Callable[[int], bool] ) -> AsyncIterator[T]: """Filter an asynchronous sequence using the index of the elements. The given function is synchronous, takes the index as an argument, and returns ``True`` if the corresponding should be forwarded, ``False`` otherwise. """ enumerated = transform.enumerate.raw(source) async with streamcontext(enumerated) as streamer: async for i, item in streamer: if func(i): yield item @pipable_operator def slice(source: AsyncIterable[T], *args: int) -> AsyncIterator[T]: """Slice an asynchronous sequence. The arguments are the same as the builtin type slice. There are two limitations compare to regular slices: - Positive stop index with negative start index is not supported - Negative step is not supported """ s = builtins.slice(*args) start, stop, step = s.start or 0, s.stop, s.step or 1 aiterator = aiter(source) # Filter the first items if start < 0: aiterator = takelast.raw(aiterator, abs(start)) elif start > 0: aiterator = skip.raw(aiterator, start) # Filter the last items if stop is not None: if stop >= 0 and start < 0: raise ValueError("Positive stop with negative start is not supported") elif stop >= 0: aiterator = take.raw(aiterator, stop - start) else: aiterator = skiplast.raw(aiterator, abs(stop)) # Filter step items if step is not None: if step > 1: aiterator = filterindex.raw(aiterator, lambda i: i % step == 0) elif step < 0: raise ValueError("Negative step not supported") # Return return aiterator @pipable_operator async def item(source: AsyncIterable[T], index: int) -> AsyncIterator[T]: """Forward the ``n``th element of an asynchronous sequence. The index can be negative and works like regular indexing. If the index is out of range, and ``IndexError`` is raised. """ # Prepare if index >= 0: source = skip.raw(source, index) else: source = takelast(source, abs(index)) async with streamcontext(source) as streamer: # Get first item try: result = await anext(streamer) except StopAsyncIteration: raise IndexError("Index out of range") # Check length if index < 0: count = 1 async for _ in streamer: count += 1 if count != abs(index): raise IndexError("Index out of range") # Yield result yield result @pipable_operator def getitem(source: AsyncIterable[T], index: int | builtins.slice) -> AsyncIterator[T]: """Forward one or several items from an asynchronous sequence. The argument can either be a slice or an integer. See the slice and item operators for more information. """ if isinstance(index, builtins.slice): return slice.raw(source, index.start, index.stop, index.step) if isinstance(index, int): return item.raw(source, index) raise TypeError("Not a valid index (int or slice)") @pipable_operator async def filter( source: AsyncIterable[T], func: Callable[[T], bool | Awaitable[bool]] ) -> AsyncIterator[T]: """Filter an asynchronous sequence using an arbitrary function. The function takes the item as an argument and returns ``True`` if it should be forwarded, ``False`` otherwise. The function can either be synchronous or asynchronous. """ iscorofunc = asyncio.iscoroutinefunction(func) async with streamcontext(source) as streamer: async for item in streamer: result = func(item) if iscorofunc: assert isinstance(result, Awaitable) result = await result if result: yield item @pipable_operator async def until( source: AsyncIterable[T], func: Callable[[T], bool | Awaitable[bool]] ) -> AsyncIterator[T]: """Forward an asynchronous sequence until a condition is met. Contrary to the ``takewhile`` operator, the last tested element is included in the sequence. The given function takes the item as an argument and returns a boolean corresponding to the condition to meet. The function can either be synchronous or asynchronous. """ iscorofunc = asyncio.iscoroutinefunction(func) async with streamcontext(source) as streamer: async for item in streamer: result = func(item) if iscorofunc: assert isinstance(result, Awaitable) result = await result yield item if result: return @pipable_operator async def takewhile( source: AsyncIterable[T], func: Callable[[T], bool | Awaitable[bool]] ) -> AsyncIterator[T]: """Forward an asynchronous sequence while a condition is met. Contrary to the ``until`` operator, the last tested element is not included in the sequence. The given function takes the item as an argument and returns a boolean corresponding to the condition to meet. The function can either be synchronous or asynchronous. """ iscorofunc = asyncio.iscoroutinefunction(func) async with streamcontext(source) as streamer: async for item in streamer: result = func(item) if iscorofunc: assert isinstance(result, Awaitable) result = await result if not result: return yield item @pipable_operator async def dropwhile( source: AsyncIterable[T], func: Callable[[T], bool | Awaitable[bool]] ) -> AsyncIterator[T]: """Discard the elements from an asynchronous sequence while a condition is met. The given function takes the item as an argument and returns a boolean corresponding to the condition to meet. The function can either be synchronous or asynchronous. """ iscorofunc = asyncio.iscoroutinefunction(func) async with streamcontext(source) as streamer: async for item in streamer: result = func(item) if iscorofunc: assert isinstance(result, Awaitable) result = await result if not result: yield item break async for item in streamer: yield item