"""Provide a context to easily manage several streamers running
concurrently.
"""
from __future__ import annotations
import asyncio
from .aiter_utils import AsyncExitStack
from .aiter_utils import anext
from .core import streamcontext
from typing import (
TYPE_CHECKING,
Awaitable,
List,
Set,
Tuple,
Generic,
TypeVar,
Any,
Type,
AsyncIterable,
)
from types import TracebackType
if TYPE_CHECKING:
from asyncio import Task
from aiostream.core import Streamer
T = TypeVar("T")
class TaskGroup:
def __init__(self) -> None:
self._pending: set[Task[Any]] = set()
async def __aenter__(self) -> TaskGroup:
return self
async def __aexit__(
self,
typ: Type[BaseException] | None,
value: BaseException | None,
traceback: TracebackType | None,
) -> None:
while self._pending:
task = self._pending.pop()
await self.cancel_task(task)
def create_task(self, coro: Awaitable[T]) -> Task[T]:
task = asyncio.ensure_future(coro)
self._pending.add(task)
return task
async def wait_any(self, tasks: List[Task[T]]) -> Set[Task[T]]:
done, _ = await asyncio.wait(tasks, return_when="FIRST_COMPLETED")
self._pending -= done
return done
async def wait_all(self, tasks: List[Task[T]]) -> Set[Task[T]]:
if not tasks:
return set()
done, _ = await asyncio.wait(tasks)
self._pending -= done
return done
async def cancel_task(self, task: Task[Any]) -> None:
try:
# The task is already cancelled
if task.cancelled():
pass
# The task is already finished
elif task.done():
# Discard the pending exception (if any).
# This makes sense since we don't know in which context the exception
# was meant to be processed. For instance, a `StopAsyncIteration`
# might be raised to notify that the end of a streamer has been reached.
task.exception()
# The task needs to be cancelled and awaited
else:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Silence any exception raised while cancelling the task.
# This might happen if the `CancelledError` is silenced, and the
# corresponding async generator returns, causing the `anext` call
# to raise a `StopAsyncIteration`.
except Exception:
pass
finally:
self._pending.discard(task)
class StreamerManager(Generic[T]):
def __init__(self) -> None:
self.tasks: dict[Streamer[T], Task[T]] = {}
self.streamers: list[Streamer[T]] = []
self.group: TaskGroup = TaskGroup()
self.stack = AsyncExitStack()
async def __aenter__(self) -> StreamerManager[T]:
await self.stack.__aenter__()
await self.stack.enter_async_context(self.group)
return self
async def __aexit__(
self,
typ: Type[BaseException] | None,
value: BaseException | None,
traceback: TracebackType | None,
) -> bool:
for streamer in self.streamers:
task = self.tasks.pop(streamer, None)
if task is not None:
self.stack.push_async_callback(self.group.cancel_task, task)
self.stack.push_async_exit(streamer)
self.tasks.clear()
self.streamers.clear()
return await self.stack.__aexit__(typ, value, traceback)
async def enter_and_create_task(self, aiter: AsyncIterable[T]) -> Streamer[T]:
streamer = streamcontext(aiter)
await streamer.__aenter__()
self.streamers.append(streamer)
self.create_task(streamer)
return streamer
def create_task(self, streamer: Streamer[T]) -> None:
assert streamer in self.streamers
assert streamer not in self.tasks
self.tasks[streamer] = self.group.create_task(anext(streamer))
async def wait_single_event(
self, filters: list[Streamer[T]]
) -> Tuple[Streamer[T], Task[T]]:
tasks = [self.tasks[streamer] for streamer in filters]
done = await self.group.wait_any(tasks)
for streamer in filters:
if self.tasks.get(streamer) in done:
return streamer, self.tasks.pop(streamer)
assert False
async def clean_streamer(self, streamer: Streamer[T]) -> None:
task = self.tasks.pop(streamer, None)
if task is not None:
await self.group.cancel_task(task)
await streamer.aclose()
self.streamers.remove(streamer)
async def clean_streamers(self, streamers: list[Streamer[T]]) -> None:
tasks = [
self.group.create_task(self.clean_streamer(streamer))
for streamer in streamers
]
done = await self.group.wait_all(tasks)
# Raise exception if any
for task in done:
task.result()