about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/starlette/testclient.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/starlette/testclient.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/testclient.py')
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/testclient.py731
1 files changed, 731 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/testclient.py b/.venv/lib/python3.12/site-packages/starlette/testclient.py
new file mode 100644
index 00000000..d54025e5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/testclient.py
@@ -0,0 +1,731 @@
+from __future__ import annotations
+
+import contextlib
+import inspect
+import io
+import json
+import math
+import sys
+import typing
+import warnings
+from concurrent.futures import Future
+from types import GeneratorType
+from urllib.parse import unquote, urljoin
+
+import anyio
+import anyio.abc
+import anyio.from_thread
+from anyio.streams.stapled import StapledObjectStream
+
+from starlette._utils import is_async_callable
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
+from starlette.websockets import WebSocketDisconnect
+
+if sys.version_info >= (3, 10):  # pragma: no cover
+    from typing import TypeGuard
+else:  # pragma: no cover
+    from typing_extensions import TypeGuard
+
+try:
+    import httpx
+except ModuleNotFoundError:  # pragma: no cover
+    raise RuntimeError(
+        "The starlette.testclient module requires the httpx package to be installed.\n"
+        "You can install this with:\n"
+        "    $ pip install httpx\n"
+    )
+_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]
+
+ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
+ASGI2App = typing.Callable[[Scope], ASGIInstance]
+ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
+
+
+_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]]
+
+
+def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
+    if inspect.isclass(app):
+        return hasattr(app, "__await__")
+    return is_async_callable(app)
+
+
+class _WrapASGI2:
+    """
+    Provide an ASGI3 interface onto an ASGI2 app.
+    """
+
+    def __init__(self, app: ASGI2App) -> None:
+        self.app = app
+
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        instance = self.app(scope)
+        await instance(receive, send)
+
+
+class _AsyncBackend(typing.TypedDict):
+    backend: str
+    backend_options: dict[str, typing.Any]
+
+
+class _Upgrade(Exception):
+    def __init__(self, session: WebSocketTestSession) -> None:
+        self.session = session
+
+
+class WebSocketDenialResponse(  # type: ignore[misc]
+    httpx.Response,
+    WebSocketDisconnect,
+):
+    """
+    A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
+    `WebSocket` is closed before being accepted with a `send_denial_response()`.
+    """
+
+
+class WebSocketTestSession:
+    def __init__(
+        self,
+        app: ASGI3App,
+        scope: Scope,
+        portal_factory: _PortalFactoryType,
+    ) -> None:
+        self.app = app
+        self.scope = scope
+        self.accepted_subprotocol = None
+        self.portal_factory = portal_factory
+        self.extra_headers = None
+
+    def __enter__(self) -> WebSocketTestSession:
+        with contextlib.ExitStack() as stack:
+            self.portal = portal = stack.enter_context(self.portal_factory())
+            fut, cs = portal.start_task(self._run)
+            stack.callback(fut.result)
+            stack.callback(portal.call, cs.cancel)
+            self.send({"type": "websocket.connect"})
+            message = self.receive()
+            self._raise_on_close(message)
+            self.accepted_subprotocol = message.get("subprotocol", None)
+            self.extra_headers = message.get("headers", None)
+            stack.callback(self.close, 1000)
+            self.exit_stack = stack.pop_all()
+            return self
+
+    def __exit__(self, *args: typing.Any) -> bool | None:
+        return self.exit_stack.__exit__(*args)
+
+    async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
+        """
+        The sub-thread in which the websocket session runs.
+        """
+        send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
+        send_tx, send_rx = send
+        receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
+        receive_tx, receive_rx = receive
+        with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
+            self._receive_tx = receive_tx
+            self._send_rx = send_rx
+            task_status.started(cs)
+            await self.app(self.scope, receive_rx.receive, send_tx.send)
+
+            # wait for cs.cancel to be called before closing streams
+            await anyio.sleep_forever()
+
+    def _raise_on_close(self, message: Message) -> None:
+        if message["type"] == "websocket.close":
+            raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
+        elif message["type"] == "websocket.http.response.start":
+            status_code: int = message["status"]
+            headers: list[tuple[bytes, bytes]] = message["headers"]
+            body: list[bytes] = []
+            while True:
+                message = self.receive()
+                assert message["type"] == "websocket.http.response.body"
+                body.append(message["body"])
+                if not message.get("more_body", False):
+                    break
+            raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
+
+    def send(self, message: Message) -> None:
+        self.portal.call(self._receive_tx.send, message)
+
+    def send_text(self, data: str) -> None:
+        self.send({"type": "websocket.receive", "text": data})
+
+    def send_bytes(self, data: bytes) -> None:
+        self.send({"type": "websocket.receive", "bytes": data})
+
+    def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None:
+        text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
+        if mode == "text":
+            self.send({"type": "websocket.receive", "text": text})
+        else:
+            self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
+
+    def close(self, code: int = 1000, reason: str | None = None) -> None:
+        self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
+
+    def receive(self) -> Message:
+        return self.portal.call(self._send_rx.receive)
+
+    def receive_text(self) -> str:
+        message = self.receive()
+        self._raise_on_close(message)
+        return typing.cast(str, message["text"])
+
+    def receive_bytes(self) -> bytes:
+        message = self.receive()
+        self._raise_on_close(message)
+        return typing.cast(bytes, message["bytes"])
+
+    def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any:
+        message = self.receive()
+        self._raise_on_close(message)
+        if mode == "text":
+            text = message["text"]
+        else:
+            text = message["bytes"].decode("utf-8")
+        return json.loads(text)
+
+
+class _TestClientTransport(httpx.BaseTransport):
+    def __init__(
+        self,
+        app: ASGI3App,
+        portal_factory: _PortalFactoryType,
+        raise_server_exceptions: bool = True,
+        root_path: str = "",
+        *,
+        client: tuple[str, int],
+        app_state: dict[str, typing.Any],
+    ) -> None:
+        self.app = app
+        self.raise_server_exceptions = raise_server_exceptions
+        self.root_path = root_path
+        self.portal_factory = portal_factory
+        self.app_state = app_state
+        self.client = client
+
+    def handle_request(self, request: httpx.Request) -> httpx.Response:
+        scheme = request.url.scheme
+        netloc = request.url.netloc.decode(encoding="ascii")
+        path = request.url.path
+        raw_path = request.url.raw_path
+        query = request.url.query.decode(encoding="ascii")
+
+        default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
+
+        if ":" in netloc:
+            host, port_string = netloc.split(":", 1)
+            port = int(port_string)
+        else:
+            host = netloc
+            port = default_port
+
+        # Include the 'host' header.
+        if "host" in request.headers:
+            headers: list[tuple[bytes, bytes]] = []
+        elif port == default_port:  # pragma: no cover
+            headers = [(b"host", host.encode())]
+        else:  # pragma: no cover
+            headers = [(b"host", (f"{host}:{port}").encode())]
+
+        # Include other request headers.
+        headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
+
+        scope: dict[str, typing.Any]
+
+        if scheme in {"ws", "wss"}:
+            subprotocol = request.headers.get("sec-websocket-protocol", None)
+            if subprotocol is None:
+                subprotocols: typing.Sequence[str] = []
+            else:
+                subprotocols = [value.strip() for value in subprotocol.split(",")]
+            scope = {
+                "type": "websocket",
+                "path": unquote(path),
+                "raw_path": raw_path.split(b"?", 1)[0],
+                "root_path": self.root_path,
+                "scheme": scheme,
+                "query_string": query.encode(),
+                "headers": headers,
+                "client": self.client,
+                "server": [host, port],
+                "subprotocols": subprotocols,
+                "state": self.app_state.copy(),
+                "extensions": {"websocket.http.response": {}},
+            }
+            session = WebSocketTestSession(self.app, scope, self.portal_factory)
+            raise _Upgrade(session)
+
+        scope = {
+            "type": "http",
+            "http_version": "1.1",
+            "method": request.method,
+            "path": unquote(path),
+            "raw_path": raw_path.split(b"?", 1)[0],
+            "root_path": self.root_path,
+            "scheme": scheme,
+            "query_string": query.encode(),
+            "headers": headers,
+            "client": self.client,
+            "server": [host, port],
+            "extensions": {"http.response.debug": {}},
+            "state": self.app_state.copy(),
+        }
+
+        request_complete = False
+        response_started = False
+        response_complete: anyio.Event
+        raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
+        template = None
+        context = None
+
+        async def receive() -> Message:
+            nonlocal request_complete
+
+            if request_complete:
+                if not response_complete.is_set():
+                    await response_complete.wait()
+                return {"type": "http.disconnect"}
+
+            body = request.read()
+            if isinstance(body, str):
+                body_bytes: bytes = body.encode("utf-8")  # pragma: no cover
+            elif body is None:
+                body_bytes = b""  # pragma: no cover
+            elif isinstance(body, GeneratorType):
+                try:  # pragma: no cover
+                    chunk = body.send(None)
+                    if isinstance(chunk, str):
+                        chunk = chunk.encode("utf-8")
+                    return {"type": "http.request", "body": chunk, "more_body": True}
+                except StopIteration:  # pragma: no cover
+                    request_complete = True
+                    return {"type": "http.request", "body": b""}
+            else:
+                body_bytes = body
+
+            request_complete = True
+            return {"type": "http.request", "body": body_bytes}
+
+        async def send(message: Message) -> None:
+            nonlocal raw_kwargs, response_started, template, context
+
+            if message["type"] == "http.response.start":
+                assert not response_started, 'Received multiple "http.response.start" messages.'
+                raw_kwargs["status_code"] = message["status"]
+                raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
+                response_started = True
+            elif message["type"] == "http.response.body":
+                assert response_started, 'Received "http.response.body" without "http.response.start".'
+                assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
+                body = message.get("body", b"")
+                more_body = message.get("more_body", False)
+                if request.method != "HEAD":
+                    raw_kwargs["stream"].write(body)
+                if not more_body:
+                    raw_kwargs["stream"].seek(0)
+                    response_complete.set()
+            elif message["type"] == "http.response.debug":
+                template = message["info"]["template"]
+                context = message["info"]["context"]
+
+        try:
+            with self.portal_factory() as portal:
+                response_complete = portal.call(anyio.Event)
+                portal.call(self.app, scope, receive, send)
+        except BaseException as exc:
+            if self.raise_server_exceptions:
+                raise exc
+
+        if self.raise_server_exceptions:
+            assert response_started, "TestClient did not receive any response."
+        elif not response_started:
+            raw_kwargs = {
+                "status_code": 500,
+                "headers": [],
+                "stream": io.BytesIO(),
+            }
+
+        raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
+
+        response = httpx.Response(**raw_kwargs, request=request)
+        if template is not None:
+            response.template = template  # type: ignore[attr-defined]
+            response.context = context  # type: ignore[attr-defined]
+        return response
+
+
+class TestClient(httpx.Client):
+    __test__ = False
+    task: Future[None]
+    portal: anyio.abc.BlockingPortal | None = None
+
+    def __init__(
+        self,
+        app: ASGIApp,
+        base_url: str = "http://testserver",
+        raise_server_exceptions: bool = True,
+        root_path: str = "",
+        backend: typing.Literal["asyncio", "trio"] = "asyncio",
+        backend_options: dict[str, typing.Any] | None = None,
+        cookies: httpx._types.CookieTypes | None = None,
+        headers: dict[str, str] | None = None,
+        follow_redirects: bool = True,
+        client: tuple[str, int] = ("testclient", 50000),
+    ) -> None:
+        self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
+        if _is_asgi3(app):
+            asgi_app = app
+        else:
+            app = typing.cast(ASGI2App, app)  # type: ignore[assignment]
+            asgi_app = _WrapASGI2(app)  # type: ignore[arg-type]
+        self.app = asgi_app
+        self.app_state: dict[str, typing.Any] = {}
+        transport = _TestClientTransport(
+            self.app,
+            portal_factory=self._portal_factory,
+            raise_server_exceptions=raise_server_exceptions,
+            root_path=root_path,
+            app_state=self.app_state,
+            client=client,
+        )
+        if headers is None:
+            headers = {}
+        headers.setdefault("user-agent", "testclient")
+        super().__init__(
+            base_url=base_url,
+            headers=headers,
+            transport=transport,
+            follow_redirects=follow_redirects,
+            cookies=cookies,
+        )
+
+    @contextlib.contextmanager
+    def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
+        if self.portal is not None:
+            yield self.portal
+        else:
+            with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
+                yield portal
+
+    def request(  # type: ignore[override]
+        self,
+        method: str,
+        url: httpx._types.URLTypes,
+        *,
+        content: httpx._types.RequestContent | None = None,
+        data: _RequestData | None = None,
+        files: httpx._types.RequestFiles | None = None,
+        json: typing.Any = None,
+        params: httpx._types.QueryParamTypes | None = None,
+        headers: httpx._types.HeaderTypes | None = None,
+        cookies: httpx._types.CookieTypes | None = None,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict[str, typing.Any] | None = None,
+    ) -> httpx.Response:
+        if timeout is not httpx.USE_CLIENT_DEFAULT:
+            warnings.warn(
+                "You should not use the 'timeout' argument with the TestClient. "
+                "See https://github.com/encode/starlette/issues/1108 for more information.",
+                DeprecationWarning,
+            )
+        url = self._merge_url(url)
+        return super().request(
+            method,
+            url,
+            content=content,
+            data=data,
+            files=files,
+            json=json,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=follow_redirects,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def get(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        params: httpx._types.QueryParamTypes | None = None,
+        headers: httpx._types.HeaderTypes | None = None,
+        cookies: httpx._types.CookieTypes | None = None,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict[str, typing.Any] | None = None,
+    ) -> httpx.Response:
+        return super().get(
+            url,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=follow_redirects,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def options(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        params: httpx._types.QueryParamTypes | None = None,
+        headers: httpx._types.HeaderTypes | None = None,
+        cookies: httpx._types.CookieTypes | None = None,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict[str, typing.Any] | None = None,
+    ) -> httpx.Response:
+        return super().options(
+            url,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=follow_redirects,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def head(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        params: httpx._types.QueryParamTypes | None = None,
+        headers: httpx._types.HeaderTypes | None = None,
+        cookies: httpx._types.CookieTypes | None = None,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict[str, typing.Any] | None = None,
+    ) -> httpx.Response:
+        return super().head(
+            url,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=follow_redirects,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def post(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        content: httpx._types.RequestContent | None = None,
+        data: _RequestData | None = None,
+        files: httpx._types.RequestFiles | None = None,
+        json: typing.Any = None,
+        params: httpx._types.QueryParamTypes | None = None,
+        headers: httpx._types.HeaderTypes | None = None,
+        cookies: httpx._types.CookieTypes | None = None,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict[str, typing.Any] | None = None,
+    ) -> httpx.Response:
+        return super().post(
+            url,
+            content=content,
+            data=data,
+            files=files,
+            json=json,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=follow_redirects,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def put(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        content: httpx._types.RequestContent | None = None,
+        data: _RequestData | None = None,
+        files: httpx._types.RequestFiles | None = None,
+        json: typing.Any = None,
+        params: httpx._types.QueryParamTypes | None = None,
+        headers: httpx._types.HeaderTypes | None = None,
+        cookies: httpx._types.CookieTypes | None = None,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict[str, typing.Any] | None = None,
+    ) -> httpx.Response:
+        return super().put(
+            url,
+            content=content,
+            data=data,
+            files=files,
+            json=json,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=follow_redirects,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def patch(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        content: httpx._types.RequestContent | None = None,
+        data: _RequestData | None = None,
+        files: httpx._types.RequestFiles | None = None,
+        json: typing.Any = None,
+        params: httpx._types.QueryParamTypes | None = None,
+        headers: httpx._types.HeaderTypes | None = None,
+        cookies: httpx._types.CookieTypes | None = None,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict[str, typing.Any] | None = None,
+    ) -> httpx.Response:
+        return super().patch(
+            url,
+            content=content,
+            data=data,
+            files=files,
+            json=json,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=follow_redirects,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def delete(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        params: httpx._types.QueryParamTypes | None = None,
+        headers: httpx._types.HeaderTypes | None = None,
+        cookies: httpx._types.CookieTypes | None = None,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict[str, typing.Any] | None = None,
+    ) -> httpx.Response:
+        return super().delete(
+            url,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=follow_redirects,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def websocket_connect(
+        self,
+        url: str,
+        subprotocols: typing.Sequence[str] | None = None,
+        **kwargs: typing.Any,
+    ) -> WebSocketTestSession:
+        url = urljoin("ws://testserver", url)
+        headers = kwargs.get("headers", {})
+        headers.setdefault("connection", "upgrade")
+        headers.setdefault("sec-websocket-key", "testserver==")
+        headers.setdefault("sec-websocket-version", "13")
+        if subprotocols is not None:
+            headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
+        kwargs["headers"] = headers
+        try:
+            super().request("GET", url, **kwargs)
+        except _Upgrade as exc:
+            session = exc.session
+        else:
+            raise RuntimeError("Expected WebSocket upgrade")  # pragma: no cover
+
+        return session
+
+    def __enter__(self) -> TestClient:
+        with contextlib.ExitStack() as stack:
+            self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
+
+            @stack.callback
+            def reset_portal() -> None:
+                self.portal = None
+
+            send: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None] = (
+                anyio.create_memory_object_stream(math.inf)
+            )
+            receive: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]] = (
+                anyio.create_memory_object_stream(math.inf)
+            )
+            for channel in (*send, *receive):
+                stack.callback(channel.close)
+            self.stream_send = StapledObjectStream(*send)
+            self.stream_receive = StapledObjectStream(*receive)
+            self.task = portal.start_task_soon(self.lifespan)
+            portal.call(self.wait_startup)
+
+            @stack.callback
+            def wait_shutdown() -> None:
+                portal.call(self.wait_shutdown)
+
+            self.exit_stack = stack.pop_all()
+
+        return self
+
+    def __exit__(self, *args: typing.Any) -> None:
+        self.exit_stack.close()
+
+    async def lifespan(self) -> None:
+        scope = {"type": "lifespan", "state": self.app_state}
+        try:
+            await self.app(scope, self.stream_receive.receive, self.stream_send.send)
+        finally:
+            await self.stream_send.send(None)
+
+    async def wait_startup(self) -> None:
+        await self.stream_receive.send({"type": "lifespan.startup"})
+
+        async def receive() -> typing.Any:
+            message = await self.stream_send.receive()
+            if message is None:
+                self.task.result()
+            return message
+
+        message = await receive()
+        assert message["type"] in (
+            "lifespan.startup.complete",
+            "lifespan.startup.failed",
+        )
+        if message["type"] == "lifespan.startup.failed":
+            await receive()
+
+    async def wait_shutdown(self) -> None:
+        async def receive() -> typing.Any:
+            message = await self.stream_send.receive()
+            if message is None:
+                self.task.result()
+            return message
+
+        await self.stream_receive.send({"type": "lifespan.shutdown"})
+        message = await receive()
+        assert message["type"] in (
+            "lifespan.shutdown.complete",
+            "lifespan.shutdown.failed",
+        )
+        if message["type"] == "lifespan.shutdown.failed":
+            await receive()