diff options
Diffstat (limited to 'gn3/auth/authentication')
-rw-r--r-- | gn3/auth/authentication/oauth2/models/oauth2client.py | 33 | ||||
-rw-r--r-- | gn3/auth/authentication/users.py | 16 |
2 files changed, 42 insertions, 7 deletions
diff --git a/gn3/auth/authentication/oauth2/models/oauth2client.py b/gn3/auth/authentication/oauth2/models/oauth2client.py index 14c4c94..564ed32 100644 --- a/gn3/auth/authentication/oauth2/models/oauth2client.py +++ b/gn3/auth/authentication/oauth2/models/oauth2client.py @@ -1,13 +1,13 @@ """OAuth2 Client model.""" import json -import uuid import datetime +from uuid import UUID from typing import Sequence, Optional, NamedTuple from pymonad.maybe import Just, Maybe, Nothing from gn3.auth import db -from gn3.auth.authentication.users import User, user_by_id, same_password +from gn3.auth.authentication.users import User, users, user_by_id, same_password from gn3.auth.authorisation.errors import NotFoundError @@ -18,7 +18,7 @@ class OAuth2Client(NamedTuple): This is defined according to the mixin at https://docs.authlib.org/en/latest/specs/rfc6749.html#authlib.oauth2.rfc6749.ClientMixin """ - client_id: uuid.UUID + client_id: UUID client_secret: str client_id_issued_at: datetime.datetime client_secret_expires_at: datetime.datetime @@ -129,7 +129,7 @@ class OAuth2Client(NamedTuple): """Return the default redirect uri""" return self.client_metadata.get("default_redirect_uri", "") -def client(conn: db.DbConnection, client_id: uuid.UUID, +def client(conn: db.DbConnection, client_id: UUID, user: Optional[User] = None) -> Maybe: """Retrieve a client by its ID""" with db.cursor(conn) as cursor: @@ -145,7 +145,7 @@ def client(conn: db.DbConnection, client_id: uuid.UUID, the_user = None return Just( - OAuth2Client(uuid.UUID(result["client_id"]), + OAuth2Client(UUID(result["client_id"]), result["client_secret"], datetime.datetime.fromtimestamp( result["client_id_issued_at"]), @@ -156,7 +156,7 @@ def client(conn: db.DbConnection, client_id: uuid.UUID, return Nothing -def client_by_id_and_secret(conn: db.DbConnection, client_id: uuid.UUID, +def client_by_id_and_secret(conn: db.DbConnection, client_id: UUID, client_secret: str) -> OAuth2Client: """Retrieve a client by its ID and secret""" with db.cursor(conn) as cursor: @@ -171,7 +171,7 @@ def client_by_id_and_secret(conn: db.DbConnection, client_id: uuid.UUID, datetime.datetime.fromtimestamp( row["client_secret_expires_at"]), json.loads(row["client_metadata"]), - user_by_id(conn, uuid.UUID(row["user_id"]))) + user_by_id(conn, UUID(row["user_id"]))) raise NotFoundError("Could not find client with the given credentials.") @@ -203,3 +203,22 @@ def save_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client "user_id": str(the_client.user.user_id) }) return the_client + +def oauth2_clients(conn: db.DbConnection) -> tuple[OAuth2Client, ...]: + """Fetch a list of all OAuth2 clients.""" + with db.cursor(conn) as cursor: + cursor.execute("SELECT * FROM oauth2_clients") + clients_rs = cursor.fetchall() + the_users = { + usr.user_id: usr for usr in users( + conn, tuple({UUID(result["user_id"]) for result in clients_rs})) + } + return tuple(OAuth2Client(UUID(result["client_id"]), + result["client_secret"], + datetime.datetime.fromtimestamp( + result["client_id_issued_at"]), + datetime.datetime.fromtimestamp( + result["client_secret_expires_at"]), + json.loads(result["client_metadata"]), + the_users[UUID(result["user_id"])]) + for result in clients_rs) diff --git a/gn3/auth/authentication/users.py b/gn3/auth/authentication/users.py index 8b4f115..0e72ed2 100644 --- a/gn3/auth/authentication/users.py +++ b/gn3/auth/authentication/users.py @@ -110,3 +110,19 @@ def set_user_password( "ON CONFLICT (user_id) DO UPDATE SET password=:hash"), {"user_id": str(user.user_id), "hash": hashed_password}) return user, hashed_password + +def users(conn: db.DbConnection, + ids: tuple[UUID, ...] = tuple()) -> tuple[User, ...]: + """ + Fetch all users with the given `ids`. If `ids` is empty, return ALL users. + """ + params = ", ".join(["?"] * len(ids)) + with db.cursor(conn) as cursor: + query = "SELECT * FROM users" + ( + f" WHERE user_id IN ({params})" + if len(ids) > 0 else "") + print(query) + cursor.execute(query, tuple(str(the_id) for the_id in ids)) + return tuple(User(UUID(row["user_id"]), row["email"], row["name"]) + for row in cursor.fetchall()) + return tuple() |