aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/aiostream/test_utils.py
blob: 56761d2e603ebbf490fbb61c3b3793385d0c496e (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""Utilities for testing stream operators."""
from __future__ import annotations

import asyncio
from unittest.mock import Mock
from contextlib import contextmanager

import pytest

from .core import StreamEmpty, streamcontext, pipable_operator
from typing import TYPE_CHECKING, Any, Callable, List

if TYPE_CHECKING:
    from _pytest.fixtures import SubRequest
    from aiostream.core import Stream

__all__ = ["add_resource", "assert_run", "event_loop"]


@pipable_operator
async def add_resource(source, cleanup_time):
    """Simulate an open resource in a stream operator."""
    try:
        loop = asyncio.get_event_loop()
        loop.open_resources += 1
        loop.resources += 1
        async with streamcontext(source) as streamer:
            async for item in streamer:
                yield item
    finally:
        try:
            await asyncio.sleep(cleanup_time)
        finally:
            loop.open_resources -= 1


def compare_exceptions(
    exc1: Exception,
    exc2: Exception,
) -> bool:
    """Compare two exceptions together."""
    return exc1 == exc2 or exc1.__class__ == exc2.__class__ and exc1.args == exc2.args


async def assert_aiter(
    source: Stream,
    values: List[Any],
    exception: Exception | None = None,
) -> None:
    """Check the results of a stream using a streamcontext."""
    results = []
    exception_type = (type(exception),) if exception else ()
    try:
        async with streamcontext(source) as streamer:
            async for item in streamer:
                results.append(item)
    except exception_type as exc:
        assert exception is not None
        assert compare_exceptions(exc, exception)
    else:
        assert exception is None
    assert results == values


async def assert_await(
    source: Stream,
    values: List[Any],
    exception: Exception | None = None,
) -> None:
    """Check the results of a stream using by awaiting it."""
    exception_type = (type(exception),) if exception else ()
    try:
        result = await source
    except StreamEmpty:
        assert values == []
        assert exception is None
    except exception_type as exc:
        assert exception is not None
        assert compare_exceptions(exc, exception)
    else:
        assert result == values[-1]
        assert exception is None


@pytest.fixture(params=[assert_aiter, assert_await], ids=["aiter", "await"])
def assert_run(request: SubRequest) -> Callable:
    """Parametrized fixture returning a stream runner."""
    return request.param


@pytest.fixture
def event_loop():
    """Fixture providing a test event loop.

    The event loop simulate and records the sleep operation,
    available as event_loop.steps

    It also tracks simulated resources and make sure they are
    all released before the loop is closed.
    """

    class TimeTrackingTestLoop(asyncio.BaseEventLoop):
        stuck_threshold = 100

        def __init__(self):
            super().__init__()
            self._time = 0
            self._timers = []
            self._selector = Mock()
            self.clear()

        # Loop internals

        def _run_once(self):
            super()._run_once()
            # Update internals
            self.busy_count += 1
            self._timers = sorted(when for when in self._timers if when > loop.time())
            # Time advance
            if self.time_to_go:
                when = self._timers.pop(0)
                step = when - loop.time()
                self.steps.append(step)
                self.advance_time(step)
                self.busy_count = 0

        def _process_events(self, event_list):
            return

        def _write_to_self(self):
            return

        # Time management

        def time(self):
            return self._time

        def advance_time(self, advance):
            if advance:
                self._time += advance

        def call_at(self, when, callback, *args, **kwargs):
            self._timers.append(when)
            return super().call_at(when, callback, *args, **kwargs)

        @property
        def stuck(self):
            return self.busy_count > self.stuck_threshold

        @property
        def time_to_go(self):
            return self._timers and (self.stuck or not self._ready)

        # Resource management

        def clear(self):
            self.steps = []
            self.open_resources = 0
            self.resources = 0
            self.busy_count = 0

        @contextmanager
        def assert_cleanup(self):
            self.clear()
            yield self
            assert self.open_resources == 0
            self.clear()

    loop = TimeTrackingTestLoop()
    asyncio.set_event_loop(loop)
    with loop.assert_cleanup():
        yield loop
    loop.close()