aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/starlette/middleware/base.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/middleware/base.py')
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/base.py220
1 files changed, 220 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/base.py b/.venv/lib/python3.12/site-packages/starlette/middleware/base.py
new file mode 100644
index 00000000..2a59337e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/base.py
@@ -0,0 +1,220 @@
+from __future__ import annotations
+
+import typing
+
+import anyio
+
+from starlette._utils import collapse_excgroups
+from starlette.requests import ClientDisconnect, Request
+from starlette.responses import AsyncContentStream, Response
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
+
+RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
+DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
+T = typing.TypeVar("T")
+
+
+class _CachedRequest(Request):
+ """
+ If the user calls Request.body() from their dispatch function
+ we cache the entire request body in memory and pass that to downstream middlewares,
+ but if they call Request.stream() then all we do is send an
+ empty body so that downstream things don't hang forever.
+ """
+
+ def __init__(self, scope: Scope, receive: Receive):
+ super().__init__(scope, receive)
+ self._wrapped_rcv_disconnected = False
+ self._wrapped_rcv_consumed = False
+ self._wrapped_rc_stream = self.stream()
+
+ async def wrapped_receive(self) -> Message:
+ # wrapped_rcv state 1: disconnected
+ if self._wrapped_rcv_disconnected:
+ # we've already sent a disconnect to the downstream app
+ # we don't need to wait to get another one
+ # (although most ASGI servers will just keep sending it)
+ return {"type": "http.disconnect"}
+ # wrapped_rcv state 1: consumed but not yet disconnected
+ if self._wrapped_rcv_consumed:
+ # since the downstream app has consumed us all that is left
+ # is to send it a disconnect
+ if self._is_disconnected:
+ # the middleware has already seen the disconnect
+ # since we know the client is disconnected no need to wait
+ # for the message
+ self._wrapped_rcv_disconnected = True
+ return {"type": "http.disconnect"}
+ # we don't know yet if the client is disconnected or not
+ # so we'll wait until we get that message
+ msg = await self.receive()
+ if msg["type"] != "http.disconnect": # pragma: no cover
+ # at this point a disconnect is all that we should be receiving
+ # if we get something else, things went wrong somewhere
+ raise RuntimeError(f"Unexpected message received: {msg['type']}")
+ self._wrapped_rcv_disconnected = True
+ return msg
+
+ # wrapped_rcv state 3: not yet consumed
+ if getattr(self, "_body", None) is not None:
+ # body() was called, we return it even if the client disconnected
+ self._wrapped_rcv_consumed = True
+ return {
+ "type": "http.request",
+ "body": self._body,
+ "more_body": False,
+ }
+ elif self._stream_consumed:
+ # stream() was called to completion
+ # return an empty body so that downstream apps don't hang
+ # waiting for a disconnect
+ self._wrapped_rcv_consumed = True
+ return {
+ "type": "http.request",
+ "body": b"",
+ "more_body": False,
+ }
+ else:
+ # body() was never called and stream() wasn't consumed
+ try:
+ stream = self.stream()
+ chunk = await stream.__anext__()
+ self._wrapped_rcv_consumed = self._stream_consumed
+ return {
+ "type": "http.request",
+ "body": chunk,
+ "more_body": not self._stream_consumed,
+ }
+ except ClientDisconnect:
+ self._wrapped_rcv_disconnected = True
+ return {"type": "http.disconnect"}
+
+
+class BaseHTTPMiddleware:
+ def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
+ self.app = app
+ self.dispatch_func = self.dispatch if dispatch is None else dispatch
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] != "http":
+ await self.app(scope, receive, send)
+ return
+
+ request = _CachedRequest(scope, receive)
+ wrapped_receive = request.wrapped_receive
+ response_sent = anyio.Event()
+ app_exc: Exception | None = None
+
+ async def call_next(request: Request) -> Response:
+ async def receive_or_disconnect() -> Message:
+ if response_sent.is_set():
+ return {"type": "http.disconnect"}
+
+ async with anyio.create_task_group() as task_group:
+
+ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
+ result = await func()
+ task_group.cancel_scope.cancel()
+ return result
+
+ task_group.start_soon(wrap, response_sent.wait)
+ message = await wrap(wrapped_receive)
+
+ if response_sent.is_set():
+ return {"type": "http.disconnect"}
+
+ return message
+
+ async def send_no_error(message: Message) -> None:
+ try:
+ await send_stream.send(message)
+ except anyio.BrokenResourceError:
+ # recv_stream has been closed, i.e. response_sent has been set.
+ return
+
+ async def coro() -> None:
+ nonlocal app_exc
+
+ with send_stream:
+ try:
+ await self.app(scope, receive_or_disconnect, send_no_error)
+ except Exception as exc:
+ app_exc = exc
+
+ task_group.start_soon(coro)
+
+ try:
+ message = await recv_stream.receive()
+ info = message.get("info", None)
+ if message["type"] == "http.response.debug" and info is not None:
+ message = await recv_stream.receive()
+ except anyio.EndOfStream:
+ if app_exc is not None:
+ raise app_exc
+ raise RuntimeError("No response returned.")
+
+ assert message["type"] == "http.response.start"
+
+ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
+ async for message in recv_stream:
+ assert message["type"] == "http.response.body"
+ body = message.get("body", b"")
+ if body:
+ yield body
+ if not message.get("more_body", False):
+ break
+
+ response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
+ response.raw_headers = message["headers"]
+ return response
+
+ streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
+ send_stream, recv_stream = streams
+ with recv_stream, send_stream, collapse_excgroups():
+ async with anyio.create_task_group() as task_group:
+ response = await self.dispatch_func(request, call_next)
+ await response(scope, wrapped_receive, send)
+ response_sent.set()
+ recv_stream.close()
+
+ if app_exc is not None:
+ raise app_exc
+
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
+ raise NotImplementedError() # pragma: no cover
+
+
+class _StreamingResponse(Response):
+ def __init__(
+ self,
+ content: AsyncContentStream,
+ status_code: int = 200,
+ headers: typing.Mapping[str, str] | None = None,
+ media_type: str | None = None,
+ info: typing.Mapping[str, typing.Any] | None = None,
+ ) -> None:
+ self.info = info
+ self.body_iterator = content
+ self.status_code = status_code
+ self.media_type = media_type
+ self.init_headers(headers)
+ self.background = None
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if self.info is not None:
+ await send({"type": "http.response.debug", "info": self.info})
+ await send(
+ {
+ "type": "http.response.start",
+ "status": self.status_code,
+ "headers": self.raw_headers,
+ }
+ )
+
+ async for chunk in self.body_iterator:
+ await send({"type": "http.response.body", "body": chunk, "more_body": True})
+
+ await send({"type": "http.response.body", "body": b"", "more_body": False})
+
+ if self.background:
+ await self.background()