aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/starlette/middleware
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/middleware')
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/__init__.py42
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py52
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/base.py220
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/cors.py172
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/errors.py260
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py72
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/gzip.py141
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/httpsredirect.py19
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/sessions.py85
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py60
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py152
11 files changed, 1275 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/__init__.py b/.venv/lib/python3.12/site-packages/starlette/middleware/__init__.py
new file mode 100644
index 00000000..b99538a2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/__init__.py
@@ -0,0 +1,42 @@
+from __future__ import annotations
+
+import sys
+from collections.abc import Iterator
+from typing import Any, Protocol
+
+if sys.version_info >= (3, 10): # pragma: no cover
+ from typing import ParamSpec
+else: # pragma: no cover
+ from typing_extensions import ParamSpec
+
+from starlette.types import ASGIApp
+
+P = ParamSpec("P")
+
+
+class _MiddlewareFactory(Protocol[P]):
+ def __call__(self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover
+
+
+class Middleware:
+ def __init__(
+ self,
+ cls: _MiddlewareFactory[P],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> None:
+ self.cls = cls
+ self.args = args
+ self.kwargs = kwargs
+
+ def __iter__(self) -> Iterator[Any]:
+ as_tuple = (self.cls, self.args, self.kwargs)
+ return iter(as_tuple)
+
+ def __repr__(self) -> str:
+ class_name = self.__class__.__name__
+ args_strings = [f"{value!r}" for value in self.args]
+ option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
+ name = getattr(self.cls, "__name__", "")
+ args_repr = ", ".join([name] + args_strings + option_strings)
+ return f"{class_name}({args_repr})"
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py b/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py
new file mode 100644
index 00000000..8555ee07
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py
@@ -0,0 +1,52 @@
+from __future__ import annotations
+
+import typing
+
+from starlette.authentication import (
+ AuthCredentials,
+ AuthenticationBackend,
+ AuthenticationError,
+ UnauthenticatedUser,
+)
+from starlette.requests import HTTPConnection
+from starlette.responses import PlainTextResponse, Response
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+
+class AuthenticationMiddleware:
+ def __init__(
+ self,
+ app: ASGIApp,
+ backend: AuthenticationBackend,
+ on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
+ ) -> None:
+ self.app = app
+ self.backend = backend
+ self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = (
+ on_error if on_error is not None else self.default_on_error
+ )
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] not in ["http", "websocket"]:
+ await self.app(scope, receive, send)
+ return
+
+ conn = HTTPConnection(scope)
+ try:
+ auth_result = await self.backend.authenticate(conn)
+ except AuthenticationError as exc:
+ response = self.on_error(conn, exc)
+ if scope["type"] == "websocket":
+ await send({"type": "websocket.close", "code": 1000})
+ else:
+ await response(scope, receive, send)
+ return
+
+ if auth_result is None:
+ auth_result = AuthCredentials(), UnauthenticatedUser()
+ scope["auth"], scope["user"] = auth_result
+ await self.app(scope, receive, send)
+
+ @staticmethod
+ def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
+ return PlainTextResponse(str(exc), status_code=400)
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/base.py b/.venv/lib/python3.12/site-packages/starlette/middleware/base.py
new file mode 100644
index 00000000..2a59337e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/base.py
@@ -0,0 +1,220 @@
+from __future__ import annotations
+
+import typing
+
+import anyio
+
+from starlette._utils import collapse_excgroups
+from starlette.requests import ClientDisconnect, Request
+from starlette.responses import AsyncContentStream, Response
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
+
+RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
+DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
+T = typing.TypeVar("T")
+
+
+class _CachedRequest(Request):
+ """
+ If the user calls Request.body() from their dispatch function
+ we cache the entire request body in memory and pass that to downstream middlewares,
+ but if they call Request.stream() then all we do is send an
+ empty body so that downstream things don't hang forever.
+ """
+
+ def __init__(self, scope: Scope, receive: Receive):
+ super().__init__(scope, receive)
+ self._wrapped_rcv_disconnected = False
+ self._wrapped_rcv_consumed = False
+ self._wrapped_rc_stream = self.stream()
+
+ async def wrapped_receive(self) -> Message:
+ # wrapped_rcv state 1: disconnected
+ if self._wrapped_rcv_disconnected:
+ # we've already sent a disconnect to the downstream app
+ # we don't need to wait to get another one
+ # (although most ASGI servers will just keep sending it)
+ return {"type": "http.disconnect"}
+ # wrapped_rcv state 1: consumed but not yet disconnected
+ if self._wrapped_rcv_consumed:
+ # since the downstream app has consumed us all that is left
+ # is to send it a disconnect
+ if self._is_disconnected:
+ # the middleware has already seen the disconnect
+ # since we know the client is disconnected no need to wait
+ # for the message
+ self._wrapped_rcv_disconnected = True
+ return {"type": "http.disconnect"}
+ # we don't know yet if the client is disconnected or not
+ # so we'll wait until we get that message
+ msg = await self.receive()
+ if msg["type"] != "http.disconnect": # pragma: no cover
+ # at this point a disconnect is all that we should be receiving
+ # if we get something else, things went wrong somewhere
+ raise RuntimeError(f"Unexpected message received: {msg['type']}")
+ self._wrapped_rcv_disconnected = True
+ return msg
+
+ # wrapped_rcv state 3: not yet consumed
+ if getattr(self, "_body", None) is not None:
+ # body() was called, we return it even if the client disconnected
+ self._wrapped_rcv_consumed = True
+ return {
+ "type": "http.request",
+ "body": self._body,
+ "more_body": False,
+ }
+ elif self._stream_consumed:
+ # stream() was called to completion
+ # return an empty body so that downstream apps don't hang
+ # waiting for a disconnect
+ self._wrapped_rcv_consumed = True
+ return {
+ "type": "http.request",
+ "body": b"",
+ "more_body": False,
+ }
+ else:
+ # body() was never called and stream() wasn't consumed
+ try:
+ stream = self.stream()
+ chunk = await stream.__anext__()
+ self._wrapped_rcv_consumed = self._stream_consumed
+ return {
+ "type": "http.request",
+ "body": chunk,
+ "more_body": not self._stream_consumed,
+ }
+ except ClientDisconnect:
+ self._wrapped_rcv_disconnected = True
+ return {"type": "http.disconnect"}
+
+
+class BaseHTTPMiddleware:
+ def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
+ self.app = app
+ self.dispatch_func = self.dispatch if dispatch is None else dispatch
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] != "http":
+ await self.app(scope, receive, send)
+ return
+
+ request = _CachedRequest(scope, receive)
+ wrapped_receive = request.wrapped_receive
+ response_sent = anyio.Event()
+ app_exc: Exception | None = None
+
+ async def call_next(request: Request) -> Response:
+ async def receive_or_disconnect() -> Message:
+ if response_sent.is_set():
+ return {"type": "http.disconnect"}
+
+ async with anyio.create_task_group() as task_group:
+
+ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
+ result = await func()
+ task_group.cancel_scope.cancel()
+ return result
+
+ task_group.start_soon(wrap, response_sent.wait)
+ message = await wrap(wrapped_receive)
+
+ if response_sent.is_set():
+ return {"type": "http.disconnect"}
+
+ return message
+
+ async def send_no_error(message: Message) -> None:
+ try:
+ await send_stream.send(message)
+ except anyio.BrokenResourceError:
+ # recv_stream has been closed, i.e. response_sent has been set.
+ return
+
+ async def coro() -> None:
+ nonlocal app_exc
+
+ with send_stream:
+ try:
+ await self.app(scope, receive_or_disconnect, send_no_error)
+ except Exception as exc:
+ app_exc = exc
+
+ task_group.start_soon(coro)
+
+ try:
+ message = await recv_stream.receive()
+ info = message.get("info", None)
+ if message["type"] == "http.response.debug" and info is not None:
+ message = await recv_stream.receive()
+ except anyio.EndOfStream:
+ if app_exc is not None:
+ raise app_exc
+ raise RuntimeError("No response returned.")
+
+ assert message["type"] == "http.response.start"
+
+ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
+ async for message in recv_stream:
+ assert message["type"] == "http.response.body"
+ body = message.get("body", b"")
+ if body:
+ yield body
+ if not message.get("more_body", False):
+ break
+
+ response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
+ response.raw_headers = message["headers"]
+ return response
+
+ streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
+ send_stream, recv_stream = streams
+ with recv_stream, send_stream, collapse_excgroups():
+ async with anyio.create_task_group() as task_group:
+ response = await self.dispatch_func(request, call_next)
+ await response(scope, wrapped_receive, send)
+ response_sent.set()
+ recv_stream.close()
+
+ if app_exc is not None:
+ raise app_exc
+
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
+ raise NotImplementedError() # pragma: no cover
+
+
+class _StreamingResponse(Response):
+ def __init__(
+ self,
+ content: AsyncContentStream,
+ status_code: int = 200,
+ headers: typing.Mapping[str, str] | None = None,
+ media_type: str | None = None,
+ info: typing.Mapping[str, typing.Any] | None = None,
+ ) -> None:
+ self.info = info
+ self.body_iterator = content
+ self.status_code = status_code
+ self.media_type = media_type
+ self.init_headers(headers)
+ self.background = None
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if self.info is not None:
+ await send({"type": "http.response.debug", "info": self.info})
+ await send(
+ {
+ "type": "http.response.start",
+ "status": self.status_code,
+ "headers": self.raw_headers,
+ }
+ )
+
+ async for chunk in self.body_iterator:
+ await send({"type": "http.response.body", "body": chunk, "more_body": True})
+
+ await send({"type": "http.response.body", "body": b"", "more_body": False})
+
+ if self.background:
+ await self.background()
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py b/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py
new file mode 100644
index 00000000..61502691
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py
@@ -0,0 +1,172 @@
+from __future__ import annotations
+
+import functools
+import re
+import typing
+
+from starlette.datastructures import Headers, MutableHeaders
+from starlette.responses import PlainTextResponse, Response
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
+
+ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
+SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"}
+
+
+class CORSMiddleware:
+ def __init__(
+ self,
+ app: ASGIApp,
+ allow_origins: typing.Sequence[str] = (),
+ allow_methods: typing.Sequence[str] = ("GET",),
+ allow_headers: typing.Sequence[str] = (),
+ allow_credentials: bool = False,
+ allow_origin_regex: str | None = None,
+ expose_headers: typing.Sequence[str] = (),
+ max_age: int = 600,
+ ) -> None:
+ if "*" in allow_methods:
+ allow_methods = ALL_METHODS
+
+ compiled_allow_origin_regex = None
+ if allow_origin_regex is not None:
+ compiled_allow_origin_regex = re.compile(allow_origin_regex)
+
+ allow_all_origins = "*" in allow_origins
+ allow_all_headers = "*" in allow_headers
+ preflight_explicit_allow_origin = not allow_all_origins or allow_credentials
+
+ simple_headers = {}
+ if allow_all_origins:
+ simple_headers["Access-Control-Allow-Origin"] = "*"
+ if allow_credentials:
+ simple_headers["Access-Control-Allow-Credentials"] = "true"
+ if expose_headers:
+ simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
+
+ preflight_headers = {}
+ if preflight_explicit_allow_origin:
+ # The origin value will be set in preflight_response() if it is allowed.
+ preflight_headers["Vary"] = "Origin"
+ else:
+ preflight_headers["Access-Control-Allow-Origin"] = "*"
+ preflight_headers.update(
+ {
+ "Access-Control-Allow-Methods": ", ".join(allow_methods),
+ "Access-Control-Max-Age": str(max_age),
+ }
+ )
+ allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
+ if allow_headers and not allow_all_headers:
+ preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
+ if allow_credentials:
+ preflight_headers["Access-Control-Allow-Credentials"] = "true"
+
+ self.app = app
+ self.allow_origins = allow_origins
+ self.allow_methods = allow_methods
+ self.allow_headers = [h.lower() for h in allow_headers]
+ self.allow_all_origins = allow_all_origins
+ self.allow_all_headers = allow_all_headers
+ self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
+ self.allow_origin_regex = compiled_allow_origin_regex
+ self.simple_headers = simple_headers
+ self.preflight_headers = preflight_headers
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] != "http": # pragma: no cover
+ await self.app(scope, receive, send)
+ return
+
+ method = scope["method"]
+ headers = Headers(scope=scope)
+ origin = headers.get("origin")
+
+ if origin is None:
+ await self.app(scope, receive, send)
+ return
+
+ if method == "OPTIONS" and "access-control-request-method" in headers:
+ response = self.preflight_response(request_headers=headers)
+ await response(scope, receive, send)
+ return
+
+ await self.simple_response(scope, receive, send, request_headers=headers)
+
+ def is_allowed_origin(self, origin: str) -> bool:
+ if self.allow_all_origins:
+ return True
+
+ if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
+ return True
+
+ return origin in self.allow_origins
+
+ def preflight_response(self, request_headers: Headers) -> Response:
+ requested_origin = request_headers["origin"]
+ requested_method = request_headers["access-control-request-method"]
+ requested_headers = request_headers.get("access-control-request-headers")
+
+ headers = dict(self.preflight_headers)
+ failures = []
+
+ if self.is_allowed_origin(origin=requested_origin):
+ if self.preflight_explicit_allow_origin:
+ # The "else" case is already accounted for in self.preflight_headers
+ # and the value would be "*".
+ headers["Access-Control-Allow-Origin"] = requested_origin
+ else:
+ failures.append("origin")
+
+ if requested_method not in self.allow_methods:
+ failures.append("method")
+
+ # If we allow all headers, then we have to mirror back any requested
+ # headers in the response.
+ if self.allow_all_headers and requested_headers is not None:
+ headers["Access-Control-Allow-Headers"] = requested_headers
+ elif requested_headers is not None:
+ for header in [h.lower() for h in requested_headers.split(",")]:
+ if header.strip() not in self.allow_headers:
+ failures.append("headers")
+ break
+
+ # We don't strictly need to use 400 responses here, since its up to
+ # the browser to enforce the CORS policy, but its more informative
+ # if we do.
+ if failures:
+ failure_text = "Disallowed CORS " + ", ".join(failures)
+ return PlainTextResponse(failure_text, status_code=400, headers=headers)
+
+ return PlainTextResponse("OK", status_code=200, headers=headers)
+
+ async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
+ send = functools.partial(self.send, send=send, request_headers=request_headers)
+ await self.app(scope, receive, send)
+
+ async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
+ if message["type"] != "http.response.start":
+ await send(message)
+ return
+
+ message.setdefault("headers", [])
+ headers = MutableHeaders(scope=message)
+ headers.update(self.simple_headers)
+ origin = request_headers["Origin"]
+ has_cookie = "cookie" in request_headers
+
+ # If request includes any cookie headers, then we must respond
+ # with the specific origin instead of '*'.
+ if self.allow_all_origins and has_cookie:
+ self.allow_explicit_origin(headers, origin)
+
+ # If we only allow specific origins, then we have to mirror back
+ # the Origin header in the response.
+ elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
+ self.allow_explicit_origin(headers, origin)
+
+ await send(message)
+
+ @staticmethod
+ def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None:
+ headers["Access-Control-Allow-Origin"] = origin
+ headers.add_vary_header("Origin")
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py b/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py
new file mode 100644
index 00000000..76ad776b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py
@@ -0,0 +1,260 @@
+from __future__ import annotations
+
+import html
+import inspect
+import sys
+import traceback
+import typing
+
+from starlette._utils import is_async_callable
+from starlette.concurrency import run_in_threadpool
+from starlette.requests import Request
+from starlette.responses import HTMLResponse, PlainTextResponse, Response
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
+
+STYLES = """
+p {
+ color: #211c1c;
+}
+.traceback-container {
+ border: 1px solid #038BB8;
+}
+.traceback-title {
+ background-color: #038BB8;
+ color: lemonchiffon;
+ padding: 12px;
+ font-size: 20px;
+ margin-top: 0px;
+}
+.frame-line {
+ padding-left: 10px;
+ font-family: monospace;
+}
+.frame-filename {
+ font-family: monospace;
+}
+.center-line {
+ background-color: #038BB8;
+ color: #f9f6e1;
+ padding: 5px 0px 5px 5px;
+}
+.lineno {
+ margin-right: 5px;
+}
+.frame-title {
+ font-weight: unset;
+ padding: 10px 10px 10px 10px;
+ background-color: #E4F4FD;
+ margin-right: 10px;
+ color: #191f21;
+ font-size: 17px;
+ border: 1px solid #c7dce8;
+}
+.collapse-btn {
+ float: right;
+ padding: 0px 5px 1px 5px;
+ border: solid 1px #96aebb;
+ cursor: pointer;
+}
+.collapsed {
+ display: none;
+}
+.source-code {
+ font-family: courier;
+ font-size: small;
+ padding-bottom: 10px;
+}
+"""
+
+JS = """
+<script type="text/javascript">
+ function collapse(element){
+ const frameId = element.getAttribute("data-frame-id");
+ const frame = document.getElementById(frameId);
+
+ if (frame.classList.contains("collapsed")){
+ element.innerHTML = "&#8210;";
+ frame.classList.remove("collapsed");
+ } else {
+ element.innerHTML = "+";
+ frame.classList.add("collapsed");
+ }
+ }
+</script>
+"""
+
+TEMPLATE = """
+<html>
+ <head>
+ <style type='text/css'>
+ {styles}
+ </style>
+ <title>Starlette Debugger</title>
+ </head>
+ <body>
+ <h1>500 Server Error</h1>
+ <h2>{error}</h2>
+ <div class="traceback-container">
+ <p class="traceback-title">Traceback</p>
+ <div>{exc_html}</div>
+ </div>
+ {js}
+ </body>
+</html>
+"""
+
+FRAME_TEMPLATE = """
+<div>
+ <p class="frame-title">File <span class="frame-filename">{frame_filename}</span>,
+ line <i>{frame_lineno}</i>,
+ in <b>{frame_name}</b>
+ <span class="collapse-btn" data-frame-id="{frame_filename}-{frame_lineno}" onclick="collapse(this)">{collapse_button}</span>
+ </p>
+ <div id="{frame_filename}-{frame_lineno}" class="source-code {collapsed}">{code_context}</div>
+</div>
+""" # noqa: E501
+
+LINE = """
+<p><span class="frame-line">
+<span class="lineno">{lineno}.</span> {line}</span></p>
+"""
+
+CENTER_LINE = """
+<p class="center-line"><span class="frame-line center-line">
+<span class="lineno">{lineno}.</span> {line}</span></p>
+"""
+
+
+class ServerErrorMiddleware:
+ """
+ Handles returning 500 responses when a server error occurs.
+
+ If 'debug' is set, then traceback responses will be returned,
+ otherwise the designated 'handler' will be called.
+
+ This middleware class should generally be used to wrap *everything*
+ else up, so that unhandled exceptions anywhere in the stack
+ always result in an appropriate 500 response.
+ """
+
+ def __init__(
+ self,
+ app: ASGIApp,
+ handler: typing.Callable[[Request, Exception], typing.Any] | None = None,
+ debug: bool = False,
+ ) -> None:
+ self.app = app
+ self.handler = handler
+ self.debug = debug
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] != "http":
+ await self.app(scope, receive, send)
+ return
+
+ response_started = False
+
+ async def _send(message: Message) -> None:
+ nonlocal response_started, send
+
+ if message["type"] == "http.response.start":
+ response_started = True
+ await send(message)
+
+ try:
+ await self.app(scope, receive, _send)
+ except Exception as exc:
+ request = Request(scope)
+ if self.debug:
+ # In debug mode, return traceback responses.
+ response = self.debug_response(request, exc)
+ elif self.handler is None:
+ # Use our default 500 error handler.
+ response = self.error_response(request, exc)
+ else:
+ # Use an installed 500 error handler.
+ if is_async_callable(self.handler):
+ response = await self.handler(request, exc)
+ else:
+ response = await run_in_threadpool(self.handler, request, exc)
+
+ if not response_started:
+ await response(scope, receive, send)
+
+ # We always continue to raise the exception.
+ # This allows servers to log the error, or allows test clients
+ # to optionally raise the error within the test case.
+ raise exc
+
+ def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
+ values = {
+ # HTML escape - line could contain < or >
+ "line": html.escape(line).replace(" ", "&nbsp"),
+ "lineno": (frame_lineno - frame_index) + index,
+ }
+
+ if index != frame_index:
+ return LINE.format(**values)
+ return CENTER_LINE.format(**values)
+
+ def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
+ code_context = "".join(
+ self.format_line(
+ index,
+ line,
+ frame.lineno,
+ frame.index, # type: ignore[arg-type]
+ )
+ for index, line in enumerate(frame.code_context or [])
+ )
+
+ values = {
+ # HTML escape - filename could contain < or >, especially if it's a virtual
+ # file e.g. <stdin> in the REPL
+ "frame_filename": html.escape(frame.filename),
+ "frame_lineno": frame.lineno,
+ # HTML escape - if you try very hard it's possible to name a function with <
+ # or >
+ "frame_name": html.escape(frame.function),
+ "code_context": code_context,
+ "collapsed": "collapsed" if is_collapsed else "",
+ "collapse_button": "+" if is_collapsed else "&#8210;",
+ }
+ return FRAME_TEMPLATE.format(**values)
+
+ def generate_html(self, exc: Exception, limit: int = 7) -> str:
+ traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
+
+ exc_html = ""
+ is_collapsed = False
+ exc_traceback = exc.__traceback__
+ if exc_traceback is not None:
+ frames = inspect.getinnerframes(exc_traceback, limit)
+ for frame in reversed(frames):
+ exc_html += self.generate_frame_html(frame, is_collapsed)
+ is_collapsed = True
+
+ if sys.version_info >= (3, 13): # pragma: no cover
+ exc_type_str = traceback_obj.exc_type_str
+ else: # pragma: no cover
+ exc_type_str = traceback_obj.exc_type.__name__
+
+ # escape error class and text
+ error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
+
+ return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)
+
+ def generate_plain_text(self, exc: Exception) -> str:
+ return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
+
+ def debug_response(self, request: Request, exc: Exception) -> Response:
+ accept = request.headers.get("accept", "")
+
+ if "text/html" in accept:
+ content = self.generate_html(exc)
+ return HTMLResponse(content, status_code=500)
+ content = self.generate_plain_text(exc)
+ return PlainTextResponse(content, status_code=500)
+
+ def error_response(self, request: Request, exc: Exception) -> Response:
+ return PlainTextResponse("Internal Server Error", status_code=500)
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py b/.venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py
new file mode 100644
index 00000000..981d2fca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py
@@ -0,0 +1,72 @@
+from __future__ import annotations
+
+import typing
+
+from starlette._exception_handler import (
+ ExceptionHandlers,
+ StatusHandlers,
+ wrap_app_handling_exceptions,
+)
+from starlette.exceptions import HTTPException, WebSocketException
+from starlette.requests import Request
+from starlette.responses import PlainTextResponse, Response
+from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.websockets import WebSocket
+
+
+class ExceptionMiddleware:
+ def __init__(
+ self,
+ app: ASGIApp,
+ handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None,
+ debug: bool = False,
+ ) -> None:
+ self.app = app
+ self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
+ self._status_handlers: StatusHandlers = {}
+ self._exception_handlers: ExceptionHandlers = {
+ HTTPException: self.http_exception,
+ WebSocketException: self.websocket_exception,
+ }
+ if handlers is not None: # pragma: no branch
+ for key, value in handlers.items():
+ self.add_exception_handler(key, value)
+
+ def add_exception_handler(
+ self,
+ exc_class_or_status_code: int | type[Exception],
+ handler: typing.Callable[[Request, Exception], Response],
+ ) -> None:
+ if isinstance(exc_class_or_status_code, int):
+ self._status_handlers[exc_class_or_status_code] = handler
+ else:
+ assert issubclass(exc_class_or_status_code, Exception)
+ self._exception_handlers[exc_class_or_status_code] = handler
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] not in ("http", "websocket"):
+ await self.app(scope, receive, send)
+ return
+
+ scope["starlette.exception_handlers"] = (
+ self._exception_handlers,
+ self._status_handlers,
+ )
+
+ conn: Request | WebSocket
+ if scope["type"] == "http":
+ conn = Request(scope, receive, send)
+ else:
+ conn = WebSocket(scope, receive, send)
+
+ await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
+
+ def http_exception(self, request: Request, exc: Exception) -> Response:
+ assert isinstance(exc, HTTPException)
+ if exc.status_code in {204, 304}:
+ return Response(status_code=exc.status_code, headers=exc.headers)
+ return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
+
+ async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
+ assert isinstance(exc, WebSocketException)
+ await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/gzip.py b/.venv/lib/python3.12/site-packages/starlette/middleware/gzip.py
new file mode 100644
index 00000000..c7fd5b77
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/gzip.py
@@ -0,0 +1,141 @@
+import gzip
+import io
+import typing
+
+from starlette.datastructures import Headers, MutableHeaders
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
+
+DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
+
+
+class GZipMiddleware:
+ def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
+ self.app = app
+ self.minimum_size = minimum_size
+ self.compresslevel = compresslevel
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] != "http": # pragma: no cover
+ await self.app(scope, receive, send)
+ return
+
+ headers = Headers(scope=scope)
+ responder: ASGIApp
+ if "gzip" in headers.get("Accept-Encoding", ""):
+ responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
+ else:
+ responder = IdentityResponder(self.app, self.minimum_size)
+
+ await responder(scope, receive, send)
+
+
+class IdentityResponder:
+ content_encoding: str
+
+ def __init__(self, app: ASGIApp, minimum_size: int) -> None:
+ self.app = app
+ self.minimum_size = minimum_size
+ self.send: Send = unattached_send
+ self.initial_message: Message = {}
+ self.started = False
+ self.content_encoding_set = False
+ self.content_type_is_excluded = False
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ self.send = send
+ await self.app(scope, receive, self.send_with_compression)
+
+ async def send_with_compression(self, message: Message) -> None:
+ message_type = message["type"]
+ if message_type == "http.response.start":
+ # Don't send the initial message until we've determined how to
+ # modify the outgoing headers correctly.
+ self.initial_message = message
+ headers = Headers(raw=self.initial_message["headers"])
+ self.content_encoding_set = "content-encoding" in headers
+ self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
+ elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
+ if not self.started:
+ self.started = True
+ await self.send(self.initial_message)
+ await self.send(message)
+ elif message_type == "http.response.body" and not self.started:
+ self.started = True
+ body = message.get("body", b"")
+ more_body = message.get("more_body", False)
+ if len(body) < self.minimum_size and not more_body:
+ # Don't apply compression to small outgoing responses.
+ await self.send(self.initial_message)
+ await self.send(message)
+ elif not more_body:
+ # Standard response.
+ body = self.apply_compression(body, more_body=False)
+
+ headers = MutableHeaders(raw=self.initial_message["headers"])
+ headers.add_vary_header("Accept-Encoding")
+ if body != message["body"]:
+ headers["Content-Encoding"] = self.content_encoding
+ headers["Content-Length"] = str(len(body))
+ message["body"] = body
+
+ await self.send(self.initial_message)
+ await self.send(message)
+ else:
+ # Initial body in streaming response.
+ body = self.apply_compression(body, more_body=True)
+
+ headers = MutableHeaders(raw=self.initial_message["headers"])
+ headers.add_vary_header("Accept-Encoding")
+ if body != message["body"]:
+ headers["Content-Encoding"] = self.content_encoding
+ del headers["Content-Length"]
+ message["body"] = body
+
+ await self.send(self.initial_message)
+ await self.send(message)
+ elif message_type == "http.response.body": # pragma: no branch
+ # Remaining body in streaming response.
+ body = message.get("body", b"")
+ more_body = message.get("more_body", False)
+
+ message["body"] = self.apply_compression(body, more_body=more_body)
+
+ await self.send(message)
+
+ def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
+ """Apply compression on the response body.
+
+ If more_body is False, any compression file should be closed. If it
+ isn't, it won't be closed automatically until all background tasks
+ complete.
+ """
+ return body
+
+
+class GZipResponder(IdentityResponder):
+ content_encoding = "gzip"
+
+ def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
+ super().__init__(app, minimum_size)
+
+ self.gzip_buffer = io.BytesIO()
+ self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ with self.gzip_buffer, self.gzip_file:
+ await super().__call__(scope, receive, send)
+
+ def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
+ self.gzip_file.write(body)
+ if not more_body:
+ self.gzip_file.close()
+
+ body = self.gzip_buffer.getvalue()
+ self.gzip_buffer.seek(0)
+ self.gzip_buffer.truncate()
+
+ return body
+
+
+async def unattached_send(message: Message) -> typing.NoReturn:
+ raise RuntimeError("send awaitable not set") # pragma: no cover
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/httpsredirect.py b/.venv/lib/python3.12/site-packages/starlette/middleware/httpsredirect.py
new file mode 100644
index 00000000..a8359067
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/httpsredirect.py
@@ -0,0 +1,19 @@
+from starlette.datastructures import URL
+from starlette.responses import RedirectResponse
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+
+class HTTPSRedirectMiddleware:
+ def __init__(self, app: ASGIApp) -> None:
+ self.app = app
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] in ("http", "websocket") and scope["scheme"] in ("http", "ws"):
+ url = URL(scope=scope)
+ redirect_scheme = {"http": "https", "ws": "wss"}[url.scheme]
+ netloc = url.hostname if url.port in (80, 443) else url.netloc
+ url = url.replace(scheme=redirect_scheme, netloc=netloc)
+ response = RedirectResponse(url, status_code=307)
+ await response(scope, receive, send)
+ else:
+ await self.app(scope, receive, send)
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/sessions.py b/.venv/lib/python3.12/site-packages/starlette/middleware/sessions.py
new file mode 100644
index 00000000..5f9fcd88
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/sessions.py
@@ -0,0 +1,85 @@
+from __future__ import annotations
+
+import json
+import typing
+from base64 import b64decode, b64encode
+
+import itsdangerous
+from itsdangerous.exc import BadSignature
+
+from starlette.datastructures import MutableHeaders, Secret
+from starlette.requests import HTTPConnection
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
+
+
+class SessionMiddleware:
+ def __init__(
+ self,
+ app: ASGIApp,
+ secret_key: str | Secret,
+ session_cookie: str = "session",
+ max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
+ path: str = "/",
+ same_site: typing.Literal["lax", "strict", "none"] = "lax",
+ https_only: bool = False,
+ domain: str | None = None,
+ ) -> None:
+ self.app = app
+ self.signer = itsdangerous.TimestampSigner(str(secret_key))
+ self.session_cookie = session_cookie
+ self.max_age = max_age
+ self.path = path
+ self.security_flags = "httponly; samesite=" + same_site
+ if https_only: # Secure flag can be used with HTTPS only
+ self.security_flags += "; secure"
+ if domain is not None:
+ self.security_flags += f"; domain={domain}"
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] not in ("http", "websocket"): # pragma: no cover
+ await self.app(scope, receive, send)
+ return
+
+ connection = HTTPConnection(scope)
+ initial_session_was_empty = True
+
+ if self.session_cookie in connection.cookies:
+ data = connection.cookies[self.session_cookie].encode("utf-8")
+ try:
+ data = self.signer.unsign(data, max_age=self.max_age)
+ scope["session"] = json.loads(b64decode(data))
+ initial_session_was_empty = False
+ except BadSignature:
+ scope["session"] = {}
+ else:
+ scope["session"] = {}
+
+ async def send_wrapper(message: Message) -> None:
+ if message["type"] == "http.response.start":
+ if scope["session"]:
+ # We have session data to persist.
+ data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
+ data = self.signer.sign(data)
+ headers = MutableHeaders(scope=message)
+ header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
+ session_cookie=self.session_cookie,
+ data=data.decode("utf-8"),
+ path=self.path,
+ max_age=f"Max-Age={self.max_age}; " if self.max_age else "",
+ security_flags=self.security_flags,
+ )
+ headers.append("Set-Cookie", header_value)
+ elif not initial_session_was_empty:
+ # The session has been cleared.
+ headers = MutableHeaders(scope=message)
+ header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
+ session_cookie=self.session_cookie,
+ data="null",
+ path=self.path,
+ expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ",
+ security_flags=self.security_flags,
+ )
+ headers.append("Set-Cookie", header_value)
+ await send(message)
+
+ await self.app(scope, receive, send_wrapper)
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py
new file mode 100644
index 00000000..2d1c999e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py
@@ -0,0 +1,60 @@
+from __future__ import annotations
+
+import typing
+
+from starlette.datastructures import URL, Headers
+from starlette.responses import PlainTextResponse, RedirectResponse, Response
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
+
+
+class TrustedHostMiddleware:
+ def __init__(
+ self,
+ app: ASGIApp,
+ allowed_hosts: typing.Sequence[str] | None = None,
+ www_redirect: bool = True,
+ ) -> None:
+ if allowed_hosts is None:
+ allowed_hosts = ["*"]
+
+ for pattern in allowed_hosts:
+ assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
+ if pattern.startswith("*") and pattern != "*":
+ assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
+ self.app = app
+ self.allowed_hosts = list(allowed_hosts)
+ self.allow_any = "*" in allowed_hosts
+ self.www_redirect = www_redirect
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if self.allow_any or scope["type"] not in (
+ "http",
+ "websocket",
+ ): # pragma: no cover
+ await self.app(scope, receive, send)
+ return
+
+ headers = Headers(scope=scope)
+ host = headers.get("host", "").split(":")[0]
+ is_valid_host = False
+ found_www_redirect = False
+ for pattern in self.allowed_hosts:
+ if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
+ is_valid_host = True
+ break
+ elif "www." + host == pattern:
+ found_www_redirect = True
+
+ if is_valid_host:
+ await self.app(scope, receive, send)
+ else:
+ response: Response
+ if found_www_redirect and self.www_redirect:
+ url = URL(scope=scope)
+ redirect_url = url.replace(netloc="www." + url.netloc)
+ response = RedirectResponse(url=str(redirect_url))
+ else:
+ response = PlainTextResponse("Invalid host header", status_code=400)
+ await response(scope, receive, send)
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py b/.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py
new file mode 100644
index 00000000..6e0a3fae
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py
@@ -0,0 +1,152 @@
+from __future__ import annotations
+
+import io
+import math
+import sys
+import typing
+import warnings
+
+import anyio
+from anyio.abc import ObjectReceiveStream, ObjectSendStream
+
+from starlette.types import Receive, Scope, Send
+
+warnings.warn(
+ "starlette.middleware.wsgi is deprecated and will be removed in a future release. "
+ "Please refer to https://github.com/abersheeran/a2wsgi as a replacement.",
+ DeprecationWarning,
+)
+
+
+def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]:
+ """
+ Builds a scope and request body into a WSGI environ object.
+ """
+
+ script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
+ path_info = scope["path"].encode("utf8").decode("latin1")
+ if path_info.startswith(script_name):
+ path_info = path_info[len(script_name) :]
+
+ environ = {
+ "REQUEST_METHOD": scope["method"],
+ "SCRIPT_NAME": script_name,
+ "PATH_INFO": path_info,
+ "QUERY_STRING": scope["query_string"].decode("ascii"),
+ "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
+ "wsgi.version": (1, 0),
+ "wsgi.url_scheme": scope.get("scheme", "http"),
+ "wsgi.input": io.BytesIO(body),
+ "wsgi.errors": sys.stdout,
+ "wsgi.multithread": True,
+ "wsgi.multiprocess": True,
+ "wsgi.run_once": False,
+ }
+
+ # Get server name and port - required in WSGI, not in ASGI
+ server = scope.get("server") or ("localhost", 80)
+ environ["SERVER_NAME"] = server[0]
+ environ["SERVER_PORT"] = server[1]
+
+ # Get client IP address
+ if scope.get("client"):
+ environ["REMOTE_ADDR"] = scope["client"][0]
+
+ # Go through headers and make them into environ entries
+ for name, value in scope.get("headers", []):
+ name = name.decode("latin1")
+ if name == "content-length":
+ corrected_name = "CONTENT_LENGTH"
+ elif name == "content-type":
+ corrected_name = "CONTENT_TYPE"
+ else:
+ corrected_name = f"HTTP_{name}".upper().replace("-", "_")
+ # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in
+ # case
+ value = value.decode("latin1")
+ if corrected_name in environ:
+ value = environ[corrected_name] + "," + value
+ environ[corrected_name] = value
+ return environ
+
+
+class WSGIMiddleware:
+ def __init__(self, app: typing.Callable[..., typing.Any]) -> None:
+ self.app = app
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ assert scope["type"] == "http"
+ responder = WSGIResponder(self.app, scope)
+ await responder(receive, send)
+
+
+class WSGIResponder:
+ stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
+ stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
+
+ def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None:
+ self.app = app
+ self.scope = scope
+ self.status = None
+ self.response_headers = None
+ self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
+ self.response_started = False
+ self.exc_info: typing.Any = None
+
+ async def __call__(self, receive: Receive, send: Send) -> None:
+ body = b""
+ more_body = True
+ while more_body:
+ message = await receive()
+ body += message.get("body", b"")
+ more_body = message.get("more_body", False)
+ environ = build_environ(self.scope, body)
+
+ async with anyio.create_task_group() as task_group:
+ task_group.start_soon(self.sender, send)
+ async with self.stream_send:
+ await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
+ if self.exc_info is not None:
+ raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
+
+ async def sender(self, send: Send) -> None:
+ async with self.stream_receive:
+ async for message in self.stream_receive:
+ await send(message)
+
+ def start_response(
+ self,
+ status: str,
+ response_headers: list[tuple[str, str]],
+ exc_info: typing.Any = None,
+ ) -> None:
+ self.exc_info = exc_info
+ if not self.response_started: # pragma: no branch
+ self.response_started = True
+ status_code_string, _ = status.split(" ", 1)
+ status_code = int(status_code_string)
+ headers = [
+ (name.strip().encode("ascii").lower(), value.strip().encode("ascii"))
+ for name, value in response_headers
+ ]
+ anyio.from_thread.run(
+ self.stream_send.send,
+ {
+ "type": "http.response.start",
+ "status": status_code,
+ "headers": headers,
+ },
+ )
+
+ def wsgi(
+ self,
+ environ: dict[str, typing.Any],
+ start_response: typing.Callable[..., typing.Any],
+ ) -> None:
+ for chunk in self.app(environ, start_response):
+ anyio.from_thread.run(
+ self.stream_send.send,
+ {"type": "http.response.body", "body": chunk, "more_body": True},
+ )
+
+ anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})