aboutsummaryrefslogtreecommitdiff
"""Provide a context to easily manage several streamers running
concurrently.
"""
from __future__ import annotations

import asyncio
from .aiter_utils import AsyncExitStack

from .aiter_utils import anext
from .core import streamcontext
from typing import (
    TYPE_CHECKING,
    Awaitable,
    List,
    Set,
    Tuple,
    Generic,
    TypeVar,
    Any,
    Type,
    AsyncIterable,
)
from types import TracebackType

if TYPE_CHECKING:
    from asyncio import Task
    from aiostream.core import Streamer

T = TypeVar("T")


class TaskGroup:
    def __init__(self) -> None:
        self._pending: set[Task[Any]] = set()

    async def __aenter__(self) -> TaskGroup:
        return self

    async def __aexit__(
        self,
        typ: Type[BaseException] | None,
        value: BaseException | None,
        traceback: TracebackType | None,
    ) -> None:
        while self._pending:
            task = self._pending.pop()
            await self.cancel_task(task)

    def create_task(self, coro: Awaitable[T]) -> Task[T]:
        task = asyncio.ensure_future(coro)
        self._pending.add(task)
        return task

    async def wait_any(self, tasks: List[Task[T]]) -> Set[Task[T]]:
        done, _ = await asyncio.wait(tasks, return_when="FIRST_COMPLETED")
        self._pending -= done
        return done

    async def wait_all(self, tasks: List[Task[T]]) -> Set[Task[T]]:
        if not tasks:
            return set()
        done, _ = await asyncio.wait(tasks)
        self._pending -= done
        return done

    async def cancel_task(self, task: Task[Any]) -> None:
        try:
            # The task is already cancelled
            if task.cancelled():
                pass
            # The task is already finished
            elif task.done():
                # Discard the pending exception (if any).
                # This makes sense since we don't know in which context the exception
                # was meant to be processed. For instance, a `StopAsyncIteration`
                # might be raised to notify that the end of a streamer has been reached.
                task.exception()
            # The task needs to be cancelled and awaited
            else:
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass
                # Silence any exception raised while cancelling the task.
                # This might happen if the `CancelledError` is silenced, and the
                # corresponding async generator returns, causing the `anext` call
                # to raise a `StopAsyncIteration`.
                except Exception:
                    pass
        finally:
            self._pending.discard(task)


class StreamerManager(Generic[T]):
    def __init__(self) -> None:
        self.tasks: dict[Streamer[T], Task[T]] = {}
        self.streamers: list[Streamer[T]] = []
        self.group: TaskGroup = TaskGroup()
        self.stack = AsyncExitStack()

    async def __aenter__(self) -> StreamerManager[T]:
        await self.stack.__aenter__()
        await self.stack.enter_async_context(self.group)
        return self

    async def __aexit__(
        self,
        typ: Type[BaseException] | None,
        value: BaseException | None,
        traceback: TracebackType | None,
    ) -> bool:
        for streamer in self.streamers:
            task = self.tasks.pop(streamer, None)
            if task is not None:
                self.stack.push_async_callback(self.group.cancel_task, task)
            self.stack.push_async_exit(streamer)
        self.tasks.clear()
        self.streamers.clear()
        return await self.stack.__aexit__(typ, value, traceback)

    async def enter_and_create_task(self, aiter: AsyncIterable[T]) -> Streamer[T]:
        streamer = streamcontext(aiter)
        await streamer.__aenter__()
        self.streamers.append(streamer)
        self.create_task(streamer)
        return streamer

    def create_task(self, streamer: Streamer[T]) -> None:
        assert streamer in self.streamers
        assert streamer not in self.tasks
        self.tasks[streamer] = self.group.create_task(anext(streamer))

    async def wait_single_event(
        self, filters: list[Streamer[T]]
    ) -> Tuple[Streamer[T], Task[T]]:
        tasks = [self.tasks[streamer] for streamer in filters]
        done = await self.group.wait_any(tasks)
        for streamer in filters:
            if self.tasks.get(streamer) in done:
                return streamer, self.tasks.pop(streamer)
        assert False

    async def clean_streamer(self, streamer: Streamer[T]) -> None:
        task = self.tasks.pop(streamer, None)
        if task is not None:
            await self.group.cancel_task(task)
        await streamer.aclose()
        self.streamers.remove(streamer)

    async def clean_streamers(self, streamers: list[Streamer[T]]) -> None:
        tasks = [
            self.group.create_task(self.clean_streamer(streamer))
            for streamer in streamers
        ]
        done = await self.group.wait_all(tasks)
        # Raise exception if any
        for task in done:
            task.result()