about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/starlette/authentication.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/authentication.py')
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/authentication.py147
1 files changed, 147 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/authentication.py b/.venv/lib/python3.12/site-packages/starlette/authentication.py
new file mode 100644
index 00000000..4fd86641
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/authentication.py
@@ -0,0 +1,147 @@
+from __future__ import annotations
+
+import functools
+import inspect
+import sys
+import typing
+from urllib.parse import urlencode
+
+if sys.version_info >= (3, 10):  # pragma: no cover
+    from typing import ParamSpec
+else:  # pragma: no cover
+    from typing_extensions import ParamSpec
+
+from starlette._utils import is_async_callable
+from starlette.exceptions import HTTPException
+from starlette.requests import HTTPConnection, Request
+from starlette.responses import RedirectResponse
+from starlette.websockets import WebSocket
+
+_P = ParamSpec("_P")
+
+
+def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
+    for scope in scopes:
+        if scope not in conn.auth.scopes:
+            return False
+    return True
+
+
+def requires(
+    scopes: str | typing.Sequence[str],
+    status_code: int = 403,
+    redirect: str | None = None,
+) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
+    scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
+
+    def decorator(
+        func: typing.Callable[_P, typing.Any],
+    ) -> typing.Callable[_P, typing.Any]:
+        sig = inspect.signature(func)
+        for idx, parameter in enumerate(sig.parameters.values()):
+            if parameter.name == "request" or parameter.name == "websocket":
+                type_ = parameter.name
+                break
+        else:
+            raise Exception(f'No "request" or "websocket" argument on function "{func}"')
+
+        if type_ == "websocket":
+            # Handle websocket functions. (Always async)
+            @functools.wraps(func)
+            async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
+                websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
+                assert isinstance(websocket, WebSocket)
+
+                if not has_required_scope(websocket, scopes_list):
+                    await websocket.close()
+                else:
+                    await func(*args, **kwargs)
+
+            return websocket_wrapper
+
+        elif is_async_callable(func):
+            # Handle async request/response functions.
+            @functools.wraps(func)
+            async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
+                request = kwargs.get("request", args[idx] if idx < len(args) else None)
+                assert isinstance(request, Request)
+
+                if not has_required_scope(request, scopes_list):
+                    if redirect is not None:
+                        orig_request_qparam = urlencode({"next": str(request.url)})
+                        next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
+                        return RedirectResponse(url=next_url, status_code=303)
+                    raise HTTPException(status_code=status_code)
+                return await func(*args, **kwargs)
+
+            return async_wrapper
+
+        else:
+            # Handle sync request/response functions.
+            @functools.wraps(func)
+            def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
+                request = kwargs.get("request", args[idx] if idx < len(args) else None)
+                assert isinstance(request, Request)
+
+                if not has_required_scope(request, scopes_list):
+                    if redirect is not None:
+                        orig_request_qparam = urlencode({"next": str(request.url)})
+                        next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
+                        return RedirectResponse(url=next_url, status_code=303)
+                    raise HTTPException(status_code=status_code)
+                return func(*args, **kwargs)
+
+            return sync_wrapper
+
+    return decorator
+
+
+class AuthenticationError(Exception):
+    pass
+
+
+class AuthenticationBackend:
+    async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
+        raise NotImplementedError()  # pragma: no cover
+
+
+class AuthCredentials:
+    def __init__(self, scopes: typing.Sequence[str] | None = None):
+        self.scopes = [] if scopes is None else list(scopes)
+
+
+class BaseUser:
+    @property
+    def is_authenticated(self) -> bool:
+        raise NotImplementedError()  # pragma: no cover
+
+    @property
+    def display_name(self) -> str:
+        raise NotImplementedError()  # pragma: no cover
+
+    @property
+    def identity(self) -> str:
+        raise NotImplementedError()  # pragma: no cover
+
+
+class SimpleUser(BaseUser):
+    def __init__(self, username: str) -> None:
+        self.username = username
+
+    @property
+    def is_authenticated(self) -> bool:
+        return True
+
+    @property
+    def display_name(self) -> str:
+        return self.username
+
+
+class UnauthenticatedUser(BaseUser):
+    @property
+    def is_authenticated(self) -> bool:
+        return False
+
+    @property
+    def display_name(self) -> str:
+        return ""