diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiostream/aiter_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/aiostream/aiter_utils.py | 262 |
1 files changed, 262 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiostream/aiter_utils.py b/.venv/lib/python3.12/site-packages/aiostream/aiter_utils.py new file mode 100644 index 00000000..f68ea846 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiostream/aiter_utils.py @@ -0,0 +1,262 @@ +"""Utilities for asynchronous iteration.""" +from __future__ import annotations +from types import TracebackType + +import warnings +import functools +from typing import ( + TYPE_CHECKING, + AsyncContextManager, + AsyncGenerator, + AsyncIterable, + Awaitable, + Callable, + Type, + TypeVar, + AsyncIterator, + Any, +) + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + P = ParamSpec("P") + +from contextlib import AsyncExitStack + +__all__ = [ + "aiter", + "anext", + "await_", + "async_", + "is_async_iterable", + "assert_async_iterable", + "is_async_iterator", + "assert_async_iterator", + "AsyncIteratorContext", + "aitercontext", + "AsyncExitStack", +] + + +# Magic method shorcuts + + +def aiter(obj: AsyncIterable[T]) -> AsyncIterator[T]: + """Access aiter magic method.""" + assert_async_iterable(obj) + return obj.__aiter__() + + +def anext(obj: AsyncIterator[T]) -> Awaitable[T]: + """Access anext magic method.""" + assert_async_iterator(obj) + return obj.__anext__() + + +# Async / await helper functions + + +async def await_(obj: Awaitable[T]) -> T: + """Identity coroutine function.""" + return await obj + + +def async_(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: + """Wrap the given function into a coroutine function.""" + + @functools.wraps(fn) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + return await fn(*args, **kwargs) + + return wrapper + + +# Iterability helpers + + +def is_async_iterable(obj: object) -> bool: + """Check if the given object is an asynchronous iterable.""" + return hasattr(obj, "__aiter__") + + +def assert_async_iterable(obj: object) -> None: + """Raise a TypeError if the given object is not an + asynchronous iterable. + """ + if not is_async_iterable(obj): + raise TypeError(f"{type(obj).__name__!r} object is not async iterable") + + +def is_async_iterator(obj: object) -> bool: + """Check if the given object is an asynchronous iterator.""" + return hasattr(obj, "__anext__") + + +def assert_async_iterator(obj: object) -> None: + """Raise a TypeError if the given object is not an + asynchronous iterator. + """ + if not is_async_iterator(obj): + raise TypeError(f"{type(obj).__name__!r} object is not an async iterator") + + +# Async iterator context + +T = TypeVar("T") +Self = TypeVar("Self", bound="AsyncIteratorContext[Any]") + + +class AsyncIteratorContext(AsyncIterator[T], AsyncContextManager[Any]): + """Asynchronous iterator with context management. + + The context management makes sure the aclose asynchronous method + of the corresponding iterator has run before it exits. It also issues + warnings and RuntimeError if it is used incorrectly. + + Correct usage:: + + ait = some_asynchronous_iterable() + async with AsyncIteratorContext(ait) as safe_ait: + async for item in safe_ait: + <block> + + It is nonetheless not meant to use directly. + Prefer aitercontext helper instead. + """ + + _STANDBY = "STANDBY" + _RUNNING = "RUNNING" + _FINISHED = "FINISHED" + + def __init__(self, aiterator: AsyncIterator[T]): + """Initialize with an asynchrnous iterator.""" + assert_async_iterator(aiterator) + if isinstance(aiterator, AsyncIteratorContext): + raise TypeError(f"{aiterator!r} is already an AsyncIteratorContext") + self._state = self._STANDBY + self._aiterator = aiterator + + def __aiter__(self: Self) -> Self: + return self + + def __anext__(self) -> Awaitable[T]: + if self._state == self._FINISHED: + raise RuntimeError( + f"{type(self).__name__} is closed and cannot be iterated" + ) + if self._state == self._STANDBY: + warnings.warn( + f"{type(self).__name__} is iterated outside of its context", + stacklevel=2, + ) + return anext(self._aiterator) + + async def __aenter__(self: Self) -> Self: + if self._state == self._RUNNING: + raise RuntimeError(f"{type(self).__name__} has already been entered") + if self._state == self._FINISHED: + raise RuntimeError( + f"{type(self).__name__} is closed and cannot be iterated" + ) + self._state = self._RUNNING + return self + + async def __aexit__( + self, + typ: Type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> bool: + try: + if self._state == self._FINISHED: + return False + try: + # No exception to throw + if typ is None: + return False + + # Prevent GeneratorExit from being silenced + if typ is GeneratorExit: + return False + + # No method to throw + if not hasattr(self._aiterator, "athrow"): + return False + + # No frame to throw + if not getattr(self._aiterator, "ag_frame", True): + return False + + # Cannot throw at the moment + if getattr(self._aiterator, "ag_running", False): + return False + + # Throw + try: + assert isinstance(self._aiterator, AsyncGenerator) + await self._aiterator.athrow(typ, value, traceback) + raise RuntimeError("Async iterator didn't stop after athrow()") + + # Exception has been (most probably) silenced + except StopAsyncIteration as exc: + return exc is not value + + # A (possibly new) exception has been raised + except BaseException as exc: + if exc is value: + return False + raise + finally: + # Look for an aclose method + aclose = getattr(self._aiterator, "aclose", None) + + # The ag_running attribute has been introduced with python 3.8 + running = getattr(self._aiterator, "ag_running", False) + closed = not getattr(self._aiterator, "ag_frame", True) + + # A RuntimeError is raised if aiterator is running or closed + if aclose and not running and not closed: + try: + await aclose() + + # Work around bpo-35409 + except GeneratorExit: + pass # pragma: no cover + finally: + self._state = self._FINISHED + + async def aclose(self) -> None: + await self.__aexit__(None, None, None) + + async def athrow(self, exc: Exception) -> T: + if self._state == self._FINISHED: + raise RuntimeError(f"{type(self).__name__} is closed and cannot be used") + assert isinstance(self._aiterator, AsyncGenerator) + item: T = await self._aiterator.athrow(exc) + return item + + +def aitercontext( + aiterable: AsyncIterable[T], +) -> AsyncIteratorContext[T]: + """Return an asynchronous context manager from an asynchronous iterable. + + The context management makes sure the aclose asynchronous method + has run before it exits. It also issues warnings and RuntimeError + if it is used incorrectly. + + It is safe to use with any asynchronous iterable and prevent + asynchronous iterator context to be wrapped twice. + + Correct usage:: + + ait = some_asynchronous_iterable() + async with aitercontext(ait) as safe_ait: + async for item in safe_ait: + <block> + """ + aiterator = aiter(aiterable) + if isinstance(aiterator, AsyncIteratorContext): + return aiterator + return AsyncIteratorContext(aiterator) |