aboutsummaryrefslogtreecommitdiff
"""Utilities for testing stream operators."""
from __future__ import annotations

import asyncio
from unittest.mock import Mock
from contextlib import contextmanager

import pytest

from .core import StreamEmpty, streamcontext, pipable_operator
from typing import TYPE_CHECKING, Any, Callable, List

if TYPE_CHECKING:
    from _pytest.fixtures import SubRequest
    from aiostream.core import Stream

__all__ = ["add_resource", "assert_run", "event_loop"]


@pipable_operator
async def add_resource(source, cleanup_time):
    """Simulate an open resource in a stream operator."""
    try:
        loop = asyncio.get_event_loop()
        loop.open_resources += 1
        loop.resources += 1
        async with streamcontext(source) as streamer:
            async for item in streamer:
                yield item
    finally:
        try:
            await asyncio.sleep(cleanup_time)
        finally:
            loop.open_resources -= 1


def compare_exceptions(
    exc1: Exception,
    exc2: Exception,
) -> bool:
    """Compare two exceptions together."""
    return exc1 == exc2 or exc1.__class__ == exc2.__class__ and exc1.args == exc2.args


async def assert_aiter(
    source: Stream,
    values: List[Any],
    exception: Exception | None = None,
) -> None:
    """Check the results of a stream using a streamcontext."""
    results = []
    exception_type = (type(exception),) if exception else ()
    try:
        async with streamcontext(source) as streamer:
            async for item in streamer:
                results.append(item)
    except exception_type as exc:
        assert exception is not None
        assert compare_exceptions(exc, exception)
    else:
        assert exception is None
    assert results == values


async def assert_await(
    source: Stream,
    values: List[Any],
    exception: Exception | None = None,
) -> None:
    """Check the results of a stream using by awaiting it."""
    exception_type = (type(exception),) if exception else ()
    try:
        result = await source
    except StreamEmpty:
        assert values == []
        assert exception is None
    except exception_type as exc:
        assert exception is not None
        assert compare_exceptions(exc, exception)
    else:
        assert result == values[-1]
        assert exception is None


@pytest.fixture(params=[assert_aiter, assert_await], ids=["aiter", "await"])
def assert_run(request: SubRequest) -> Callable:
    """Parametrized fixture returning a stream runner."""
    return request.param


@pytest.fixture
def event_loop():
    """Fixture providing a test event loop.

    The event loop simulate and records the sleep operation,
    available as event_loop.steps

    It also tracks simulated resources and make sure they are
    all released before the loop is closed.
    """

    class TimeTrackingTestLoop(asyncio.BaseEventLoop):
        stuck_threshold = 100

        def __init__(self):
            super().__init__()
            self._time = 0
            self._timers = []
            self._selector = Mock()
            self.clear()

        # Loop internals

        def _run_once(self):
            super()._run_once()
            # Update internals
            self.busy_count += 1
            self._timers = sorted(when for when in self._timers if when > loop.time())
            # Time advance
            if self.time_to_go:
                when = self._timers.pop(0)
                step = when - loop.time()
                self.steps.append(step)
                self.advance_time(step)
                self.busy_count = 0

        def _process_events(self, event_list):
            return

        def _write_to_self(self):
            return

        # Time management

        def time(self):
            return self._time

        def advance_time(self, advance):
            if advance:
                self._time += advance

        def call_at(self, when, callback, *args, **kwargs):
            self._timers.append(when)
            return super().call_at(when, callback, *args, **kwargs)

        @property
        def stuck(self):
            return self.busy_count > self.stuck_threshold

        @property
        def time_to_go(self):
            return self._timers and (self.stuck or not self._ready)

        # Resource management

        def clear(self):
            self.steps = []
            self.open_resources = 0
            self.resources = 0
            self.busy_count = 0

        @contextmanager
        def assert_cleanup(self):
            self.clear()
            yield self
            assert self.open_resources == 0
            self.clear()

    loop = TimeTrackingTestLoop()
    asyncio.set_event_loop(loop)
    with loop.assert_cleanup():
        yield loop
    loop.close()