aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/aiostream/test_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiostream/test_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/aiostream/test_utils.py173
1 files changed, 173 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiostream/test_utils.py b/.venv/lib/python3.12/site-packages/aiostream/test_utils.py
new file mode 100644
index 00000000..56761d2e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/aiostream/test_utils.py
@@ -0,0 +1,173 @@
+"""Utilities for testing stream operators."""
+from __future__ import annotations
+
+import asyncio
+from unittest.mock import Mock
+from contextlib import contextmanager
+
+import pytest
+
+from .core import StreamEmpty, streamcontext, pipable_operator
+from typing import TYPE_CHECKING, Any, Callable, List
+
+if TYPE_CHECKING:
+ from _pytest.fixtures import SubRequest
+ from aiostream.core import Stream
+
+__all__ = ["add_resource", "assert_run", "event_loop"]
+
+
+@pipable_operator
+async def add_resource(source, cleanup_time):
+ """Simulate an open resource in a stream operator."""
+ try:
+ loop = asyncio.get_event_loop()
+ loop.open_resources += 1
+ loop.resources += 1
+ async with streamcontext(source) as streamer:
+ async for item in streamer:
+ yield item
+ finally:
+ try:
+ await asyncio.sleep(cleanup_time)
+ finally:
+ loop.open_resources -= 1
+
+
+def compare_exceptions(
+ exc1: Exception,
+ exc2: Exception,
+) -> bool:
+ """Compare two exceptions together."""
+ return exc1 == exc2 or exc1.__class__ == exc2.__class__ and exc1.args == exc2.args
+
+
+async def assert_aiter(
+ source: Stream,
+ values: List[Any],
+ exception: Exception | None = None,
+) -> None:
+ """Check the results of a stream using a streamcontext."""
+ results = []
+ exception_type = (type(exception),) if exception else ()
+ try:
+ async with streamcontext(source) as streamer:
+ async for item in streamer:
+ results.append(item)
+ except exception_type as exc:
+ assert exception is not None
+ assert compare_exceptions(exc, exception)
+ else:
+ assert exception is None
+ assert results == values
+
+
+async def assert_await(
+ source: Stream,
+ values: List[Any],
+ exception: Exception | None = None,
+) -> None:
+ """Check the results of a stream using by awaiting it."""
+ exception_type = (type(exception),) if exception else ()
+ try:
+ result = await source
+ except StreamEmpty:
+ assert values == []
+ assert exception is None
+ except exception_type as exc:
+ assert exception is not None
+ assert compare_exceptions(exc, exception)
+ else:
+ assert result == values[-1]
+ assert exception is None
+
+
+@pytest.fixture(params=[assert_aiter, assert_await], ids=["aiter", "await"])
+def assert_run(request: SubRequest) -> Callable:
+ """Parametrized fixture returning a stream runner."""
+ return request.param
+
+
+@pytest.fixture
+def event_loop():
+ """Fixture providing a test event loop.
+
+ The event loop simulate and records the sleep operation,
+ available as event_loop.steps
+
+ It also tracks simulated resources and make sure they are
+ all released before the loop is closed.
+ """
+
+ class TimeTrackingTestLoop(asyncio.BaseEventLoop):
+ stuck_threshold = 100
+
+ def __init__(self):
+ super().__init__()
+ self._time = 0
+ self._timers = []
+ self._selector = Mock()
+ self.clear()
+
+ # Loop internals
+
+ def _run_once(self):
+ super()._run_once()
+ # Update internals
+ self.busy_count += 1
+ self._timers = sorted(when for when in self._timers if when > loop.time())
+ # Time advance
+ if self.time_to_go:
+ when = self._timers.pop(0)
+ step = when - loop.time()
+ self.steps.append(step)
+ self.advance_time(step)
+ self.busy_count = 0
+
+ def _process_events(self, event_list):
+ return
+
+ def _write_to_self(self):
+ return
+
+ # Time management
+
+ def time(self):
+ return self._time
+
+ def advance_time(self, advance):
+ if advance:
+ self._time += advance
+
+ def call_at(self, when, callback, *args, **kwargs):
+ self._timers.append(when)
+ return super().call_at(when, callback, *args, **kwargs)
+
+ @property
+ def stuck(self):
+ return self.busy_count > self.stuck_threshold
+
+ @property
+ def time_to_go(self):
+ return self._timers and (self.stuck or not self._ready)
+
+ # Resource management
+
+ def clear(self):
+ self.steps = []
+ self.open_resources = 0
+ self.resources = 0
+ self.busy_count = 0
+
+ @contextmanager
+ def assert_cleanup(self):
+ self.clear()
+ yield self
+ assert self.open_resources == 0
+ self.clear()
+
+ loop = TimeTrackingTestLoop()
+ asyncio.set_event_loop(loop)
+ with loop.assert_cleanup():
+ yield loop
+ loop.close()