aboutsummaryrefslogtreecommitdiff
"""Non-pipable creation operators."""
from __future__ import annotations

import sys
import asyncio
import inspect
import builtins
import itertools

from typing import (
    AsyncIterable,
    Awaitable,
    Iterable,
    Protocol,
    TypeVar,
    AsyncIterator,
    cast,
)
from typing_extensions import ParamSpec

from ..stream import time
from ..core import operator, streamcontext

__all__ = [
    "iterate",
    "preserve",
    "just",
    "call",
    "throw",
    "empty",
    "never",
    "repeat",
    "range",
    "count",
]

T = TypeVar("T")
P = ParamSpec("P")

# Hack for python 3.8 compatibility
if sys.version_info < (3, 9):
    P = TypeVar("P")

# Convert regular iterables


@operator
async def from_iterable(it: Iterable[T]) -> AsyncIterator[T]:
    """Generate values from a regular iterable."""
    for item in it:
        await asyncio.sleep(0)
        yield item


@operator
def from_async_iterable(ait: AsyncIterable[T]) -> AsyncIterator[T]:
    """Generate values from an asynchronous iterable.

    Note: the corresponding iterator will be explicitely closed
    when leaving the context manager."""
    return streamcontext(ait)


@operator
def iterate(it: AsyncIterable[T] | Iterable[T]) -> AsyncIterator[T]:
    """Generate values from a sychronous or asynchronous iterable."""
    if isinstance(it, AsyncIterable):
        return from_async_iterable.raw(it)
    if isinstance(it, Iterable):
        return from_iterable.raw(it)
    raise TypeError(f"{type(it).__name__!r} object is not (async) iterable")


@operator
async def preserve(ait: AsyncIterable[T]) -> AsyncIterator[T]:
    """Generate values from an asynchronous iterable without
    explicitly closing the corresponding iterator."""
    async for item in ait:
        yield item


# Simple operators


@operator
async def just(value: T) -> AsyncIterator[T]:
    """Await if possible, and generate a single value."""
    if inspect.isawaitable(value):
        yield await value
    else:
        yield value


Y = TypeVar("Y", covariant=True)


class SyncCallable(Protocol[P, Y]):
    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Y:
        ...


class AsyncCallable(Protocol[P, Y]):
    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[Y]:
        ...


@operator
async def call(
    func: SyncCallable[P, T] | AsyncCallable[P, T], *args: P.args, **kwargs: P.kwargs
) -> AsyncIterator[T]:
    """Call the given function and generate a single value.

    Await if the provided function is asynchronous.
    """
    if asyncio.iscoroutinefunction(func):
        async_func = cast("AsyncCallable[P, T]", func)
        yield await async_func(*args, **kwargs)
    else:
        sync_func = cast("SyncCallable[P, T]", func)
        yield sync_func(*args, **kwargs)


@operator
async def throw(exc: Exception) -> AsyncIterator[None]:
    """Throw an exception without generating any value."""
    if False:
        yield
    raise exc


@operator
async def empty() -> AsyncIterator[None]:
    """Terminate without generating any value."""
    if False:
        yield


@operator
async def never() -> AsyncIterator[None]:
    """Hang forever without generating any value."""
    if False:
        yield
    future: asyncio.Future[None] = asyncio.Future()
    try:
        await future
    finally:
        future.cancel()


@operator
def repeat(
    value: T, times: int | None = None, *, interval: float = 0.0
) -> AsyncIterator[T]:
    """Generate the same value a given number of times.

    If ``times`` is ``None``, the value is repeated indefinitely.
    An optional interval can be given to space the values out.
    """
    args = () if times is None else (times,)
    it = itertools.repeat(value, *args)
    agen = from_iterable.raw(it)
    return time.spaceout.raw(agen, interval) if interval else agen


# Counting operators


@operator
def range(*args: int, interval: float = 0.0) -> AsyncIterator[int]:
    """Generate a given range of numbers.

    It supports the same arguments as the builtin function.
    An optional interval can be given to space the values out.
    """
    agen = from_iterable.raw(builtins.range(*args))
    return time.spaceout.raw(agen, interval) if interval else agen


@operator
def count(
    start: int = 0, step: int = 1, *, interval: float = 0.0
) -> AsyncIterator[int]:
    """Generate consecutive numbers indefinitely.

    Optional starting point and increment can be defined,
    respectively defaulting to ``0`` and ``1``.

    An optional interval can be given to space the values out.
    """
    agen = from_iterable.raw(itertools.count(start, step))
    return time.spaceout.raw(agen, interval) if interval else agen