diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiostream/manager.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/aiostream/manager.py | 159 |
1 files changed, 159 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiostream/manager.py b/.venv/lib/python3.12/site-packages/aiostream/manager.py new file mode 100644 index 00000000..bab224a5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiostream/manager.py @@ -0,0 +1,159 @@ +"""Provide a context to easily manage several streamers running +concurrently. +""" +from __future__ import annotations + +import asyncio +from .aiter_utils import AsyncExitStack + +from .aiter_utils import anext +from .core import streamcontext +from typing import ( + TYPE_CHECKING, + Awaitable, + List, + Set, + Tuple, + Generic, + TypeVar, + Any, + Type, + AsyncIterable, +) +from types import TracebackType + +if TYPE_CHECKING: + from asyncio import Task + from aiostream.core import Streamer + +T = TypeVar("T") + + +class TaskGroup: + def __init__(self) -> None: + self._pending: set[Task[Any]] = set() + + async def __aenter__(self) -> TaskGroup: + return self + + async def __aexit__( + self, + typ: Type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + while self._pending: + task = self._pending.pop() + await self.cancel_task(task) + + def create_task(self, coro: Awaitable[T]) -> Task[T]: + task = asyncio.ensure_future(coro) + self._pending.add(task) + return task + + async def wait_any(self, tasks: List[Task[T]]) -> Set[Task[T]]: + done, _ = await asyncio.wait(tasks, return_when="FIRST_COMPLETED") + self._pending -= done + return done + + async def wait_all(self, tasks: List[Task[T]]) -> Set[Task[T]]: + if not tasks: + return set() + done, _ = await asyncio.wait(tasks) + self._pending -= done + return done + + async def cancel_task(self, task: Task[Any]) -> None: + try: + # The task is already cancelled + if task.cancelled(): + pass + # The task is already finished + elif task.done(): + # Discard the pending exception (if any). + # This makes sense since we don't know in which context the exception + # was meant to be processed. For instance, a `StopAsyncIteration` + # might be raised to notify that the end of a streamer has been reached. + task.exception() + # The task needs to be cancelled and awaited + else: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + # Silence any exception raised while cancelling the task. + # This might happen if the `CancelledError` is silenced, and the + # corresponding async generator returns, causing the `anext` call + # to raise a `StopAsyncIteration`. + except Exception: + pass + finally: + self._pending.discard(task) + + +class StreamerManager(Generic[T]): + def __init__(self) -> None: + self.tasks: dict[Streamer[T], Task[T]] = {} + self.streamers: list[Streamer[T]] = [] + self.group: TaskGroup = TaskGroup() + self.stack = AsyncExitStack() + + async def __aenter__(self) -> StreamerManager[T]: + await self.stack.__aenter__() + await self.stack.enter_async_context(self.group) + return self + + async def __aexit__( + self, + typ: Type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> bool: + for streamer in self.streamers: + task = self.tasks.pop(streamer, None) + if task is not None: + self.stack.push_async_callback(self.group.cancel_task, task) + self.stack.push_async_exit(streamer) + self.tasks.clear() + self.streamers.clear() + return await self.stack.__aexit__(typ, value, traceback) + + async def enter_and_create_task(self, aiter: AsyncIterable[T]) -> Streamer[T]: + streamer = streamcontext(aiter) + await streamer.__aenter__() + self.streamers.append(streamer) + self.create_task(streamer) + return streamer + + def create_task(self, streamer: Streamer[T]) -> None: + assert streamer in self.streamers + assert streamer not in self.tasks + self.tasks[streamer] = self.group.create_task(anext(streamer)) + + async def wait_single_event( + self, filters: list[Streamer[T]] + ) -> Tuple[Streamer[T], Task[T]]: + tasks = [self.tasks[streamer] for streamer in filters] + done = await self.group.wait_any(tasks) + for streamer in filters: + if self.tasks.get(streamer) in done: + return streamer, self.tasks.pop(streamer) + assert False + + async def clean_streamer(self, streamer: Streamer[T]) -> None: + task = self.tasks.pop(streamer, None) + if task is not None: + await self.group.cancel_task(task) + await streamer.aclose() + self.streamers.remove(streamer) + + async def clean_streamers(self, streamers: list[Streamer[T]]) -> None: + tasks = [ + self.group.create_task(self.clean_streamer(streamer)) + for streamer in streamers + ] + done = await self.group.wait_all(tasks) + # Raise exception if any + for task in done: + task.result() |