diff options
Diffstat (limited to 'gn3/auth/authentication')
-rw-r--r-- | gn3/auth/authentication/oauth2/grants/password_grant.py | 10 | ||||
-rw-r--r-- | gn3/auth/authentication/users.py | 7 |
2 files changed, 11 insertions, 6 deletions
diff --git a/gn3/auth/authentication/oauth2/grants/password_grant.py b/gn3/auth/authentication/oauth2/grants/password_grant.py index 3ec7384..3233877 100644 --- a/gn3/auth/authentication/oauth2/grants/password_grant.py +++ b/gn3/auth/authentication/oauth2/grants/password_grant.py @@ -6,6 +6,8 @@ from authlib.oauth2.rfc6749 import grants from gn3.auth import db from gn3.auth.authentication.users import valid_login, user_by_email +from gn3.auth.authorisation.errors import NotFoundError + class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant): """Implement the 'Password' grant.""" TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"] @@ -13,6 +15,8 @@ class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant): def authenticate_user(self, username, password): "Authenticate the user with their username and password." with db.connection(app.config["AUTH_DB"]) as conn: - return user_by_email(conn, username).maybe( - None, - lambda user: valid_login(conn, user, password) and user) + try: + user = user_by_email(conn, username) + return user if valid_login(conn, user, password) else None + except NotFoundError as _nfe: + return None diff --git a/gn3/auth/authentication/users.py b/gn3/auth/authentication/users.py index ce01805..e65938e 100644 --- a/gn3/auth/authentication/users.py +++ b/gn3/auth/authentication/users.py @@ -6,6 +6,7 @@ import bcrypt from pymonad.maybe import Just, Maybe, Nothing from gn3.auth import db +from gn3.auth.authorisation.errors import NotFoundError class User(NamedTuple): """Class representing a user.""" @@ -25,16 +26,16 @@ DUMMY_USER = User(user_id=UUID("a391cf60-e8b7-4294-bd22-ddbbda4b3530"), email="gn3@dummy.user", name="Dummy user to use as placeholder") -def user_by_email(conn: db.DbConnection, email: str) -> Maybe: +def user_by_email(conn: db.DbConnection, email: str) -> User: """Retrieve user from database by their email address""" with db.cursor(conn) as cursor: cursor.execute("SELECT * FROM users WHERE email=?", (email,)) row = cursor.fetchone() if row: - return Just(User(UUID(row["user_id"]), row["email"], row["name"])) + return User(UUID(row["user_id"]), row["email"], row["name"]) - return Nothing + raise NotFoundError(f"Could not find user with email {email}") def user_by_id(conn: db.DbConnection, user_id: UUID) -> Maybe: """Retrieve user from database by their user id""" |