diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/authentication.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/starlette/authentication.py | 147 |
1 files changed, 147 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/authentication.py b/.venv/lib/python3.12/site-packages/starlette/authentication.py new file mode 100644 index 00000000..4fd86641 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/authentication.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import functools +import inspect +import sys +import typing +from urllib.parse import urlencode + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +from starlette._utils import is_async_callable +from starlette.exceptions import HTTPException +from starlette.requests import HTTPConnection, Request +from starlette.responses import RedirectResponse +from starlette.websockets import WebSocket + +_P = ParamSpec("_P") + + +def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: + for scope in scopes: + if scope not in conn.auth.scopes: + return False + return True + + +def requires( + scopes: str | typing.Sequence[str], + status_code: int = 403, + redirect: str | None = None, +) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]: + scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) + + def decorator( + func: typing.Callable[_P, typing.Any], + ) -> typing.Callable[_P, typing.Any]: + sig = inspect.signature(func) + for idx, parameter in enumerate(sig.parameters.values()): + if parameter.name == "request" or parameter.name == "websocket": + type_ = parameter.name + break + else: + raise Exception(f'No "request" or "websocket" argument on function "{func}"') + + if type_ == "websocket": + # Handle websocket functions. (Always async) + @functools.wraps(func) + async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None) + assert isinstance(websocket, WebSocket) + + if not has_required_scope(websocket, scopes_list): + await websocket.close() + else: + await func(*args, **kwargs) + + return websocket_wrapper + + elif is_async_callable(func): + # Handle async request/response functions. + @functools.wraps(func) + async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: + request = kwargs.get("request", args[idx] if idx < len(args) else None) + assert isinstance(request, Request) + + if not has_required_scope(request, scopes_list): + if redirect is not None: + orig_request_qparam = urlencode({"next": str(request.url)}) + next_url = f"{request.url_for(redirect)}?{orig_request_qparam}" + return RedirectResponse(url=next_url, status_code=303) + raise HTTPException(status_code=status_code) + return await func(*args, **kwargs) + + return async_wrapper + + else: + # Handle sync request/response functions. + @functools.wraps(func) + def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: + request = kwargs.get("request", args[idx] if idx < len(args) else None) + assert isinstance(request, Request) + + if not has_required_scope(request, scopes_list): + if redirect is not None: + orig_request_qparam = urlencode({"next": str(request.url)}) + next_url = f"{request.url_for(redirect)}?{orig_request_qparam}" + return RedirectResponse(url=next_url, status_code=303) + raise HTTPException(status_code=status_code) + return func(*args, **kwargs) + + return sync_wrapper + + return decorator + + +class AuthenticationError(Exception): + pass + + +class AuthenticationBackend: + async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None: + raise NotImplementedError() # pragma: no cover + + +class AuthCredentials: + def __init__(self, scopes: typing.Sequence[str] | None = None): + self.scopes = [] if scopes is None else list(scopes) + + +class BaseUser: + @property + def is_authenticated(self) -> bool: + raise NotImplementedError() # pragma: no cover + + @property + def display_name(self) -> str: + raise NotImplementedError() # pragma: no cover + + @property + def identity(self) -> str: + raise NotImplementedError() # pragma: no cover + + +class SimpleUser(BaseUser): + def __init__(self, username: str) -> None: + self.username = username + + @property + def is_authenticated(self) -> bool: + return True + + @property + def display_name(self) -> str: + return self.username + + +class UnauthenticatedUser(BaseUser): + @property + def is_authenticated(self) -> bool: + return False + + @property + def display_name(self) -> str: + return "" |