aboutsummaryrefslogtreecommitdiff
path: root/gn3/auth/authentication
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/auth/authentication')
-rw-r--r--gn3/auth/authentication/oauth2/grants/password_grant.py10
-rw-r--r--gn3/auth/authentication/users.py7
2 files changed, 11 insertions, 6 deletions
diff --git a/gn3/auth/authentication/oauth2/grants/password_grant.py b/gn3/auth/authentication/oauth2/grants/password_grant.py
index 3ec7384..3233877 100644
--- a/gn3/auth/authentication/oauth2/grants/password_grant.py
+++ b/gn3/auth/authentication/oauth2/grants/password_grant.py
@@ -6,6 +6,8 @@ from authlib.oauth2.rfc6749 import grants
from gn3.auth import db
from gn3.auth.authentication.users import valid_login, user_by_email
+from gn3.auth.authorisation.errors import NotFoundError
+
class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant):
"""Implement the 'Password' grant."""
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"]
@@ -13,6 +15,8 @@ class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant):
def authenticate_user(self, username, password):
"Authenticate the user with their username and password."
with db.connection(app.config["AUTH_DB"]) as conn:
- return user_by_email(conn, username).maybe(
- None,
- lambda user: valid_login(conn, user, password) and user)
+ try:
+ user = user_by_email(conn, username)
+ return user if valid_login(conn, user, password) else None
+ except NotFoundError as _nfe:
+ return None
diff --git a/gn3/auth/authentication/users.py b/gn3/auth/authentication/users.py
index ce01805..e65938e 100644
--- a/gn3/auth/authentication/users.py
+++ b/gn3/auth/authentication/users.py
@@ -6,6 +6,7 @@ import bcrypt
from pymonad.maybe import Just, Maybe, Nothing
from gn3.auth import db
+from gn3.auth.authorisation.errors import NotFoundError
class User(NamedTuple):
"""Class representing a user."""
@@ -25,16 +26,16 @@ DUMMY_USER = User(user_id=UUID("a391cf60-e8b7-4294-bd22-ddbbda4b3530"),
email="gn3@dummy.user",
name="Dummy user to use as placeholder")
-def user_by_email(conn: db.DbConnection, email: str) -> Maybe:
+def user_by_email(conn: db.DbConnection, email: str) -> User:
"""Retrieve user from database by their email address"""
with db.cursor(conn) as cursor:
cursor.execute("SELECT * FROM users WHERE email=?", (email,))
row = cursor.fetchone()
if row:
- return Just(User(UUID(row["user_id"]), row["email"], row["name"]))
+ return User(UUID(row["user_id"]), row["email"], row["name"])
- return Nothing
+ raise NotFoundError(f"Could not find user with email {email}")
def user_by_id(conn: db.DbConnection, user_id: UUID) -> Maybe:
"""Retrieve user from database by their user id"""