diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py new file mode 100644 index 00000000..2d1c999e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import typing + +from starlette.datastructures import URL, Headers +from starlette.responses import PlainTextResponse, RedirectResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send + +ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'." + + +class TrustedHostMiddleware: + def __init__( + self, + app: ASGIApp, + allowed_hosts: typing.Sequence[str] | None = None, + www_redirect: bool = True, + ) -> None: + if allowed_hosts is None: + allowed_hosts = ["*"] + + for pattern in allowed_hosts: + assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD + if pattern.startswith("*") and pattern != "*": + assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD + self.app = app + self.allowed_hosts = list(allowed_hosts) + self.allow_any = "*" in allowed_hosts + self.www_redirect = www_redirect + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if self.allow_any or scope["type"] not in ( + "http", + "websocket", + ): # pragma: no cover + await self.app(scope, receive, send) + return + + headers = Headers(scope=scope) + host = headers.get("host", "").split(":")[0] + is_valid_host = False + found_www_redirect = False + for pattern in self.allowed_hosts: + if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])): + is_valid_host = True + break + elif "www." + host == pattern: + found_www_redirect = True + + if is_valid_host: + await self.app(scope, receive, send) + else: + response: Response + if found_www_redirect and self.www_redirect: + url = URL(scope=scope) + redirect_url = url.replace(netloc="www." + url.netloc) + response = RedirectResponse(url=str(redirect_url)) + else: + response = PlainTextResponse("Invalid host header", status_code=400) + await response(scope, receive, send) |