about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/aiostream/manager.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiostream/manager.py')
-rw-r--r--.venv/lib/python3.12/site-packages/aiostream/manager.py159
1 files changed, 159 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiostream/manager.py b/.venv/lib/python3.12/site-packages/aiostream/manager.py
new file mode 100644
index 00000000..bab224a5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/aiostream/manager.py
@@ -0,0 +1,159 @@
+"""Provide a context to easily manage several streamers running
+concurrently.
+"""
+from __future__ import annotations
+
+import asyncio
+from .aiter_utils import AsyncExitStack
+
+from .aiter_utils import anext
+from .core import streamcontext
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    List,
+    Set,
+    Tuple,
+    Generic,
+    TypeVar,
+    Any,
+    Type,
+    AsyncIterable,
+)
+from types import TracebackType
+
+if TYPE_CHECKING:
+    from asyncio import Task
+    from aiostream.core import Streamer
+
+T = TypeVar("T")
+
+
+class TaskGroup:
+    def __init__(self) -> None:
+        self._pending: set[Task[Any]] = set()
+
+    async def __aenter__(self) -> TaskGroup:
+        return self
+
+    async def __aexit__(
+        self,
+        typ: Type[BaseException] | None,
+        value: BaseException | None,
+        traceback: TracebackType | None,
+    ) -> None:
+        while self._pending:
+            task = self._pending.pop()
+            await self.cancel_task(task)
+
+    def create_task(self, coro: Awaitable[T]) -> Task[T]:
+        task = asyncio.ensure_future(coro)
+        self._pending.add(task)
+        return task
+
+    async def wait_any(self, tasks: List[Task[T]]) -> Set[Task[T]]:
+        done, _ = await asyncio.wait(tasks, return_when="FIRST_COMPLETED")
+        self._pending -= done
+        return done
+
+    async def wait_all(self, tasks: List[Task[T]]) -> Set[Task[T]]:
+        if not tasks:
+            return set()
+        done, _ = await asyncio.wait(tasks)
+        self._pending -= done
+        return done
+
+    async def cancel_task(self, task: Task[Any]) -> None:
+        try:
+            # The task is already cancelled
+            if task.cancelled():
+                pass
+            # The task is already finished
+            elif task.done():
+                # Discard the pending exception (if any).
+                # This makes sense since we don't know in which context the exception
+                # was meant to be processed. For instance, a `StopAsyncIteration`
+                # might be raised to notify that the end of a streamer has been reached.
+                task.exception()
+            # The task needs to be cancelled and awaited
+            else:
+                task.cancel()
+                try:
+                    await task
+                except asyncio.CancelledError:
+                    pass
+                # Silence any exception raised while cancelling the task.
+                # This might happen if the `CancelledError` is silenced, and the
+                # corresponding async generator returns, causing the `anext` call
+                # to raise a `StopAsyncIteration`.
+                except Exception:
+                    pass
+        finally:
+            self._pending.discard(task)
+
+
+class StreamerManager(Generic[T]):
+    def __init__(self) -> None:
+        self.tasks: dict[Streamer[T], Task[T]] = {}
+        self.streamers: list[Streamer[T]] = []
+        self.group: TaskGroup = TaskGroup()
+        self.stack = AsyncExitStack()
+
+    async def __aenter__(self) -> StreamerManager[T]:
+        await self.stack.__aenter__()
+        await self.stack.enter_async_context(self.group)
+        return self
+
+    async def __aexit__(
+        self,
+        typ: Type[BaseException] | None,
+        value: BaseException | None,
+        traceback: TracebackType | None,
+    ) -> bool:
+        for streamer in self.streamers:
+            task = self.tasks.pop(streamer, None)
+            if task is not None:
+                self.stack.push_async_callback(self.group.cancel_task, task)
+            self.stack.push_async_exit(streamer)
+        self.tasks.clear()
+        self.streamers.clear()
+        return await self.stack.__aexit__(typ, value, traceback)
+
+    async def enter_and_create_task(self, aiter: AsyncIterable[T]) -> Streamer[T]:
+        streamer = streamcontext(aiter)
+        await streamer.__aenter__()
+        self.streamers.append(streamer)
+        self.create_task(streamer)
+        return streamer
+
+    def create_task(self, streamer: Streamer[T]) -> None:
+        assert streamer in self.streamers
+        assert streamer not in self.tasks
+        self.tasks[streamer] = self.group.create_task(anext(streamer))
+
+    async def wait_single_event(
+        self, filters: list[Streamer[T]]
+    ) -> Tuple[Streamer[T], Task[T]]:
+        tasks = [self.tasks[streamer] for streamer in filters]
+        done = await self.group.wait_any(tasks)
+        for streamer in filters:
+            if self.tasks.get(streamer) in done:
+                return streamer, self.tasks.pop(streamer)
+        assert False
+
+    async def clean_streamer(self, streamer: Streamer[T]) -> None:
+        task = self.tasks.pop(streamer, None)
+        if task is not None:
+            await self.group.cancel_task(task)
+        await streamer.aclose()
+        self.streamers.remove(streamer)
+
+    async def clean_streamers(self, streamers: list[Streamer[T]]) -> None:
+        tasks = [
+            self.group.create_task(self.clean_streamer(streamer))
+            for streamer in streamers
+        ]
+        done = await self.group.wait_all(tasks)
+        # Raise exception if any
+        for task in done:
+            task.result()