diff options
Diffstat (limited to 'gn_auth/auth/authentication/oauth2/models')
-rw-r--r-- | gn_auth/auth/authentication/oauth2/models/oauth2client.py | 119 |
1 files changed, 76 insertions, 43 deletions
diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py index ea86772..b140a16 100644 --- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py +++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py @@ -1,17 +1,26 @@ """OAuth2 Client model.""" import json import datetime + from uuid import UUID -from typing import Sequence, Optional, NamedTuple +from dataclasses import dataclass +from functools import cached_property +from typing import Sequence, Optional +from authlib.oauth2.rfc6749 import ClientMixin from pymonad.maybe import Just, Maybe, Nothing from gn_auth.auth.db import sqlite3 as db -from gn_auth.auth.authentication.users import User, users, user_by_id, same_password +from gn_auth.auth.authentication.users import (User, + users, + user_by_id, + same_password) from gn_auth.auth.authorisation.errors import NotFoundError -class OAuth2Client(NamedTuple): + +@dataclass(frozen=True) +class OAuth2Client(ClientMixin): """ Client to the OAuth2 Server. @@ -29,12 +38,12 @@ class OAuth2Client(NamedTuple): """Check whether the `client_secret` matches this client.""" return same_password(client_secret, self.client_secret) - @property + @cached_property def token_endpoint_auth_method(self) -> str: """Return the token endpoint authorisation method.""" return self.client_metadata.get("token_endpoint_auth_method", ["none"]) - @property + @cached_property def client_type(self) -> str: """ Return the token endpoint authorisation method. @@ -64,12 +73,12 @@ class OAuth2Client(NamedTuple): and method == "client_secret_basic") return False - @property - def id(self):# pylint: disable=[invalid-name] + @cached_property + def id(self): # pylint: disable=[invalid-name] """Return the client_id.""" return self.client_id - @property + @cached_property def grant_types(self) -> Sequence[str]: """ Return the grant types that this client supports. @@ -88,7 +97,7 @@ class OAuth2Client(NamedTuple): """ return grant_type in self.grant_types - @property + @cached_property def redirect_uris(self) -> Sequence[str]: """Return the redirect_uris that this client supports.""" return self.client_metadata.get('redirect_uris', []) @@ -99,7 +108,7 @@ class OAuth2Client(NamedTuple): """ return redirect_uri in self.redirect_uris - @property + @cached_property def response_types(self) -> Sequence[str]: """Return the response_types that this client supports.""" return self.client_metadata.get("response_type", []) @@ -108,13 +117,14 @@ class OAuth2Client(NamedTuple): """Check whether this client supports `response_type`.""" return response_type in self.response_types - @property + @cached_property def scope(self) -> Sequence[str]: """Return valid scopes for this client.""" return tuple(set(self.client_metadata.get("scope", []))) def get_allowed_scope(self, scope: str) -> str: - """Return list of scopes in `scope` that are supported by this client.""" + """Return list of scopes in `scope` that are supported by this + client.""" if not bool(scope): return "" requested = scope.split() @@ -129,33 +139,39 @@ class OAuth2Client(NamedTuple): """Return the default redirect uri""" return self.client_metadata.get("default_redirect_uri", "") + def client(conn: db.DbConnection, client_id: UUID, user: Optional[User] = None) -> Maybe: """Retrieve a client by its ID""" with db.cursor(conn) as cursor: cursor.execute( - "SELECT * FROM oauth2_clients WHERE client_id=?", (str(client_id),)) + "SELECT * FROM oauth2_clients WHERE client_id=?", + (str(client_id),)) result = cursor.fetchone() the_user = user if result: if not bool(the_user): try: the_user = user_by_id(conn, result["user_id"]) - except NotFoundError as _nfe: + except NotFoundError: the_user = None return Just( - OAuth2Client(UUID(result["client_id"]), - result["client_secret"], - datetime.datetime.fromtimestamp( - result["client_id_issued_at"]), - datetime.datetime.fromtimestamp( - result["client_secret_expires_at"]), - json.loads(result["client_metadata"]), - the_user))# type: ignore[arg-type] - + OAuth2Client( + client_id=UUID(result["client_id"]), + client_secret=result["client_secret"], + client_id_issued_at=datetime.datetime.fromtimestamp( + result["client_id_issued_at"] + ), + client_secret_expires_at=datetime.datetime.fromtimestamp( + result["client_secret_expires_at"] + ), + client_metadata=json.loads(result["client_metadata"]), + user=the_user) # type: ignore[arg-type] + ) return Nothing + def client_by_id_and_secret(conn: db.DbConnection, client_id: UUID, client_secret: str) -> OAuth2Client: """Retrieve a client by its ID and secret""" @@ -166,16 +182,25 @@ def client_by_id_and_secret(conn: db.DbConnection, client_id: UUID, row = cursor.fetchone() if bool(row) and same_password(client_secret, row["client_secret"]): return OAuth2Client( - client_id, client_secret, - datetime.datetime.fromtimestamp(row["client_id_issued_at"]), - datetime.datetime.fromtimestamp( - row["client_secret_expires_at"]), - json.loads(row["client_metadata"]), - user_by_id(conn, UUID(row["user_id"]))) + client_id=client_id, + client_secret=client_secret, + client_id_issued_at=datetime.datetime.fromtimestamp( + row["client_id_issued_at"] + ), + client_secret_expires_at=datetime.datetime.fromtimestamp( + row["client_secret_expires_at"] + ), + client_metadata=json.loads(row["client_metadata"]), + user=user_by_id(conn, UUID(row["user_id"])) + ) + raise NotFoundError( + "Could not find client with the given credentials." + ) - raise NotFoundError("Could not find client with the given credentials.") -def save_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client: +def save_client( + conn: db.DbConnection, the_client: OAuth2Client +) -> OAuth2Client: """Persist the client details into the database.""" with db.cursor(conn) as cursor: query = ( @@ -204,6 +229,7 @@ def save_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client }) return the_client + def oauth2_clients(conn: db.DbConnection) -> tuple[OAuth2Client, ...]: """Fetch a list of all OAuth2 clients.""" with db.cursor(conn) as cursor: @@ -211,19 +237,26 @@ def oauth2_clients(conn: db.DbConnection) -> tuple[OAuth2Client, ...]: clients_rs = cursor.fetchall() the_users = { usr.user_id: usr for usr in users( - conn, tuple({UUID(result["user_id"]) for result in clients_rs})) + conn, tuple({UUID(result["user_id"]) + for result in clients_rs})) } - return tuple(OAuth2Client(UUID(result["client_id"]), - result["client_secret"], - datetime.datetime.fromtimestamp( - result["client_id_issued_at"]), - datetime.datetime.fromtimestamp( - result["client_secret_expires_at"]), - json.loads(result["client_metadata"]), - the_users[UUID(result["user_id"])]) - for result in clients_rs) - -def delete_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client: + return tuple( + OAuth2Client( + client_id=UUID(result["client_id"]), + client_secret=result["client_secret"], + client_id_issued_at=datetime.datetime.fromtimestamp( + result["client_id_issued_at"]), + client_secret_expires_at=datetime.datetime.fromtimestamp( + result["client_secret_expires_at"]), + client_metadata=json.loads(result["client_metadata"]), + user=the_users[UUID(result["user_id"])] + ) for result in clients_rs + ) + + +def delete_client( + conn: db.DbConnection, the_client: OAuth2Client +) -> OAuth2Client: """Delete the given client from the database""" with db.cursor(conn) as cursor: params = (str(the_client.client_id),) |