"""Utilities for testing stream operators.""" from __future__ import annotations import asyncio from unittest.mock import Mock from contextlib import contextmanager import pytest from .core import StreamEmpty, streamcontext, pipable_operator from typing import TYPE_CHECKING, Any, Callable, List if TYPE_CHECKING: from _pytest.fixtures import SubRequest from aiostream.core import Stream __all__ = ["add_resource", "assert_run", "event_loop"] @pipable_operator async def add_resource(source, cleanup_time): """Simulate an open resource in a stream operator.""" try: loop = asyncio.get_event_loop() loop.open_resources += 1 loop.resources += 1 async with streamcontext(source) as streamer: async for item in streamer: yield item finally: try: await asyncio.sleep(cleanup_time) finally: loop.open_resources -= 1 def compare_exceptions( exc1: Exception, exc2: Exception, ) -> bool: """Compare two exceptions together.""" return exc1 == exc2 or exc1.__class__ == exc2.__class__ and exc1.args == exc2.args async def assert_aiter( source: Stream, values: List[Any], exception: Exception | None = None, ) -> None: """Check the results of a stream using a streamcontext.""" results = [] exception_type = (type(exception),) if exception else () try: async with streamcontext(source) as streamer: async for item in streamer: results.append(item) except exception_type as exc: assert exception is not None assert compare_exceptions(exc, exception) else: assert exception is None assert results == values async def assert_await( source: Stream, values: List[Any], exception: Exception | None = None, ) -> None: """Check the results of a stream using by awaiting it.""" exception_type = (type(exception),) if exception else () try: result = await source except StreamEmpty: assert values == [] assert exception is None except exception_type as exc: assert exception is not None assert compare_exceptions(exc, exception) else: assert result == values[-1] assert exception is None @pytest.fixture(params=[assert_aiter, assert_await], ids=["aiter", "await"]) def assert_run(request: SubRequest) -> Callable: """Parametrized fixture returning a stream runner.""" return request.param @pytest.fixture def event_loop(): """Fixture providing a test event loop. The event loop simulate and records the sleep operation, available as event_loop.steps It also tracks simulated resources and make sure they are all released before the loop is closed. """ class TimeTrackingTestLoop(asyncio.BaseEventLoop): stuck_threshold = 100 def __init__(self): super().__init__() self._time = 0 self._timers = [] self._selector = Mock() self.clear() # Loop internals def _run_once(self): super()._run_once() # Update internals self.busy_count += 1 self._timers = sorted(when for when in self._timers if when > loop.time()) # Time advance if self.time_to_go: when = self._timers.pop(0) step = when - loop.time() self.steps.append(step) self.advance_time(step) self.busy_count = 0 def _process_events(self, event_list): return def _write_to_self(self): return # Time management def time(self): return self._time def advance_time(self, advance): if advance: self._time += advance def call_at(self, when, callback, *args, **kwargs): self._timers.append(when) return super().call_at(when, callback, *args, **kwargs) @property def stuck(self): return self.busy_count > self.stuck_threshold @property def time_to_go(self): return self._timers and (self.stuck or not self._ready) # Resource management def clear(self): self.steps = [] self.open_resources = 0 self.resources = 0 self.busy_count = 0 @contextmanager def assert_cleanup(self): self.clear() yield self assert self.open_resources == 0 self.clear() loop = TimeTrackingTestLoop() asyncio.set_event_loop(loop) with loop.assert_cleanup(): yield loop loop.close()