From 4a52a71956a8d46fcb7294ac71734504bb09bcc2 Mon Sep 17 00:00:00 2001 From: S. Solomon Darnell Date: Fri, 28 Mar 2025 21:52:21 -0500 Subject: two version of R2R are here --- .../starlette/middleware/trustedhost.py | 60 ++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 .venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py (limited to '.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py') diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py new file mode 100644 index 00000000..2d1c999e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import typing + +from starlette.datastructures import URL, Headers +from starlette.responses import PlainTextResponse, RedirectResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send + +ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'." + + +class TrustedHostMiddleware: + def __init__( + self, + app: ASGIApp, + allowed_hosts: typing.Sequence[str] | None = None, + www_redirect: bool = True, + ) -> None: + if allowed_hosts is None: + allowed_hosts = ["*"] + + for pattern in allowed_hosts: + assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD + if pattern.startswith("*") and pattern != "*": + assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD + self.app = app + self.allowed_hosts = list(allowed_hosts) + self.allow_any = "*" in allowed_hosts + self.www_redirect = www_redirect + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if self.allow_any or scope["type"] not in ( + "http", + "websocket", + ): # pragma: no cover + await self.app(scope, receive, send) + return + + headers = Headers(scope=scope) + host = headers.get("host", "").split(":")[0] + is_valid_host = False + found_www_redirect = False + for pattern in self.allowed_hosts: + if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])): + is_valid_host = True + break + elif "www." + host == pattern: + found_www_redirect = True + + if is_valid_host: + await self.app(scope, receive, send) + else: + response: Response + if found_www_redirect and self.www_redirect: + url = URL(scope=scope) + redirect_url = url.replace(netloc="www." + url.netloc) + response = RedirectResponse(url=str(redirect_url)) + else: + response = PlainTextResponse("Invalid host header", status_code=400) + await response(scope, receive, send) -- cgit v1.2.3