aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/aiostream/manager.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiostream/manager.py')
-rw-r--r--.venv/lib/python3.12/site-packages/aiostream/manager.py159
1 files changed, 159 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiostream/manager.py b/.venv/lib/python3.12/site-packages/aiostream/manager.py
new file mode 100644
index 00000000..bab224a5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/aiostream/manager.py
@@ -0,0 +1,159 @@
+"""Provide a context to easily manage several streamers running
+concurrently.
+"""
+from __future__ import annotations
+
+import asyncio
+from .aiter_utils import AsyncExitStack
+
+from .aiter_utils import anext
+from .core import streamcontext
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ List,
+ Set,
+ Tuple,
+ Generic,
+ TypeVar,
+ Any,
+ Type,
+ AsyncIterable,
+)
+from types import TracebackType
+
+if TYPE_CHECKING:
+ from asyncio import Task
+ from aiostream.core import Streamer
+
+T = TypeVar("T")
+
+
+class TaskGroup:
+ def __init__(self) -> None:
+ self._pending: set[Task[Any]] = set()
+
+ async def __aenter__(self) -> TaskGroup:
+ return self
+
+ async def __aexit__(
+ self,
+ typ: Type[BaseException] | None,
+ value: BaseException | None,
+ traceback: TracebackType | None,
+ ) -> None:
+ while self._pending:
+ task = self._pending.pop()
+ await self.cancel_task(task)
+
+ def create_task(self, coro: Awaitable[T]) -> Task[T]:
+ task = asyncio.ensure_future(coro)
+ self._pending.add(task)
+ return task
+
+ async def wait_any(self, tasks: List[Task[T]]) -> Set[Task[T]]:
+ done, _ = await asyncio.wait(tasks, return_when="FIRST_COMPLETED")
+ self._pending -= done
+ return done
+
+ async def wait_all(self, tasks: List[Task[T]]) -> Set[Task[T]]:
+ if not tasks:
+ return set()
+ done, _ = await asyncio.wait(tasks)
+ self._pending -= done
+ return done
+
+ async def cancel_task(self, task: Task[Any]) -> None:
+ try:
+ # The task is already cancelled
+ if task.cancelled():
+ pass
+ # The task is already finished
+ elif task.done():
+ # Discard the pending exception (if any).
+ # This makes sense since we don't know in which context the exception
+ # was meant to be processed. For instance, a `StopAsyncIteration`
+ # might be raised to notify that the end of a streamer has been reached.
+ task.exception()
+ # The task needs to be cancelled and awaited
+ else:
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+ # Silence any exception raised while cancelling the task.
+ # This might happen if the `CancelledError` is silenced, and the
+ # corresponding async generator returns, causing the `anext` call
+ # to raise a `StopAsyncIteration`.
+ except Exception:
+ pass
+ finally:
+ self._pending.discard(task)
+
+
+class StreamerManager(Generic[T]):
+ def __init__(self) -> None:
+ self.tasks: dict[Streamer[T], Task[T]] = {}
+ self.streamers: list[Streamer[T]] = []
+ self.group: TaskGroup = TaskGroup()
+ self.stack = AsyncExitStack()
+
+ async def __aenter__(self) -> StreamerManager[T]:
+ await self.stack.__aenter__()
+ await self.stack.enter_async_context(self.group)
+ return self
+
+ async def __aexit__(
+ self,
+ typ: Type[BaseException] | None,
+ value: BaseException | None,
+ traceback: TracebackType | None,
+ ) -> bool:
+ for streamer in self.streamers:
+ task = self.tasks.pop(streamer, None)
+ if task is not None:
+ self.stack.push_async_callback(self.group.cancel_task, task)
+ self.stack.push_async_exit(streamer)
+ self.tasks.clear()
+ self.streamers.clear()
+ return await self.stack.__aexit__(typ, value, traceback)
+
+ async def enter_and_create_task(self, aiter: AsyncIterable[T]) -> Streamer[T]:
+ streamer = streamcontext(aiter)
+ await streamer.__aenter__()
+ self.streamers.append(streamer)
+ self.create_task(streamer)
+ return streamer
+
+ def create_task(self, streamer: Streamer[T]) -> None:
+ assert streamer in self.streamers
+ assert streamer not in self.tasks
+ self.tasks[streamer] = self.group.create_task(anext(streamer))
+
+ async def wait_single_event(
+ self, filters: list[Streamer[T]]
+ ) -> Tuple[Streamer[T], Task[T]]:
+ tasks = [self.tasks[streamer] for streamer in filters]
+ done = await self.group.wait_any(tasks)
+ for streamer in filters:
+ if self.tasks.get(streamer) in done:
+ return streamer, self.tasks.pop(streamer)
+ assert False
+
+ async def clean_streamer(self, streamer: Streamer[T]) -> None:
+ task = self.tasks.pop(streamer, None)
+ if task is not None:
+ await self.group.cancel_task(task)
+ await streamer.aclose()
+ self.streamers.remove(streamer)
+
+ async def clean_streamers(self, streamers: list[Streamer[T]]) -> None:
+ tasks = [
+ self.group.create_task(self.clean_streamer(streamer))
+ for streamer in streamers
+ ]
+ done = await self.group.wait_all(tasks)
+ # Raise exception if any
+ for task in done:
+ task.result()