aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py')
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py b/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py
new file mode 100644
index 00000000..8555ee07
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py
@@ -0,0 +1,52 @@
+from __future__ import annotations
+
+import typing
+
+from starlette.authentication import (
+ AuthCredentials,
+ AuthenticationBackend,
+ AuthenticationError,
+ UnauthenticatedUser,
+)
+from starlette.requests import HTTPConnection
+from starlette.responses import PlainTextResponse, Response
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+
+class AuthenticationMiddleware:
+ def __init__(
+ self,
+ app: ASGIApp,
+ backend: AuthenticationBackend,
+ on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
+ ) -> None:
+ self.app = app
+ self.backend = backend
+ self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = (
+ on_error if on_error is not None else self.default_on_error
+ )
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] not in ["http", "websocket"]:
+ await self.app(scope, receive, send)
+ return
+
+ conn = HTTPConnection(scope)
+ try:
+ auth_result = await self.backend.authenticate(conn)
+ except AuthenticationError as exc:
+ response = self.on_error(conn, exc)
+ if scope["type"] == "websocket":
+ await send({"type": "websocket.close", "code": 1000})
+ else:
+ await response(scope, receive, send)
+ return
+
+ if auth_result is None:
+ auth_result = AuthCredentials(), UnauthenticatedUser()
+ scope["auth"], scope["user"] = auth_result
+ await self.app(scope, receive, send)
+
+ @staticmethod
+ def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
+ return PlainTextResponse(str(exc), status_code=400)