about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/aiohttp/test_utils.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/aiohttp/test_utils.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiohttp/test_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/aiohttp/test_utils.py770
1 files changed, 770 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiohttp/test_utils.py b/.venv/lib/python3.12/site-packages/aiohttp/test_utils.py
new file mode 100644
index 00000000..be6e9b33
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/aiohttp/test_utils.py
@@ -0,0 +1,770 @@
+"""Utilities shared by tests."""
+
+import asyncio
+import contextlib
+import gc
+import inspect
+import ipaddress
+import os
+import socket
+import sys
+import warnings
+from abc import ABC, abstractmethod
+from types import TracebackType
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Generic,
+    Iterator,
+    List,
+    Optional,
+    Type,
+    TypeVar,
+    cast,
+    overload,
+)
+from unittest import IsolatedAsyncioTestCase, mock
+
+from aiosignal import Signal
+from multidict import CIMultiDict, CIMultiDictProxy
+from yarl import URL
+
+import aiohttp
+from aiohttp.client import (
+    _RequestContextManager,
+    _RequestOptions,
+    _WSRequestContextManager,
+)
+
+from . import ClientSession, hdrs
+from .abc import AbstractCookieJar
+from .client_reqrep import ClientResponse
+from .client_ws import ClientWebSocketResponse
+from .helpers import sentinel
+from .http import HttpVersion, RawRequestMessage
+from .streams import EMPTY_PAYLOAD, StreamReader
+from .typedefs import StrOrURL
+from .web import (
+    Application,
+    AppRunner,
+    BaseRequest,
+    BaseRunner,
+    Request,
+    Server,
+    ServerRunner,
+    SockSite,
+    UrlMappingMatchInfo,
+)
+from .web_protocol import _RequestHandler
+
+if TYPE_CHECKING:
+    from ssl import SSLContext
+else:
+    SSLContext = None
+
+if sys.version_info >= (3, 11) and TYPE_CHECKING:
+    from typing import Unpack
+
+if sys.version_info >= (3, 11):
+    from typing import Self
+else:
+    Self = Any
+
+_ApplicationNone = TypeVar("_ApplicationNone", Application, None)
+_Request = TypeVar("_Request", bound=BaseRequest)
+
+REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
+
+
+def get_unused_port_socket(
+    host: str, family: socket.AddressFamily = socket.AF_INET
+) -> socket.socket:
+    return get_port_socket(host, 0, family)
+
+
+def get_port_socket(
+    host: str, port: int, family: socket.AddressFamily
+) -> socket.socket:
+    s = socket.socket(family, socket.SOCK_STREAM)
+    if REUSE_ADDRESS:
+        # Windows has different semantics for SO_REUSEADDR,
+        # so don't set it. Ref:
+        # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
+        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+    s.bind((host, port))
+    return s
+
+
+def unused_port() -> int:
+    """Return a port that is unused on the current host."""
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+        s.bind(("127.0.0.1", 0))
+        return cast(int, s.getsockname()[1])
+
+
+class BaseTestServer(ABC):
+    __test__ = False
+
+    def __init__(
+        self,
+        *,
+        scheme: str = "",
+        loop: Optional[asyncio.AbstractEventLoop] = None,
+        host: str = "127.0.0.1",
+        port: Optional[int] = None,
+        skip_url_asserts: bool = False,
+        socket_factory: Callable[
+            [str, int, socket.AddressFamily], socket.socket
+        ] = get_port_socket,
+        **kwargs: Any,
+    ) -> None:
+        self._loop = loop
+        self.runner: Optional[BaseRunner] = None
+        self._root: Optional[URL] = None
+        self.host = host
+        self.port = port
+        self._closed = False
+        self.scheme = scheme
+        self.skip_url_asserts = skip_url_asserts
+        self.socket_factory = socket_factory
+
+    async def start_server(
+        self, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any
+    ) -> None:
+        if self.runner:
+            return
+        self._loop = loop
+        self._ssl = kwargs.pop("ssl", None)
+        self.runner = await self._make_runner(handler_cancellation=True, **kwargs)
+        await self.runner.setup()
+        if not self.port:
+            self.port = 0
+        absolute_host = self.host
+        try:
+            version = ipaddress.ip_address(self.host).version
+        except ValueError:
+            version = 4
+        if version == 6:
+            absolute_host = f"[{self.host}]"
+        family = socket.AF_INET6 if version == 6 else socket.AF_INET
+        _sock = self.socket_factory(self.host, self.port, family)
+        self.host, self.port = _sock.getsockname()[:2]
+        site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
+        await site.start()
+        server = site._server
+        assert server is not None
+        sockets = server.sockets  # type: ignore[attr-defined]
+        assert sockets is not None
+        self.port = sockets[0].getsockname()[1]
+        if not self.scheme:
+            self.scheme = "https" if self._ssl else "http"
+        self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}")
+
+    @abstractmethod  # pragma: no cover
+    async def _make_runner(self, **kwargs: Any) -> BaseRunner:
+        pass
+
+    def make_url(self, path: StrOrURL) -> URL:
+        assert self._root is not None
+        url = URL(path)
+        if not self.skip_url_asserts:
+            assert not url.absolute
+            return self._root.join(url)
+        else:
+            return URL(str(self._root) + str(path))
+
+    @property
+    def started(self) -> bool:
+        return self.runner is not None
+
+    @property
+    def closed(self) -> bool:
+        return self._closed
+
+    @property
+    def handler(self) -> Server:
+        # for backward compatibility
+        # web.Server instance
+        runner = self.runner
+        assert runner is not None
+        assert runner.server is not None
+        return runner.server
+
+    async def close(self) -> None:
+        """Close all fixtures created by the test client.
+
+        After that point, the TestClient is no longer usable.
+
+        This is an idempotent function: running close multiple times
+        will not have any additional effects.
+
+        close is also run when the object is garbage collected, and on
+        exit when used as a context manager.
+
+        """
+        if self.started and not self.closed:
+            assert self.runner is not None
+            await self.runner.cleanup()
+            self._root = None
+            self.port = None
+            self._closed = True
+
+    def __enter__(self) -> None:
+        raise TypeError("Use async with instead")
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
+        # __exit__ should exist in pair with __enter__ but never executed
+        pass  # pragma: no cover
+
+    async def __aenter__(self) -> "BaseTestServer":
+        await self.start_server(loop=self._loop)
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
+        await self.close()
+
+
+class TestServer(BaseTestServer):
+    def __init__(
+        self,
+        app: Application,
+        *,
+        scheme: str = "",
+        host: str = "127.0.0.1",
+        port: Optional[int] = None,
+        **kwargs: Any,
+    ):
+        self.app = app
+        super().__init__(scheme=scheme, host=host, port=port, **kwargs)
+
+    async def _make_runner(self, **kwargs: Any) -> BaseRunner:
+        return AppRunner(self.app, **kwargs)
+
+
+class RawTestServer(BaseTestServer):
+    def __init__(
+        self,
+        handler: _RequestHandler,
+        *,
+        scheme: str = "",
+        host: str = "127.0.0.1",
+        port: Optional[int] = None,
+        **kwargs: Any,
+    ) -> None:
+        self._handler = handler
+        super().__init__(scheme=scheme, host=host, port=port, **kwargs)
+
+    async def _make_runner(self, debug: bool = True, **kwargs: Any) -> ServerRunner:
+        srv = Server(self._handler, loop=self._loop, debug=debug, **kwargs)
+        return ServerRunner(srv, debug=debug, **kwargs)
+
+
+class TestClient(Generic[_Request, _ApplicationNone]):
+    """
+    A test client implementation.
+
+    To write functional tests for aiohttp based servers.
+
+    """
+
+    __test__ = False
+
+    @overload
+    def __init__(
+        self: "TestClient[Request, Application]",
+        server: TestServer,
+        *,
+        cookie_jar: Optional[AbstractCookieJar] = None,
+        **kwargs: Any,
+    ) -> None: ...
+    @overload
+    def __init__(
+        self: "TestClient[_Request, None]",
+        server: BaseTestServer,
+        *,
+        cookie_jar: Optional[AbstractCookieJar] = None,
+        **kwargs: Any,
+    ) -> None: ...
+    def __init__(
+        self,
+        server: BaseTestServer,
+        *,
+        cookie_jar: Optional[AbstractCookieJar] = None,
+        loop: Optional[asyncio.AbstractEventLoop] = None,
+        **kwargs: Any,
+    ) -> None:
+        if not isinstance(server, BaseTestServer):
+            raise TypeError(
+                "server must be TestServer instance, found type: %r" % type(server)
+            )
+        self._server = server
+        self._loop = loop
+        if cookie_jar is None:
+            cookie_jar = aiohttp.CookieJar(unsafe=True, loop=loop)
+        self._session = ClientSession(loop=loop, cookie_jar=cookie_jar, **kwargs)
+        self._session._retry_connection = False
+        self._closed = False
+        self._responses: List[ClientResponse] = []
+        self._websockets: List[ClientWebSocketResponse] = []
+
+    async def start_server(self) -> None:
+        await self._server.start_server(loop=self._loop)
+
+    @property
+    def host(self) -> str:
+        return self._server.host
+
+    @property
+    def port(self) -> Optional[int]:
+        return self._server.port
+
+    @property
+    def server(self) -> BaseTestServer:
+        return self._server
+
+    @property
+    def app(self) -> _ApplicationNone:
+        return getattr(self._server, "app", None)  # type: ignore[return-value]
+
+    @property
+    def session(self) -> ClientSession:
+        """An internal aiohttp.ClientSession.
+
+        Unlike the methods on the TestClient, client session requests
+        do not automatically include the host in the url queried, and
+        will require an absolute path to the resource.
+
+        """
+        return self._session
+
+    def make_url(self, path: StrOrURL) -> URL:
+        return self._server.make_url(path)
+
+    async def _request(
+        self, method: str, path: StrOrURL, **kwargs: Any
+    ) -> ClientResponse:
+        resp = await self._session.request(method, self.make_url(path), **kwargs)
+        # save it to close later
+        self._responses.append(resp)
+        return resp
+
+    if sys.version_info >= (3, 11) and TYPE_CHECKING:
+
+        def request(
+            self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions]
+        ) -> _RequestContextManager: ...
+
+        def get(
+            self,
+            path: StrOrURL,
+            **kwargs: Unpack[_RequestOptions],
+        ) -> _RequestContextManager: ...
+
+        def options(
+            self,
+            path: StrOrURL,
+            **kwargs: Unpack[_RequestOptions],
+        ) -> _RequestContextManager: ...
+
+        def head(
+            self,
+            path: StrOrURL,
+            **kwargs: Unpack[_RequestOptions],
+        ) -> _RequestContextManager: ...
+
+        def post(
+            self,
+            path: StrOrURL,
+            **kwargs: Unpack[_RequestOptions],
+        ) -> _RequestContextManager: ...
+
+        def put(
+            self,
+            path: StrOrURL,
+            **kwargs: Unpack[_RequestOptions],
+        ) -> _RequestContextManager: ...
+
+        def patch(
+            self,
+            path: StrOrURL,
+            **kwargs: Unpack[_RequestOptions],
+        ) -> _RequestContextManager: ...
+
+        def delete(
+            self,
+            path: StrOrURL,
+            **kwargs: Unpack[_RequestOptions],
+        ) -> _RequestContextManager: ...
+
+    else:
+
+        def request(
+            self, method: str, path: StrOrURL, **kwargs: Any
+        ) -> _RequestContextManager:
+            """Routes a request to tested http server.
+
+            The interface is identical to aiohttp.ClientSession.request,
+            except the loop kwarg is overridden by the instance used by the
+            test server.
+
+            """
+            return _RequestContextManager(self._request(method, path, **kwargs))
+
+        def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+            """Perform an HTTP GET request."""
+            return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
+
+        def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+            """Perform an HTTP POST request."""
+            return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
+
+        def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+            """Perform an HTTP OPTIONS request."""
+            return _RequestContextManager(
+                self._request(hdrs.METH_OPTIONS, path, **kwargs)
+            )
+
+        def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+            """Perform an HTTP HEAD request."""
+            return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
+
+        def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+            """Perform an HTTP PUT request."""
+            return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
+
+        def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+            """Perform an HTTP PATCH request."""
+            return _RequestContextManager(
+                self._request(hdrs.METH_PATCH, path, **kwargs)
+            )
+
+        def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+            """Perform an HTTP PATCH request."""
+            return _RequestContextManager(
+                self._request(hdrs.METH_DELETE, path, **kwargs)
+            )
+
+    def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager:
+        """Initiate websocket connection.
+
+        The api corresponds to aiohttp.ClientSession.ws_connect.
+
+        """
+        return _WSRequestContextManager(self._ws_connect(path, **kwargs))
+
+    async def _ws_connect(
+        self, path: StrOrURL, **kwargs: Any
+    ) -> ClientWebSocketResponse:
+        ws = await self._session.ws_connect(self.make_url(path), **kwargs)
+        self._websockets.append(ws)
+        return ws
+
+    async def close(self) -> None:
+        """Close all fixtures created by the test client.
+
+        After that point, the TestClient is no longer usable.
+
+        This is an idempotent function: running close multiple times
+        will not have any additional effects.
+
+        close is also run on exit when used as a(n) (asynchronous)
+        context manager.
+
+        """
+        if not self._closed:
+            for resp in self._responses:
+                resp.close()
+            for ws in self._websockets:
+                await ws.close()
+            await self._session.close()
+            await self._server.close()
+            self._closed = True
+
+    def __enter__(self) -> None:
+        raise TypeError("Use async with instead")
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> None:
+        # __exit__ should exist in pair with __enter__ but never executed
+        pass  # pragma: no cover
+
+    async def __aenter__(self) -> Self:
+        await self.start_server()
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> None:
+        await self.close()
+
+
+class AioHTTPTestCase(IsolatedAsyncioTestCase):
+    """A base class to allow for unittest web applications using aiohttp.
+
+    Provides the following:
+
+    * self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
+    * self.loop (asyncio.BaseEventLoop): the event loop in which the
+        application and server are running.
+    * self.app (aiohttp.web.Application): the application returned by
+        self.get_application()
+
+    Note that the TestClient's methods are asynchronous: you have to
+    execute function on the test client using asynchronous methods.
+    """
+
+    async def get_application(self) -> Application:
+        """Get application.
+
+        This method should be overridden
+        to return the aiohttp.web.Application
+        object to test.
+        """
+        return self.get_app()
+
+    def get_app(self) -> Application:
+        """Obsolete method used to constructing web application.
+
+        Use .get_application() coroutine instead.
+        """
+        raise RuntimeError("Did you forget to define get_application()?")
+
+    async def asyncSetUp(self) -> None:
+        self.loop = asyncio.get_running_loop()
+        return await self.setUpAsync()
+
+    async def setUpAsync(self) -> None:
+        self.app = await self.get_application()
+        self.server = await self.get_server(self.app)
+        self.client = await self.get_client(self.server)
+
+        await self.client.start_server()
+
+    async def asyncTearDown(self) -> None:
+        return await self.tearDownAsync()
+
+    async def tearDownAsync(self) -> None:
+        await self.client.close()
+
+    async def get_server(self, app: Application) -> TestServer:
+        """Return a TestServer instance."""
+        return TestServer(app, loop=self.loop)
+
+    async def get_client(self, server: TestServer) -> TestClient[Request, Application]:
+        """Return a TestClient instance."""
+        return TestClient(server, loop=self.loop)
+
+
+def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any:
+    """
+    A decorator dedicated to use with asynchronous AioHTTPTestCase test methods.
+
+    In 3.8+, this does nothing.
+    """
+    warnings.warn(
+        "Decorator `@unittest_run_loop` is no longer needed in aiohttp 3.8+",
+        DeprecationWarning,
+        stacklevel=2,
+    )
+    return func
+
+
+_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop]
+
+
+@contextlib.contextmanager
+def loop_context(
+    loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False
+) -> Iterator[asyncio.AbstractEventLoop]:
+    """A contextmanager that creates an event_loop, for test purposes.
+
+    Handles the creation and cleanup of a test loop.
+    """
+    loop = setup_test_loop(loop_factory)
+    yield loop
+    teardown_test_loop(loop, fast=fast)
+
+
+def setup_test_loop(
+    loop_factory: _LOOP_FACTORY = asyncio.new_event_loop,
+) -> asyncio.AbstractEventLoop:
+    """Create and return an asyncio.BaseEventLoop instance.
+
+    The caller should also call teardown_test_loop,
+    once they are done with the loop.
+    """
+    loop = loop_factory()
+    asyncio.set_event_loop(loop)
+    return loop
+
+
+def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None:
+    """Teardown and cleanup an event_loop created by setup_test_loop."""
+    closed = loop.is_closed()
+    if not closed:
+        loop.call_soon(loop.stop)
+        loop.run_forever()
+        loop.close()
+
+    if not fast:
+        gc.collect()
+
+    asyncio.set_event_loop(None)
+
+
+def _create_app_mock() -> mock.MagicMock:
+    def get_dict(app: Any, key: str) -> Any:
+        return app.__app_dict[key]
+
+    def set_dict(app: Any, key: str, value: Any) -> None:
+        app.__app_dict[key] = value
+
+    app = mock.MagicMock(spec=Application)
+    app.__app_dict = {}
+    app.__getitem__ = get_dict
+    app.__setitem__ = set_dict
+
+    app._debug = False
+    app.on_response_prepare = Signal(app)
+    app.on_response_prepare.freeze()
+    return app
+
+
+def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock:
+    transport = mock.Mock()
+
+    def get_extra_info(key: str) -> Optional[SSLContext]:
+        if key == "sslcontext":
+            return sslcontext
+        else:
+            return None
+
+    transport.get_extra_info.side_effect = get_extra_info
+    return transport
+
+
+def make_mocked_request(
+    method: str,
+    path: str,
+    headers: Any = None,
+    *,
+    match_info: Any = sentinel,
+    version: HttpVersion = HttpVersion(1, 1),
+    closing: bool = False,
+    app: Any = None,
+    writer: Any = sentinel,
+    protocol: Any = sentinel,
+    transport: Any = sentinel,
+    payload: StreamReader = EMPTY_PAYLOAD,
+    sslcontext: Optional[SSLContext] = None,
+    client_max_size: int = 1024**2,
+    loop: Any = ...,
+) -> Request:
+    """Creates mocked web.Request testing purposes.
+
+    Useful in unit tests, when spinning full web server is overkill or
+    specific conditions and errors are hard to trigger.
+    """
+    task = mock.Mock()
+    if loop is ...:
+        # no loop passed, try to get the current one if
+        # its is running as we need a real loop to create
+        # executor jobs to be able to do testing
+        # with a real executor
+        try:
+            loop = asyncio.get_running_loop()
+        except RuntimeError:
+            loop = mock.Mock()
+            loop.create_future.return_value = ()
+
+    if version < HttpVersion(1, 1):
+        closing = True
+
+    if headers:
+        headers = CIMultiDictProxy(CIMultiDict(headers))
+        raw_hdrs = tuple(
+            (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
+        )
+    else:
+        headers = CIMultiDictProxy(CIMultiDict())
+        raw_hdrs = ()
+
+    chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower()
+
+    message = RawRequestMessage(
+        method,
+        path,
+        version,
+        headers,
+        raw_hdrs,
+        closing,
+        None,
+        False,
+        chunked,
+        URL(path),
+    )
+    if app is None:
+        app = _create_app_mock()
+
+    if transport is sentinel:
+        transport = _create_transport(sslcontext)
+
+    if protocol is sentinel:
+        protocol = mock.Mock()
+        protocol.transport = transport
+
+    if writer is sentinel:
+        writer = mock.Mock()
+        writer.write_headers = make_mocked_coro(None)
+        writer.write = make_mocked_coro(None)
+        writer.write_eof = make_mocked_coro(None)
+        writer.drain = make_mocked_coro(None)
+        writer.transport = transport
+
+    protocol.transport = transport
+    protocol.writer = writer
+
+    req = Request(
+        message, payload, protocol, writer, task, loop, client_max_size=client_max_size
+    )
+
+    match_info = UrlMappingMatchInfo(
+        {} if match_info is sentinel else match_info, mock.Mock()
+    )
+    match_info.add_app(app)
+    req._match_info = match_info
+
+    return req
+
+
+def make_mocked_coro(
+    return_value: Any = sentinel, raise_exception: Any = sentinel
+) -> Any:
+    """Creates a coroutine mock."""
+
+    async def mock_coro(*args: Any, **kwargs: Any) -> Any:
+        if raise_exception is not sentinel:
+            raise raise_exception
+        if not inspect.isawaitable(return_value):
+            return return_value
+        await return_value
+
+    return mock.Mock(wraps=mock_coro)