diff options
Diffstat (limited to 'gn3/auth')
4 files changed, 11 insertions, 9 deletions
diff --git a/gn3/auth/authentication/oauth2/grants/password_grant.py b/gn3/auth/authentication/oauth2/grants/password_grant.py index 91fdb7c..3ec7384 100644 --- a/gn3/auth/authentication/oauth2/grants/password_grant.py +++ b/gn3/auth/authentication/oauth2/grants/password_grant.py @@ -15,4 +15,4 @@ class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant): with db.connection(app.config["AUTH_DB"]) as conn: return user_by_email(conn, username).maybe( None, - lambda user: valid_login(conn, user, password)) + lambda user: valid_login(conn, user, password) and user) diff --git a/gn3/auth/authentication/oauth2/models/oauth2client.py b/gn3/auth/authentication/oauth2/models/oauth2client.py index efaff54..ac3bdb1 100644 --- a/gn3/auth/authentication/oauth2/models/oauth2client.py +++ b/gn3/auth/authentication/oauth2/models/oauth2client.py @@ -2,7 +2,7 @@ import json import uuid import datetime -from typing import NamedTuple, Sequence +from typing import Sequence, Optional, NamedTuple from pymonad.maybe import Just, Maybe, Nothing @@ -127,12 +127,15 @@ class OAuth2Client(NamedTuple): """Return the default redirect uri""" return self.client_metadata.get("default_redirect_uri", "") -def client(conn: db.DbConnection, client_id: uuid.UUID) -> Maybe: +def client(conn: db.DbConnection, client_id: uuid.UUID, + user: Optional[User] = None) -> Maybe: """Retrieve a client by its ID""" with db.cursor(conn) as cursor: cursor.execute( "SELECT * FROM oauth2_clients WHERE client_id=?", (str(client_id),)) result = cursor.fetchone() + the_user = user or user_by_id(conn, result["user_id"]).maybe( + None, lambda usr: usr) if result: return Just( OAuth2Client(uuid.UUID(result["client_id"]), @@ -142,8 +145,6 @@ def client(conn: db.DbConnection, client_id: uuid.UUID) -> Maybe: datetime.datetime.fromtimestamp( result["client_secret_expires_at"]), json.loads(result["client_metadata"]), - user_by_id( # type: ignore[misc] - conn, uuid.UUID(result["user_id"])).maybe( - None, lambda usr: usr))) + the_user)) return Nothing diff --git a/gn3/auth/authentication/oauth2/models/oauth2token.py b/gn3/auth/authentication/oauth2/models/oauth2token.py index 70421b4..ce7caae 100644 --- a/gn3/auth/authentication/oauth2/models/oauth2token.py +++ b/gn3/auth/authentication/oauth2/models/oauth2token.py @@ -49,9 +49,10 @@ class OAuth2Token(NamedTuple): return self.revoked def __token_from_resultset__(conn: db.DbConnection, rset) -> Maybe: - the_client = client(conn, uuid.UUID(rset["client_id"])) - the_user = user_by_id(conn, uuid.UUID(rset["user_id"])) __identity__ = lambda val: val + the_user = user_by_id(conn, uuid.UUID(rset["user_id"])) + the_client = client(conn, uuid.UUID(rset["client_id"]), + the_user.maybe(None, __identity__)) if the_client.is_just() and the_user.is_just(): return Just(OAuth2Token(token_id=uuid.UUID(rset["token_id"]), diff --git a/gn3/auth/authentication/oauth2/server.py b/gn3/auth/authentication/oauth2/server.py index 960625d..73c9340 100644 --- a/gn3/auth/authentication/oauth2/server.py +++ b/gn3/auth/authentication/oauth2/server.py @@ -36,7 +36,7 @@ def create_save_token_func(token_model: type) -> Callable: save_token( conn, token_model( token_id=uuid.uuid4(), client=request.client, - user=request.client.user, + user=request.user, **{ "refresh_token": None, "revoked": False, "issued_at": datetime.datetime.now(), |