diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/uvicorn/middleware/proxy_headers.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/uvicorn/middleware/proxy_headers.py | 84 |
1 files changed, 84 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/uvicorn/middleware/proxy_headers.py b/.venv/lib/python3.12/site-packages/uvicorn/middleware/proxy_headers.py new file mode 100644 index 00000000..28277e1d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/uvicorn/middleware/proxy_headers.py @@ -0,0 +1,84 @@ +""" +This middleware can be used when a known proxy is fronting the application, +and is trusted to be properly setting the `X-Forwarded-Proto` and +`X-Forwarded-For` headers with the connecting client information. + +Modifies the `client` and `scheme` information so that they reference +the connecting client, rather that the connecting proxy. + +https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies +""" +from typing import List, Optional, Tuple, Union, cast + +from uvicorn._types import ( + ASGI3Application, + ASGIReceiveCallable, + ASGISendCallable, + HTTPScope, + Scope, + WebSocketScope, +) + + +class ProxyHeadersMiddleware: + def __init__( + self, + app: "ASGI3Application", + trusted_hosts: Union[List[str], str] = "127.0.0.1", + ) -> None: + self.app = app + if isinstance(trusted_hosts, str): + self.trusted_hosts = {item.strip() for item in trusted_hosts.split(",")} + else: + self.trusted_hosts = set(trusted_hosts) + self.always_trust = "*" in self.trusted_hosts + + def get_trusted_client_host( + self, x_forwarded_for_hosts: List[str] + ) -> Optional[str]: + if self.always_trust: + return x_forwarded_for_hosts[0] + + for host in reversed(x_forwarded_for_hosts): + if host not in self.trusted_hosts: + return host + + return None + + async def __call__( + self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable" + ) -> None: + if scope["type"] in ("http", "websocket"): + scope = cast(Union["HTTPScope", "WebSocketScope"], scope) + client_addr: Optional[Tuple[str, int]] = scope.get("client") + client_host = client_addr[0] if client_addr else None + + if self.always_trust or client_host in self.trusted_hosts: + headers = dict(scope["headers"]) + + if b"x-forwarded-proto" in headers: + # Determine if the incoming request was http or https based on + # the X-Forwarded-Proto header. + x_forwarded_proto = ( + headers[b"x-forwarded-proto"].decode("latin1").strip() + ) + if scope["type"] == "websocket": + scope["scheme"] = ( + "wss" if x_forwarded_proto == "https" else "ws" + ) + else: + scope["scheme"] = x_forwarded_proto + + if b"x-forwarded-for" in headers: + # Determine the client address from the last trusted IP in the + # X-Forwarded-For header. We've lost the connecting client's port + # information by now, so only include the host. + x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1") + x_forwarded_for_hosts = [ + item.strip() for item in x_forwarded_for.split(",") + ] + host = self.get_trusted_client_host(x_forwarded_for_hosts) + port = 0 + scope["client"] = (host, port) # type: ignore[arg-type] + + return await self.app(scope, receive, send) |