"""Non-pipable creation operators.""" from __future__ import annotations import sys import asyncio import inspect import builtins import itertools from typing import ( AsyncIterable, Awaitable, Iterable, Protocol, TypeVar, AsyncIterator, cast, ) from typing_extensions import ParamSpec from ..stream import time from ..core import operator, streamcontext __all__ = [ "iterate", "preserve", "just", "call", "throw", "empty", "never", "repeat", "range", "count", ] T = TypeVar("T") P = ParamSpec("P") # Hack for python 3.8 compatibility if sys.version_info < (3, 9): P = TypeVar("P") # Convert regular iterables @operator async def from_iterable(it: Iterable[T]) -> AsyncIterator[T]: """Generate values from a regular iterable.""" for item in it: await asyncio.sleep(0) yield item @operator def from_async_iterable(ait: AsyncIterable[T]) -> AsyncIterator[T]: """Generate values from an asynchronous iterable. Note: the corresponding iterator will be explicitely closed when leaving the context manager.""" return streamcontext(ait) @operator def iterate(it: AsyncIterable[T] | Iterable[T]) -> AsyncIterator[T]: """Generate values from a sychronous or asynchronous iterable.""" if isinstance(it, AsyncIterable): return from_async_iterable.raw(it) if isinstance(it, Iterable): return from_iterable.raw(it) raise TypeError(f"{type(it).__name__!r} object is not (async) iterable") @operator async def preserve(ait: AsyncIterable[T]) -> AsyncIterator[T]: """Generate values from an asynchronous iterable without explicitly closing the corresponding iterator.""" async for item in ait: yield item # Simple operators @operator async def just(value: T) -> AsyncIterator[T]: """Await if possible, and generate a single value.""" if inspect.isawaitable(value): yield await value else: yield value Y = TypeVar("Y", covariant=True) class SyncCallable(Protocol[P, Y]): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Y: ... class AsyncCallable(Protocol[P, Y]): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[Y]: ... @operator async def call( func: SyncCallable[P, T] | AsyncCallable[P, T], *args: P.args, **kwargs: P.kwargs ) -> AsyncIterator[T]: """Call the given function and generate a single value. Await if the provided function is asynchronous. """ if asyncio.iscoroutinefunction(func): async_func = cast("AsyncCallable[P, T]", func) yield await async_func(*args, **kwargs) else: sync_func = cast("SyncCallable[P, T]", func) yield sync_func(*args, **kwargs) @operator async def throw(exc: Exception) -> AsyncIterator[None]: """Throw an exception without generating any value.""" if False: yield raise exc @operator async def empty() -> AsyncIterator[None]: """Terminate without generating any value.""" if False: yield @operator async def never() -> AsyncIterator[None]: """Hang forever without generating any value.""" if False: yield future: asyncio.Future[None] = asyncio.Future() try: await future finally: future.cancel() @operator def repeat( value: T, times: int | None = None, *, interval: float = 0.0 ) -> AsyncIterator[T]: """Generate the same value a given number of times. If ``times`` is ``None``, the value is repeated indefinitely. An optional interval can be given to space the values out. """ args = () if times is None else (times,) it = itertools.repeat(value, *args) agen = from_iterable.raw(it) return time.spaceout.raw(agen, interval) if interval else agen # Counting operators @operator def range(*args: int, interval: float = 0.0) -> AsyncIterator[int]: """Generate a given range of numbers. It supports the same arguments as the builtin function. An optional interval can be given to space the values out. """ agen = from_iterable.raw(builtins.range(*args)) return time.spaceout.raw(agen, interval) if interval else agen @operator def count( start: int = 0, step: int = 1, *, interval: float = 0.0 ) -> AsyncIterator[int]: """Generate consecutive numbers indefinitely. Optional starting point and increment can be defined, respectively defaulting to ``0`` and ``1``. An optional interval can be given to space the values out. """ agen = from_iterable.raw(itertools.count(start, step)) return time.spaceout.raw(agen, interval) if interval else agen