diff options
Diffstat (limited to 'gn3/auth/authentication/oauth2')
-rw-r--r-- | gn3/auth/authentication/oauth2/models/oauth2client.py | 33 |
1 files changed, 26 insertions, 7 deletions
diff --git a/gn3/auth/authentication/oauth2/models/oauth2client.py b/gn3/auth/authentication/oauth2/models/oauth2client.py index 14c4c94..564ed32 100644 --- a/gn3/auth/authentication/oauth2/models/oauth2client.py +++ b/gn3/auth/authentication/oauth2/models/oauth2client.py @@ -1,13 +1,13 @@ """OAuth2 Client model.""" import json -import uuid import datetime +from uuid import UUID from typing import Sequence, Optional, NamedTuple from pymonad.maybe import Just, Maybe, Nothing from gn3.auth import db -from gn3.auth.authentication.users import User, user_by_id, same_password +from gn3.auth.authentication.users import User, users, user_by_id, same_password from gn3.auth.authorisation.errors import NotFoundError @@ -18,7 +18,7 @@ class OAuth2Client(NamedTuple): This is defined according to the mixin at https://docs.authlib.org/en/latest/specs/rfc6749.html#authlib.oauth2.rfc6749.ClientMixin """ - client_id: uuid.UUID + client_id: UUID client_secret: str client_id_issued_at: datetime.datetime client_secret_expires_at: datetime.datetime @@ -129,7 +129,7 @@ class OAuth2Client(NamedTuple): """Return the default redirect uri""" return self.client_metadata.get("default_redirect_uri", "") -def client(conn: db.DbConnection, client_id: uuid.UUID, +def client(conn: db.DbConnection, client_id: UUID, user: Optional[User] = None) -> Maybe: """Retrieve a client by its ID""" with db.cursor(conn) as cursor: @@ -145,7 +145,7 @@ def client(conn: db.DbConnection, client_id: uuid.UUID, the_user = None return Just( - OAuth2Client(uuid.UUID(result["client_id"]), + OAuth2Client(UUID(result["client_id"]), result["client_secret"], datetime.datetime.fromtimestamp( result["client_id_issued_at"]), @@ -156,7 +156,7 @@ def client(conn: db.DbConnection, client_id: uuid.UUID, return Nothing -def client_by_id_and_secret(conn: db.DbConnection, client_id: uuid.UUID, +def client_by_id_and_secret(conn: db.DbConnection, client_id: UUID, client_secret: str) -> OAuth2Client: """Retrieve a client by its ID and secret""" with db.cursor(conn) as cursor: @@ -171,7 +171,7 @@ def client_by_id_and_secret(conn: db.DbConnection, client_id: uuid.UUID, datetime.datetime.fromtimestamp( row["client_secret_expires_at"]), json.loads(row["client_metadata"]), - user_by_id(conn, uuid.UUID(row["user_id"]))) + user_by_id(conn, UUID(row["user_id"]))) raise NotFoundError("Could not find client with the given credentials.") @@ -203,3 +203,22 @@ def save_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client "user_id": str(the_client.user.user_id) }) return the_client + +def oauth2_clients(conn: db.DbConnection) -> tuple[OAuth2Client, ...]: + """Fetch a list of all OAuth2 clients.""" + with db.cursor(conn) as cursor: + cursor.execute("SELECT * FROM oauth2_clients") + clients_rs = cursor.fetchall() + the_users = { + usr.user_id: usr for usr in users( + conn, tuple({UUID(result["user_id"]) for result in clients_rs})) + } + return tuple(OAuth2Client(UUID(result["client_id"]), + result["client_secret"], + datetime.datetime.fromtimestamp( + result["client_id_issued_at"]), + datetime.datetime.fromtimestamp( + result["client_secret_expires_at"]), + json.loads(result["client_metadata"]), + the_users[UUID(result["user_id"])]) + for result in clients_rs) |