about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py')
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py
new file mode 100644
index 00000000..2d1c999e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py
@@ -0,0 +1,60 @@
+from __future__ import annotations
+
+import typing
+
+from starlette.datastructures import URL, Headers
+from starlette.responses import PlainTextResponse, RedirectResponse, Response
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
+
+
+class TrustedHostMiddleware:
+    def __init__(
+        self,
+        app: ASGIApp,
+        allowed_hosts: typing.Sequence[str] | None = None,
+        www_redirect: bool = True,
+    ) -> None:
+        if allowed_hosts is None:
+            allowed_hosts = ["*"]
+
+        for pattern in allowed_hosts:
+            assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
+            if pattern.startswith("*") and pattern != "*":
+                assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
+        self.app = app
+        self.allowed_hosts = list(allowed_hosts)
+        self.allow_any = "*" in allowed_hosts
+        self.www_redirect = www_redirect
+
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        if self.allow_any or scope["type"] not in (
+            "http",
+            "websocket",
+        ):  # pragma: no cover
+            await self.app(scope, receive, send)
+            return
+
+        headers = Headers(scope=scope)
+        host = headers.get("host", "").split(":")[0]
+        is_valid_host = False
+        found_www_redirect = False
+        for pattern in self.allowed_hosts:
+            if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
+                is_valid_host = True
+                break
+            elif "www." + host == pattern:
+                found_www_redirect = True
+
+        if is_valid_host:
+            await self.app(scope, receive, send)
+        else:
+            response: Response
+            if found_www_redirect and self.www_redirect:
+                url = URL(scope=scope)
+                redirect_url = url.replace(netloc="www." + url.netloc)
+                response = RedirectResponse(url=str(redirect_url))
+            else:
+                response = PlainTextResponse("Invalid host header", status_code=400)
+            await response(scope, receive, send)