aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/middleware/cors.py')
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/cors.py172
1 files changed, 172 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py b/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py
new file mode 100644
index 00000000..61502691
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py
@@ -0,0 +1,172 @@
+from __future__ import annotations
+
+import functools
+import re
+import typing
+
+from starlette.datastructures import Headers, MutableHeaders
+from starlette.responses import PlainTextResponse, Response
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
+
+ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
+SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"}
+
+
+class CORSMiddleware:
+ def __init__(
+ self,
+ app: ASGIApp,
+ allow_origins: typing.Sequence[str] = (),
+ allow_methods: typing.Sequence[str] = ("GET",),
+ allow_headers: typing.Sequence[str] = (),
+ allow_credentials: bool = False,
+ allow_origin_regex: str | None = None,
+ expose_headers: typing.Sequence[str] = (),
+ max_age: int = 600,
+ ) -> None:
+ if "*" in allow_methods:
+ allow_methods = ALL_METHODS
+
+ compiled_allow_origin_regex = None
+ if allow_origin_regex is not None:
+ compiled_allow_origin_regex = re.compile(allow_origin_regex)
+
+ allow_all_origins = "*" in allow_origins
+ allow_all_headers = "*" in allow_headers
+ preflight_explicit_allow_origin = not allow_all_origins or allow_credentials
+
+ simple_headers = {}
+ if allow_all_origins:
+ simple_headers["Access-Control-Allow-Origin"] = "*"
+ if allow_credentials:
+ simple_headers["Access-Control-Allow-Credentials"] = "true"
+ if expose_headers:
+ simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
+
+ preflight_headers = {}
+ if preflight_explicit_allow_origin:
+ # The origin value will be set in preflight_response() if it is allowed.
+ preflight_headers["Vary"] = "Origin"
+ else:
+ preflight_headers["Access-Control-Allow-Origin"] = "*"
+ preflight_headers.update(
+ {
+ "Access-Control-Allow-Methods": ", ".join(allow_methods),
+ "Access-Control-Max-Age": str(max_age),
+ }
+ )
+ allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
+ if allow_headers and not allow_all_headers:
+ preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
+ if allow_credentials:
+ preflight_headers["Access-Control-Allow-Credentials"] = "true"
+
+ self.app = app
+ self.allow_origins = allow_origins
+ self.allow_methods = allow_methods
+ self.allow_headers = [h.lower() for h in allow_headers]
+ self.allow_all_origins = allow_all_origins
+ self.allow_all_headers = allow_all_headers
+ self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
+ self.allow_origin_regex = compiled_allow_origin_regex
+ self.simple_headers = simple_headers
+ self.preflight_headers = preflight_headers
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] != "http": # pragma: no cover
+ await self.app(scope, receive, send)
+ return
+
+ method = scope["method"]
+ headers = Headers(scope=scope)
+ origin = headers.get("origin")
+
+ if origin is None:
+ await self.app(scope, receive, send)
+ return
+
+ if method == "OPTIONS" and "access-control-request-method" in headers:
+ response = self.preflight_response(request_headers=headers)
+ await response(scope, receive, send)
+ return
+
+ await self.simple_response(scope, receive, send, request_headers=headers)
+
+ def is_allowed_origin(self, origin: str) -> bool:
+ if self.allow_all_origins:
+ return True
+
+ if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
+ return True
+
+ return origin in self.allow_origins
+
+ def preflight_response(self, request_headers: Headers) -> Response:
+ requested_origin = request_headers["origin"]
+ requested_method = request_headers["access-control-request-method"]
+ requested_headers = request_headers.get("access-control-request-headers")
+
+ headers = dict(self.preflight_headers)
+ failures = []
+
+ if self.is_allowed_origin(origin=requested_origin):
+ if self.preflight_explicit_allow_origin:
+ # The "else" case is already accounted for in self.preflight_headers
+ # and the value would be "*".
+ headers["Access-Control-Allow-Origin"] = requested_origin
+ else:
+ failures.append("origin")
+
+ if requested_method not in self.allow_methods:
+ failures.append("method")
+
+ # If we allow all headers, then we have to mirror back any requested
+ # headers in the response.
+ if self.allow_all_headers and requested_headers is not None:
+ headers["Access-Control-Allow-Headers"] = requested_headers
+ elif requested_headers is not None:
+ for header in [h.lower() for h in requested_headers.split(",")]:
+ if header.strip() not in self.allow_headers:
+ failures.append("headers")
+ break
+
+ # We don't strictly need to use 400 responses here, since its up to
+ # the browser to enforce the CORS policy, but its more informative
+ # if we do.
+ if failures:
+ failure_text = "Disallowed CORS " + ", ".join(failures)
+ return PlainTextResponse(failure_text, status_code=400, headers=headers)
+
+ return PlainTextResponse("OK", status_code=200, headers=headers)
+
+ async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
+ send = functools.partial(self.send, send=send, request_headers=request_headers)
+ await self.app(scope, receive, send)
+
+ async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
+ if message["type"] != "http.response.start":
+ await send(message)
+ return
+
+ message.setdefault("headers", [])
+ headers = MutableHeaders(scope=message)
+ headers.update(self.simple_headers)
+ origin = request_headers["Origin"]
+ has_cookie = "cookie" in request_headers
+
+ # If request includes any cookie headers, then we must respond
+ # with the specific origin instead of '*'.
+ if self.allow_all_origins and has_cookie:
+ self.allow_explicit_origin(headers, origin)
+
+ # If we only allow specific origins, then we have to mirror back
+ # the Origin header in the response.
+ elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
+ self.allow_explicit_origin(headers, origin)
+
+ await send(message)
+
+ @staticmethod
+ def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None:
+ headers["Access-Control-Allow-Origin"] = origin
+ headers.add_vary_header("Origin")