diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/starlette/middleware | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/middleware')
11 files changed, 1275 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/__init__.py b/.venv/lib/python3.12/site-packages/starlette/middleware/__init__.py new file mode 100644 index 00000000..b99538a2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/__init__.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import sys +from collections.abc import Iterator +from typing import Any, Protocol + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +from starlette.types import ASGIApp + +P = ParamSpec("P") + + +class _MiddlewareFactory(Protocol[P]): + def __call__(self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover + + +class Middleware: + def __init__( + self, + cls: _MiddlewareFactory[P], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + self.cls = cls + self.args = args + self.kwargs = kwargs + + def __iter__(self) -> Iterator[Any]: + as_tuple = (self.cls, self.args, self.kwargs) + return iter(as_tuple) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + args_strings = [f"{value!r}" for value in self.args] + option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()] + name = getattr(self.cls, "__name__", "") + args_repr = ", ".join([name] + args_strings + option_strings) + return f"{class_name}({args_repr})" diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py b/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py new file mode 100644 index 00000000..8555ee07 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import typing + +from starlette.authentication import ( + AuthCredentials, + AuthenticationBackend, + AuthenticationError, + UnauthenticatedUser, +) +from starlette.requests import HTTPConnection +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send + + +class AuthenticationMiddleware: + def __init__( + self, + app: ASGIApp, + backend: AuthenticationBackend, + on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None, + ) -> None: + self.app = app + self.backend = backend + self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = ( + on_error if on_error is not None else self.default_on_error + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ["http", "websocket"]: + await self.app(scope, receive, send) + return + + conn = HTTPConnection(scope) + try: + auth_result = await self.backend.authenticate(conn) + except AuthenticationError as exc: + response = self.on_error(conn, exc) + if scope["type"] == "websocket": + await send({"type": "websocket.close", "code": 1000}) + else: + await response(scope, receive, send) + return + + if auth_result is None: + auth_result = AuthCredentials(), UnauthenticatedUser() + scope["auth"], scope["user"] = auth_result + await self.app(scope, receive, send) + + @staticmethod + def default_on_error(conn: HTTPConnection, exc: Exception) -> Response: + return PlainTextResponse(str(exc), status_code=400) diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/base.py b/.venv/lib/python3.12/site-packages/starlette/middleware/base.py new file mode 100644 index 00000000..2a59337e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/base.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import typing + +import anyio + +from starlette._utils import collapse_excgroups +from starlette.requests import ClientDisconnect, Request +from starlette.responses import AsyncContentStream, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] +DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]] +T = typing.TypeVar("T") + + +class _CachedRequest(Request): + """ + If the user calls Request.body() from their dispatch function + we cache the entire request body in memory and pass that to downstream middlewares, + but if they call Request.stream() then all we do is send an + empty body so that downstream things don't hang forever. + """ + + def __init__(self, scope: Scope, receive: Receive): + super().__init__(scope, receive) + self._wrapped_rcv_disconnected = False + self._wrapped_rcv_consumed = False + self._wrapped_rc_stream = self.stream() + + async def wrapped_receive(self) -> Message: + # wrapped_rcv state 1: disconnected + if self._wrapped_rcv_disconnected: + # we've already sent a disconnect to the downstream app + # we don't need to wait to get another one + # (although most ASGI servers will just keep sending it) + return {"type": "http.disconnect"} + # wrapped_rcv state 1: consumed but not yet disconnected + if self._wrapped_rcv_consumed: + # since the downstream app has consumed us all that is left + # is to send it a disconnect + if self._is_disconnected: + # the middleware has already seen the disconnect + # since we know the client is disconnected no need to wait + # for the message + self._wrapped_rcv_disconnected = True + return {"type": "http.disconnect"} + # we don't know yet if the client is disconnected or not + # so we'll wait until we get that message + msg = await self.receive() + if msg["type"] != "http.disconnect": # pragma: no cover + # at this point a disconnect is all that we should be receiving + # if we get something else, things went wrong somewhere + raise RuntimeError(f"Unexpected message received: {msg['type']}") + self._wrapped_rcv_disconnected = True + return msg + + # wrapped_rcv state 3: not yet consumed + if getattr(self, "_body", None) is not None: + # body() was called, we return it even if the client disconnected + self._wrapped_rcv_consumed = True + return { + "type": "http.request", + "body": self._body, + "more_body": False, + } + elif self._stream_consumed: + # stream() was called to completion + # return an empty body so that downstream apps don't hang + # waiting for a disconnect + self._wrapped_rcv_consumed = True + return { + "type": "http.request", + "body": b"", + "more_body": False, + } + else: + # body() was never called and stream() wasn't consumed + try: + stream = self.stream() + chunk = await stream.__anext__() + self._wrapped_rcv_consumed = self._stream_consumed + return { + "type": "http.request", + "body": chunk, + "more_body": not self._stream_consumed, + } + except ClientDisconnect: + self._wrapped_rcv_disconnected = True + return {"type": "http.disconnect"} + + +class BaseHTTPMiddleware: + def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None: + self.app = app + self.dispatch_func = self.dispatch if dispatch is None else dispatch + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = _CachedRequest(scope, receive) + wrapped_receive = request.wrapped_receive + response_sent = anyio.Event() + app_exc: Exception | None = None + + async def call_next(request: Request) -> Response: + async def receive_or_disconnect() -> Message: + if response_sent.is_set(): + return {"type": "http.disconnect"} + + async with anyio.create_task_group() as task_group: + + async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T: + result = await func() + task_group.cancel_scope.cancel() + return result + + task_group.start_soon(wrap, response_sent.wait) + message = await wrap(wrapped_receive) + + if response_sent.is_set(): + return {"type": "http.disconnect"} + + return message + + async def send_no_error(message: Message) -> None: + try: + await send_stream.send(message) + except anyio.BrokenResourceError: + # recv_stream has been closed, i.e. response_sent has been set. + return + + async def coro() -> None: + nonlocal app_exc + + with send_stream: + try: + await self.app(scope, receive_or_disconnect, send_no_error) + except Exception as exc: + app_exc = exc + + task_group.start_soon(coro) + + try: + message = await recv_stream.receive() + info = message.get("info", None) + if message["type"] == "http.response.debug" and info is not None: + message = await recv_stream.receive() + except anyio.EndOfStream: + if app_exc is not None: + raise app_exc + raise RuntimeError("No response returned.") + + assert message["type"] == "http.response.start" + + async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async for message in recv_stream: + assert message["type"] == "http.response.body" + body = message.get("body", b"") + if body: + yield body + if not message.get("more_body", False): + break + + response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info) + response.raw_headers = message["headers"] + return response + + streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream() + send_stream, recv_stream = streams + with recv_stream, send_stream, collapse_excgroups(): + async with anyio.create_task_group() as task_group: + response = await self.dispatch_func(request, call_next) + await response(scope, wrapped_receive, send) + response_sent.set() + recv_stream.close() + + if app_exc is not None: + raise app_exc + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + raise NotImplementedError() # pragma: no cover + + +class _StreamingResponse(Response): + def __init__( + self, + content: AsyncContentStream, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + info: typing.Mapping[str, typing.Any] | None = None, + ) -> None: + self.info = info + self.body_iterator = content + self.status_code = status_code + self.media_type = media_type + self.init_headers(headers) + self.background = None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if self.info is not None: + await send({"type": "http.response.debug", "info": self.info}) + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + + async for chunk in self.body_iterator: + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + if self.background: + await self.background() diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py b/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py new file mode 100644 index 00000000..61502691 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import functools +import re +import typing + +from starlette.datastructures import Headers, MutableHeaders +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT") +SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} + + +class CORSMiddleware: + def __init__( + self, + app: ASGIApp, + allow_origins: typing.Sequence[str] = (), + allow_methods: typing.Sequence[str] = ("GET",), + allow_headers: typing.Sequence[str] = (), + allow_credentials: bool = False, + allow_origin_regex: str | None = None, + expose_headers: typing.Sequence[str] = (), + max_age: int = 600, + ) -> None: + if "*" in allow_methods: + allow_methods = ALL_METHODS + + compiled_allow_origin_regex = None + if allow_origin_regex is not None: + compiled_allow_origin_regex = re.compile(allow_origin_regex) + + allow_all_origins = "*" in allow_origins + allow_all_headers = "*" in allow_headers + preflight_explicit_allow_origin = not allow_all_origins or allow_credentials + + simple_headers = {} + if allow_all_origins: + simple_headers["Access-Control-Allow-Origin"] = "*" + if allow_credentials: + simple_headers["Access-Control-Allow-Credentials"] = "true" + if expose_headers: + simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers) + + preflight_headers = {} + if preflight_explicit_allow_origin: + # The origin value will be set in preflight_response() if it is allowed. + preflight_headers["Vary"] = "Origin" + else: + preflight_headers["Access-Control-Allow-Origin"] = "*" + preflight_headers.update( + { + "Access-Control-Allow-Methods": ", ".join(allow_methods), + "Access-Control-Max-Age": str(max_age), + } + ) + allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers)) + if allow_headers and not allow_all_headers: + preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) + if allow_credentials: + preflight_headers["Access-Control-Allow-Credentials"] = "true" + + self.app = app + self.allow_origins = allow_origins + self.allow_methods = allow_methods + self.allow_headers = [h.lower() for h in allow_headers] + self.allow_all_origins = allow_all_origins + self.allow_all_headers = allow_all_headers + self.preflight_explicit_allow_origin = preflight_explicit_allow_origin + self.allow_origin_regex = compiled_allow_origin_regex + self.simple_headers = simple_headers + self.preflight_headers = preflight_headers + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": # pragma: no cover + await self.app(scope, receive, send) + return + + method = scope["method"] + headers = Headers(scope=scope) + origin = headers.get("origin") + + if origin is None: + await self.app(scope, receive, send) + return + + if method == "OPTIONS" and "access-control-request-method" in headers: + response = self.preflight_response(request_headers=headers) + await response(scope, receive, send) + return + + await self.simple_response(scope, receive, send, request_headers=headers) + + def is_allowed_origin(self, origin: str) -> bool: + if self.allow_all_origins: + return True + + if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin): + return True + + return origin in self.allow_origins + + def preflight_response(self, request_headers: Headers) -> Response: + requested_origin = request_headers["origin"] + requested_method = request_headers["access-control-request-method"] + requested_headers = request_headers.get("access-control-request-headers") + + headers = dict(self.preflight_headers) + failures = [] + + if self.is_allowed_origin(origin=requested_origin): + if self.preflight_explicit_allow_origin: + # The "else" case is already accounted for in self.preflight_headers + # and the value would be "*". + headers["Access-Control-Allow-Origin"] = requested_origin + else: + failures.append("origin") + + if requested_method not in self.allow_methods: + failures.append("method") + + # If we allow all headers, then we have to mirror back any requested + # headers in the response. + if self.allow_all_headers and requested_headers is not None: + headers["Access-Control-Allow-Headers"] = requested_headers + elif requested_headers is not None: + for header in [h.lower() for h in requested_headers.split(",")]: + if header.strip() not in self.allow_headers: + failures.append("headers") + break + + # We don't strictly need to use 400 responses here, since its up to + # the browser to enforce the CORS policy, but its more informative + # if we do. + if failures: + failure_text = "Disallowed CORS " + ", ".join(failures) + return PlainTextResponse(failure_text, status_code=400, headers=headers) + + return PlainTextResponse("OK", status_code=200, headers=headers) + + async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None: + send = functools.partial(self.send, send=send, request_headers=request_headers) + await self.app(scope, receive, send) + + async def send(self, message: Message, send: Send, request_headers: Headers) -> None: + if message["type"] != "http.response.start": + await send(message) + return + + message.setdefault("headers", []) + headers = MutableHeaders(scope=message) + headers.update(self.simple_headers) + origin = request_headers["Origin"] + has_cookie = "cookie" in request_headers + + # If request includes any cookie headers, then we must respond + # with the specific origin instead of '*'. + if self.allow_all_origins and has_cookie: + self.allow_explicit_origin(headers, origin) + + # If we only allow specific origins, then we have to mirror back + # the Origin header in the response. + elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): + self.allow_explicit_origin(headers, origin) + + await send(message) + + @staticmethod + def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: + headers["Access-Control-Allow-Origin"] = origin + headers.add_vary_header("Origin") diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py b/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py new file mode 100644 index 00000000..76ad776b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import html +import inspect +import sys +import traceback +import typing + +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.requests import Request +from starlette.responses import HTMLResponse, PlainTextResponse, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +STYLES = """ +p { + color: #211c1c; +} +.traceback-container { + border: 1px solid #038BB8; +} +.traceback-title { + background-color: #038BB8; + color: lemonchiffon; + padding: 12px; + font-size: 20px; + margin-top: 0px; +} +.frame-line { + padding-left: 10px; + font-family: monospace; +} +.frame-filename { + font-family: monospace; +} +.center-line { + background-color: #038BB8; + color: #f9f6e1; + padding: 5px 0px 5px 5px; +} +.lineno { + margin-right: 5px; +} +.frame-title { + font-weight: unset; + padding: 10px 10px 10px 10px; + background-color: #E4F4FD; + margin-right: 10px; + color: #191f21; + font-size: 17px; + border: 1px solid #c7dce8; +} +.collapse-btn { + float: right; + padding: 0px 5px 1px 5px; + border: solid 1px #96aebb; + cursor: pointer; +} +.collapsed { + display: none; +} +.source-code { + font-family: courier; + font-size: small; + padding-bottom: 10px; +} +""" + +JS = """ +<script type="text/javascript"> + function collapse(element){ + const frameId = element.getAttribute("data-frame-id"); + const frame = document.getElementById(frameId); + + if (frame.classList.contains("collapsed")){ + element.innerHTML = "‒"; + frame.classList.remove("collapsed"); + } else { + element.innerHTML = "+"; + frame.classList.add("collapsed"); + } + } +</script> +""" + +TEMPLATE = """ +<html> + <head> + <style type='text/css'> + {styles} + </style> + <title>Starlette Debugger</title> + </head> + <body> + <h1>500 Server Error</h1> + <h2>{error}</h2> + <div class="traceback-container"> + <p class="traceback-title">Traceback</p> + <div>{exc_html}</div> + </div> + {js} + </body> +</html> +""" + +FRAME_TEMPLATE = """ +<div> + <p class="frame-title">File <span class="frame-filename">{frame_filename}</span>, + line <i>{frame_lineno}</i>, + in <b>{frame_name}</b> + <span class="collapse-btn" data-frame-id="{frame_filename}-{frame_lineno}" onclick="collapse(this)">{collapse_button}</span> + </p> + <div id="{frame_filename}-{frame_lineno}" class="source-code {collapsed}">{code_context}</div> +</div> +""" # noqa: E501 + +LINE = """ +<p><span class="frame-line"> +<span class="lineno">{lineno}.</span> {line}</span></p> +""" + +CENTER_LINE = """ +<p class="center-line"><span class="frame-line center-line"> +<span class="lineno">{lineno}.</span> {line}</span></p> +""" + + +class ServerErrorMiddleware: + """ + Handles returning 500 responses when a server error occurs. + + If 'debug' is set, then traceback responses will be returned, + otherwise the designated 'handler' will be called. + + This middleware class should generally be used to wrap *everything* + else up, so that unhandled exceptions anywhere in the stack + always result in an appropriate 500 response. + """ + + def __init__( + self, + app: ASGIApp, + handler: typing.Callable[[Request, Exception], typing.Any] | None = None, + debug: bool = False, + ) -> None: + self.app = app + self.handler = handler + self.debug = debug + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + response_started = False + + async def _send(message: Message) -> None: + nonlocal response_started, send + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + try: + await self.app(scope, receive, _send) + except Exception as exc: + request = Request(scope) + if self.debug: + # In debug mode, return traceback responses. + response = self.debug_response(request, exc) + elif self.handler is None: + # Use our default 500 error handler. + response = self.error_response(request, exc) + else: + # Use an installed 500 error handler. + if is_async_callable(self.handler): + response = await self.handler(request, exc) + else: + response = await run_in_threadpool(self.handler, request, exc) + + if not response_started: + await response(scope, receive, send) + + # We always continue to raise the exception. + # This allows servers to log the error, or allows test clients + # to optionally raise the error within the test case. + raise exc + + def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str: + values = { + # HTML escape - line could contain < or > + "line": html.escape(line).replace(" ", " "), + "lineno": (frame_lineno - frame_index) + index, + } + + if index != frame_index: + return LINE.format(**values) + return CENTER_LINE.format(**values) + + def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str: + code_context = "".join( + self.format_line( + index, + line, + frame.lineno, + frame.index, # type: ignore[arg-type] + ) + for index, line in enumerate(frame.code_context or []) + ) + + values = { + # HTML escape - filename could contain < or >, especially if it's a virtual + # file e.g. <stdin> in the REPL + "frame_filename": html.escape(frame.filename), + "frame_lineno": frame.lineno, + # HTML escape - if you try very hard it's possible to name a function with < + # or > + "frame_name": html.escape(frame.function), + "code_context": code_context, + "collapsed": "collapsed" if is_collapsed else "", + "collapse_button": "+" if is_collapsed else "‒", + } + return FRAME_TEMPLATE.format(**values) + + def generate_html(self, exc: Exception, limit: int = 7) -> str: + traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True) + + exc_html = "" + is_collapsed = False + exc_traceback = exc.__traceback__ + if exc_traceback is not None: + frames = inspect.getinnerframes(exc_traceback, limit) + for frame in reversed(frames): + exc_html += self.generate_frame_html(frame, is_collapsed) + is_collapsed = True + + if sys.version_info >= (3, 13): # pragma: no cover + exc_type_str = traceback_obj.exc_type_str + else: # pragma: no cover + exc_type_str = traceback_obj.exc_type.__name__ + + # escape error class and text + error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}" + + return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html) + + def generate_plain_text(self, exc: Exception) -> str: + return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + + def debug_response(self, request: Request, exc: Exception) -> Response: + accept = request.headers.get("accept", "") + + if "text/html" in accept: + content = self.generate_html(exc) + return HTMLResponse(content, status_code=500) + content = self.generate_plain_text(exc) + return PlainTextResponse(content, status_code=500) + + def error_response(self, request: Request, exc: Exception) -> Response: + return PlainTextResponse("Internal Server Error", status_code=500) diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py b/.venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py new file mode 100644 index 00000000..981d2fca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import typing + +from starlette._exception_handler import ( + ExceptionHandlers, + StatusHandlers, + wrap_app_handling_exceptions, +) +from starlette.exceptions import HTTPException, WebSocketException +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.websockets import WebSocket + + +class ExceptionMiddleware: + def __init__( + self, + app: ASGIApp, + handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None, + debug: bool = False, + ) -> None: + self.app = app + self.debug = debug # TODO: We ought to handle 404 cases if debug is set. + self._status_handlers: StatusHandlers = {} + self._exception_handlers: ExceptionHandlers = { + HTTPException: self.http_exception, + WebSocketException: self.websocket_exception, + } + if handlers is not None: # pragma: no branch + for key, value in handlers.items(): + self.add_exception_handler(key, value) + + def add_exception_handler( + self, + exc_class_or_status_code: int | type[Exception], + handler: typing.Callable[[Request, Exception], Response], + ) -> None: + if isinstance(exc_class_or_status_code, int): + self._status_handlers[exc_class_or_status_code] = handler + else: + assert issubclass(exc_class_or_status_code, Exception) + self._exception_handlers[exc_class_or_status_code] = handler + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ("http", "websocket"): + await self.app(scope, receive, send) + return + + scope["starlette.exception_handlers"] = ( + self._exception_handlers, + self._status_handlers, + ) + + conn: Request | WebSocket + if scope["type"] == "http": + conn = Request(scope, receive, send) + else: + conn = WebSocket(scope, receive, send) + + await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send) + + def http_exception(self, request: Request, exc: Exception) -> Response: + assert isinstance(exc, HTTPException) + if exc.status_code in {204, 304}: + return Response(status_code=exc.status_code, headers=exc.headers) + return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers) + + async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None: + assert isinstance(exc, WebSocketException) + await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/gzip.py b/.venv/lib/python3.12/site-packages/starlette/middleware/gzip.py new file mode 100644 index 00000000..c7fd5b77 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/gzip.py @@ -0,0 +1,141 @@ +import gzip +import io +import typing + +from starlette.datastructures import Headers, MutableHeaders +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",) + + +class GZipMiddleware: + def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None: + self.app = app + self.minimum_size = minimum_size + self.compresslevel = compresslevel + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": # pragma: no cover + await self.app(scope, receive, send) + return + + headers = Headers(scope=scope) + responder: ASGIApp + if "gzip" in headers.get("Accept-Encoding", ""): + responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel) + else: + responder = IdentityResponder(self.app, self.minimum_size) + + await responder(scope, receive, send) + + +class IdentityResponder: + content_encoding: str + + def __init__(self, app: ASGIApp, minimum_size: int) -> None: + self.app = app + self.minimum_size = minimum_size + self.send: Send = unattached_send + self.initial_message: Message = {} + self.started = False + self.content_encoding_set = False + self.content_type_is_excluded = False + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + self.send = send + await self.app(scope, receive, self.send_with_compression) + + async def send_with_compression(self, message: Message) -> None: + message_type = message["type"] + if message_type == "http.response.start": + # Don't send the initial message until we've determined how to + # modify the outgoing headers correctly. + self.initial_message = message + headers = Headers(raw=self.initial_message["headers"]) + self.content_encoding_set = "content-encoding" in headers + self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES) + elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded): + if not self.started: + self.started = True + await self.send(self.initial_message) + await self.send(message) + elif message_type == "http.response.body" and not self.started: + self.started = True + body = message.get("body", b"") + more_body = message.get("more_body", False) + if len(body) < self.minimum_size and not more_body: + # Don't apply compression to small outgoing responses. + await self.send(self.initial_message) + await self.send(message) + elif not more_body: + # Standard response. + body = self.apply_compression(body, more_body=False) + + headers = MutableHeaders(raw=self.initial_message["headers"]) + headers.add_vary_header("Accept-Encoding") + if body != message["body"]: + headers["Content-Encoding"] = self.content_encoding + headers["Content-Length"] = str(len(body)) + message["body"] = body + + await self.send(self.initial_message) + await self.send(message) + else: + # Initial body in streaming response. + body = self.apply_compression(body, more_body=True) + + headers = MutableHeaders(raw=self.initial_message["headers"]) + headers.add_vary_header("Accept-Encoding") + if body != message["body"]: + headers["Content-Encoding"] = self.content_encoding + del headers["Content-Length"] + message["body"] = body + + await self.send(self.initial_message) + await self.send(message) + elif message_type == "http.response.body": # pragma: no branch + # Remaining body in streaming response. + body = message.get("body", b"") + more_body = message.get("more_body", False) + + message["body"] = self.apply_compression(body, more_body=more_body) + + await self.send(message) + + def apply_compression(self, body: bytes, *, more_body: bool) -> bytes: + """Apply compression on the response body. + + If more_body is False, any compression file should be closed. If it + isn't, it won't be closed automatically until all background tasks + complete. + """ + return body + + +class GZipResponder(IdentityResponder): + content_encoding = "gzip" + + def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None: + super().__init__(app, minimum_size) + + self.gzip_buffer = io.BytesIO() + self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + with self.gzip_buffer, self.gzip_file: + await super().__call__(scope, receive, send) + + def apply_compression(self, body: bytes, *, more_body: bool) -> bytes: + self.gzip_file.write(body) + if not more_body: + self.gzip_file.close() + + body = self.gzip_buffer.getvalue() + self.gzip_buffer.seek(0) + self.gzip_buffer.truncate() + + return body + + +async def unattached_send(message: Message) -> typing.NoReturn: + raise RuntimeError("send awaitable not set") # pragma: no cover diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/httpsredirect.py b/.venv/lib/python3.12/site-packages/starlette/middleware/httpsredirect.py new file mode 100644 index 00000000..a8359067 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/httpsredirect.py @@ -0,0 +1,19 @@ +from starlette.datastructures import URL +from starlette.responses import RedirectResponse +from starlette.types import ASGIApp, Receive, Scope, Send + + +class HTTPSRedirectMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] in ("http", "websocket") and scope["scheme"] in ("http", "ws"): + url = URL(scope=scope) + redirect_scheme = {"http": "https", "ws": "wss"}[url.scheme] + netloc = url.hostname if url.port in (80, 443) else url.netloc + url = url.replace(scheme=redirect_scheme, netloc=netloc) + response = RedirectResponse(url, status_code=307) + await response(scope, receive, send) + else: + await self.app(scope, receive, send) diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/sessions.py b/.venv/lib/python3.12/site-packages/starlette/middleware/sessions.py new file mode 100644 index 00000000..5f9fcd88 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/sessions.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import json +import typing +from base64 import b64decode, b64encode + +import itsdangerous +from itsdangerous.exc import BadSignature + +from starlette.datastructures import MutableHeaders, Secret +from starlette.requests import HTTPConnection +from starlette.types import ASGIApp, Message, Receive, Scope, Send + + +class SessionMiddleware: + def __init__( + self, + app: ASGIApp, + secret_key: str | Secret, + session_cookie: str = "session", + max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds + path: str = "/", + same_site: typing.Literal["lax", "strict", "none"] = "lax", + https_only: bool = False, + domain: str | None = None, + ) -> None: + self.app = app + self.signer = itsdangerous.TimestampSigner(str(secret_key)) + self.session_cookie = session_cookie + self.max_age = max_age + self.path = path + self.security_flags = "httponly; samesite=" + same_site + if https_only: # Secure flag can be used with HTTPS only + self.security_flags += "; secure" + if domain is not None: + self.security_flags += f"; domain={domain}" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ("http", "websocket"): # pragma: no cover + await self.app(scope, receive, send) + return + + connection = HTTPConnection(scope) + initial_session_was_empty = True + + if self.session_cookie in connection.cookies: + data = connection.cookies[self.session_cookie].encode("utf-8") + try: + data = self.signer.unsign(data, max_age=self.max_age) + scope["session"] = json.loads(b64decode(data)) + initial_session_was_empty = False + except BadSignature: + scope["session"] = {} + else: + scope["session"] = {} + + async def send_wrapper(message: Message) -> None: + if message["type"] == "http.response.start": + if scope["session"]: + # We have session data to persist. + data = b64encode(json.dumps(scope["session"]).encode("utf-8")) + data = self.signer.sign(data) + headers = MutableHeaders(scope=message) + header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( + session_cookie=self.session_cookie, + data=data.decode("utf-8"), + path=self.path, + max_age=f"Max-Age={self.max_age}; " if self.max_age else "", + security_flags=self.security_flags, + ) + headers.append("Set-Cookie", header_value) + elif not initial_session_was_empty: + # The session has been cleared. + headers = MutableHeaders(scope=message) + header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( + session_cookie=self.session_cookie, + data="null", + path=self.path, + expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ", + security_flags=self.security_flags, + ) + headers.append("Set-Cookie", header_value) + await send(message) + + await self.app(scope, receive, send_wrapper) diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py new file mode 100644 index 00000000..2d1c999e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import typing + +from starlette.datastructures import URL, Headers +from starlette.responses import PlainTextResponse, RedirectResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send + +ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'." + + +class TrustedHostMiddleware: + def __init__( + self, + app: ASGIApp, + allowed_hosts: typing.Sequence[str] | None = None, + www_redirect: bool = True, + ) -> None: + if allowed_hosts is None: + allowed_hosts = ["*"] + + for pattern in allowed_hosts: + assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD + if pattern.startswith("*") and pattern != "*": + assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD + self.app = app + self.allowed_hosts = list(allowed_hosts) + self.allow_any = "*" in allowed_hosts + self.www_redirect = www_redirect + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if self.allow_any or scope["type"] not in ( + "http", + "websocket", + ): # pragma: no cover + await self.app(scope, receive, send) + return + + headers = Headers(scope=scope) + host = headers.get("host", "").split(":")[0] + is_valid_host = False + found_www_redirect = False + for pattern in self.allowed_hosts: + if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])): + is_valid_host = True + break + elif "www." + host == pattern: + found_www_redirect = True + + if is_valid_host: + await self.app(scope, receive, send) + else: + response: Response + if found_www_redirect and self.www_redirect: + url = URL(scope=scope) + redirect_url = url.replace(netloc="www." + url.netloc) + response = RedirectResponse(url=str(redirect_url)) + else: + response = PlainTextResponse("Invalid host header", status_code=400) + await response(scope, receive, send) diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py b/.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py new file mode 100644 index 00000000..6e0a3fae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import io +import math +import sys +import typing +import warnings + +import anyio +from anyio.abc import ObjectReceiveStream, ObjectSendStream + +from starlette.types import Receive, Scope, Send + +warnings.warn( + "starlette.middleware.wsgi is deprecated and will be removed in a future release. " + "Please refer to https://github.com/abersheeran/a2wsgi as a replacement.", + DeprecationWarning, +) + + +def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]: + """ + Builds a scope and request body into a WSGI environ object. + """ + + script_name = scope.get("root_path", "").encode("utf8").decode("latin1") + path_info = scope["path"].encode("utf8").decode("latin1") + if path_info.startswith(script_name): + path_info = path_info[len(script_name) :] + + environ = { + "REQUEST_METHOD": scope["method"], + "SCRIPT_NAME": script_name, + "PATH_INFO": path_info, + "QUERY_STRING": scope["query_string"].decode("ascii"), + "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}", + "wsgi.version": (1, 0), + "wsgi.url_scheme": scope.get("scheme", "http"), + "wsgi.input": io.BytesIO(body), + "wsgi.errors": sys.stdout, + "wsgi.multithread": True, + "wsgi.multiprocess": True, + "wsgi.run_once": False, + } + + # Get server name and port - required in WSGI, not in ASGI + server = scope.get("server") or ("localhost", 80) + environ["SERVER_NAME"] = server[0] + environ["SERVER_PORT"] = server[1] + + # Get client IP address + if scope.get("client"): + environ["REMOTE_ADDR"] = scope["client"][0] + + # Go through headers and make them into environ entries + for name, value in scope.get("headers", []): + name = name.decode("latin1") + if name == "content-length": + corrected_name = "CONTENT_LENGTH" + elif name == "content-type": + corrected_name = "CONTENT_TYPE" + else: + corrected_name = f"HTTP_{name}".upper().replace("-", "_") + # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in + # case + value = value.decode("latin1") + if corrected_name in environ: + value = environ[corrected_name] + "," + value + environ[corrected_name] = value + return environ + + +class WSGIMiddleware: + def __init__(self, app: typing.Callable[..., typing.Any]) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + responder = WSGIResponder(self.app, scope) + await responder(receive, send) + + +class WSGIResponder: + stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]] + stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] + + def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None: + self.app = app + self.scope = scope + self.status = None + self.response_headers = None + self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf) + self.response_started = False + self.exc_info: typing.Any = None + + async def __call__(self, receive: Receive, send: Send) -> None: + body = b"" + more_body = True + while more_body: + message = await receive() + body += message.get("body", b"") + more_body = message.get("more_body", False) + environ = build_environ(self.scope, body) + + async with anyio.create_task_group() as task_group: + task_group.start_soon(self.sender, send) + async with self.stream_send: + await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) + if self.exc_info is not None: + raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) + + async def sender(self, send: Send) -> None: + async with self.stream_receive: + async for message in self.stream_receive: + await send(message) + + def start_response( + self, + status: str, + response_headers: list[tuple[str, str]], + exc_info: typing.Any = None, + ) -> None: + self.exc_info = exc_info + if not self.response_started: # pragma: no branch + self.response_started = True + status_code_string, _ = status.split(" ", 1) + status_code = int(status_code_string) + headers = [ + (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) + for name, value in response_headers + ] + anyio.from_thread.run( + self.stream_send.send, + { + "type": "http.response.start", + "status": status_code, + "headers": headers, + }, + ) + + def wsgi( + self, + environ: dict[str, typing.Any], + start_response: typing.Callable[..., typing.Any], + ) -> None: + for chunk in self.app(environ, start_response): + anyio.from_thread.run( + self.stream_send.send, + {"type": "http.response.body", "body": chunk, "more_body": True}, + ) + + anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""}) |