aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py')
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py
new file mode 100644
index 00000000..2d1c999e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py
@@ -0,0 +1,60 @@
+from __future__ import annotations
+
+import typing
+
+from starlette.datastructures import URL, Headers
+from starlette.responses import PlainTextResponse, RedirectResponse, Response
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
+
+
+class TrustedHostMiddleware:
+ def __init__(
+ self,
+ app: ASGIApp,
+ allowed_hosts: typing.Sequence[str] | None = None,
+ www_redirect: bool = True,
+ ) -> None:
+ if allowed_hosts is None:
+ allowed_hosts = ["*"]
+
+ for pattern in allowed_hosts:
+ assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
+ if pattern.startswith("*") and pattern != "*":
+ assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
+ self.app = app
+ self.allowed_hosts = list(allowed_hosts)
+ self.allow_any = "*" in allowed_hosts
+ self.www_redirect = www_redirect
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if self.allow_any or scope["type"] not in (
+ "http",
+ "websocket",
+ ): # pragma: no cover
+ await self.app(scope, receive, send)
+ return
+
+ headers = Headers(scope=scope)
+ host = headers.get("host", "").split(":")[0]
+ is_valid_host = False
+ found_www_redirect = False
+ for pattern in self.allowed_hosts:
+ if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
+ is_valid_host = True
+ break
+ elif "www." + host == pattern:
+ found_www_redirect = True
+
+ if is_valid_host:
+ await self.app(scope, receive, send)
+ else:
+ response: Response
+ if found_www_redirect and self.www_redirect:
+ url = URL(scope=scope)
+ redirect_url = url.replace(netloc="www." + url.netloc)
+ response = RedirectResponse(url=str(redirect_url))
+ else:
+ response = PlainTextResponse("Invalid host header", status_code=400)
+ await response(scope, receive, send)