about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/websockets/legacy/auth.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/websockets/legacy/auth.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/websockets/legacy/auth.py')
-rw-r--r--.venv/lib/python3.12/site-packages/websockets/legacy/auth.py190
1 files changed, 190 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/websockets/legacy/auth.py b/.venv/lib/python3.12/site-packages/websockets/legacy/auth.py
new file mode 100644
index 00000000..a262fcd7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/websockets/legacy/auth.py
@@ -0,0 +1,190 @@
+from __future__ import annotations
+
+import functools
+import hmac
+import http
+from collections.abc import Awaitable, Iterable
+from typing import Any, Callable, cast
+
+from ..datastructures import Headers
+from ..exceptions import InvalidHeader
+from ..headers import build_www_authenticate_basic, parse_authorization_basic
+from .server import HTTPResponse, WebSocketServerProtocol
+
+
+__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
+
+Credentials = tuple[str, str]
+
+
+def is_credentials(value: Any) -> bool:
+    try:
+        username, password = value
+    except (TypeError, ValueError):
+        return False
+    else:
+        return isinstance(username, str) and isinstance(password, str)
+
+
+class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
+    """
+    WebSocket server protocol that enforces HTTP Basic Auth.
+
+    """
+
+    realm: str = ""
+    """
+    Scope of protection.
+
+    If provided, it should contain only ASCII characters because the
+    encoding of non-ASCII characters is undefined.
+    """
+
+    username: str | None = None
+    """Username of the authenticated user."""
+
+    def __init__(
+        self,
+        *args: Any,
+        realm: str | None = None,
+        check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
+        **kwargs: Any,
+    ) -> None:
+        if realm is not None:
+            self.realm = realm  # shadow class attribute
+        self._check_credentials = check_credentials
+        super().__init__(*args, **kwargs)
+
+    async def check_credentials(self, username: str, password: str) -> bool:
+        """
+        Check whether credentials are authorized.
+
+        This coroutine may be overridden in a subclass, for example to
+        authenticate against a database or an external service.
+
+        Args:
+            username: HTTP Basic Auth username.
+            password: HTTP Basic Auth password.
+
+        Returns:
+            :obj:`True` if the handshake should continue;
+            :obj:`False` if it should fail with an HTTP 401 error.
+
+        """
+        if self._check_credentials is not None:
+            return await self._check_credentials(username, password)
+
+        return False
+
+    async def process_request(
+        self,
+        path: str,
+        request_headers: Headers,
+    ) -> HTTPResponse | None:
+        """
+        Check HTTP Basic Auth and return an HTTP 401 response if needed.
+
+        """
+        try:
+            authorization = request_headers["Authorization"]
+        except KeyError:
+            return (
+                http.HTTPStatus.UNAUTHORIZED,
+                [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
+                b"Missing credentials\n",
+            )
+
+        try:
+            username, password = parse_authorization_basic(authorization)
+        except InvalidHeader:
+            return (
+                http.HTTPStatus.UNAUTHORIZED,
+                [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
+                b"Unsupported credentials\n",
+            )
+
+        if not await self.check_credentials(username, password):
+            return (
+                http.HTTPStatus.UNAUTHORIZED,
+                [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
+                b"Invalid credentials\n",
+            )
+
+        self.username = username
+
+        return await super().process_request(path, request_headers)
+
+
+def basic_auth_protocol_factory(
+    realm: str | None = None,
+    credentials: Credentials | Iterable[Credentials] | None = None,
+    check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
+    create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None,
+) -> Callable[..., BasicAuthWebSocketServerProtocol]:
+    """
+    Protocol factory that enforces HTTP Basic Auth.
+
+    :func:`basic_auth_protocol_factory` is designed to integrate with
+    :func:`~websockets.legacy.server.serve` like this::
+
+        serve(
+            ...,
+            create_protocol=basic_auth_protocol_factory(
+                realm="my dev server",
+                credentials=("hello", "iloveyou"),
+            )
+        )
+
+    Args:
+        realm: Scope of protection. It should contain only ASCII characters
+            because the encoding of non-ASCII characters is undefined.
+            Refer to section 2.2 of :rfc:`7235` for details.
+        credentials: Hard coded authorized credentials. It can be a
+            ``(username, password)`` pair or a list of such pairs.
+        check_credentials: Coroutine that verifies credentials.
+            It receives ``username`` and ``password`` arguments
+            and returns a :class:`bool`. One of ``credentials`` or
+            ``check_credentials`` must be provided but not both.
+        create_protocol: Factory that creates the protocol. By default, this
+            is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced
+            by a subclass.
+    Raises:
+        TypeError: If the ``credentials`` or ``check_credentials`` argument is
+            wrong.
+
+    """
+    if (credentials is None) == (check_credentials is None):
+        raise TypeError("provide either credentials or check_credentials")
+
+    if credentials is not None:
+        if is_credentials(credentials):
+            credentials_list = [cast(Credentials, credentials)]
+        elif isinstance(credentials, Iterable):
+            credentials_list = list(cast(Iterable[Credentials], credentials))
+            if not all(is_credentials(item) for item in credentials_list):
+                raise TypeError(f"invalid credentials argument: {credentials}")
+        else:
+            raise TypeError(f"invalid credentials argument: {credentials}")
+
+        credentials_dict = dict(credentials_list)
+
+        async def check_credentials(username: str, password: str) -> bool:
+            try:
+                expected_password = credentials_dict[username]
+            except KeyError:
+                return False
+            return hmac.compare_digest(expected_password, password)
+
+    if create_protocol is None:
+        create_protocol = BasicAuthWebSocketServerProtocol
+
+    # Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] |
+    # Callable[..., BasicAuthWebSocketServerProtocol]" not callable  [misc]
+    create_protocol = cast(
+        Callable[..., BasicAuthWebSocketServerProtocol], create_protocol
+    )
+    return functools.partial(
+        create_protocol,
+        realm=realm,
+        check_credentials=check_credentials,
+    )