aboutsummaryrefslogtreecommitdiff
path: root/gn3/auth/authentication
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/auth/authentication')
-rw-r--r--gn3/auth/authentication/oauth2/models/oauth2client.py12
-rw-r--r--gn3/auth/authentication/oauth2/models/oauth2token.py14
-rw-r--r--gn3/auth/authentication/users.py7
3 files changed, 21 insertions, 12 deletions
diff --git a/gn3/auth/authentication/oauth2/models/oauth2client.py b/gn3/auth/authentication/oauth2/models/oauth2client.py
index 70b8f59..14f4d5d 100644
--- a/gn3/auth/authentication/oauth2/models/oauth2client.py
+++ b/gn3/auth/authentication/oauth2/models/oauth2client.py
@@ -9,6 +9,8 @@ from pymonad.maybe import Just, Maybe, Nothing
from gn3.auth import db
from gn3.auth.authentication.users import User, user_by_id
+from gn3.auth.authorisation.errors import NotFoundError
+
class OAuth2Client(NamedTuple):
"""
Client to the OAuth2 Server.
@@ -134,8 +136,12 @@ def client(conn: db.DbConnection, client_id: uuid.UUID,
cursor.execute(
"SELECT * FROM oauth2_clients WHERE client_id=?", (str(client_id),))
result = cursor.fetchone()
- the_user = user or user_by_id(conn, result["user_id"]).maybe(
- None, lambda usr: usr)# type: ignore
+ the_user = user
+ if not bool(the_user):
+ try:
+ the_user = user_by_id(conn, result["user_id"])
+ except NotFoundError as _nfe:
+ the_user = None
if result:
return Just(
OAuth2Client(uuid.UUID(result["client_id"]),
@@ -145,6 +151,6 @@ def client(conn: db.DbConnection, client_id: uuid.UUID,
datetime.datetime.fromtimestamp(
result["client_secret_expires_at"]),
json.loads(result["client_metadata"]),
- the_user))
+ the_user))# type: ignore[arg-type]
return Nothing
diff --git a/gn3/auth/authentication/oauth2/models/oauth2token.py b/gn3/auth/authentication/oauth2/models/oauth2token.py
index c1fcafb..72e20cc 100644
--- a/gn3/auth/authentication/oauth2/models/oauth2token.py
+++ b/gn3/auth/authentication/oauth2/models/oauth2token.py
@@ -8,6 +8,8 @@ from pymonad.maybe import Just, Maybe, Nothing
from gn3.auth import db
from gn3.auth.authentication.users import User, user_by_id
+from gn3.auth.authorisation.errors import NotFoundError
+
from .oauth2client import client, OAuth2Client
class OAuth2Token(NamedTuple):
@@ -50,11 +52,13 @@ class OAuth2Token(NamedTuple):
def __token_from_resultset__(conn: db.DbConnection, rset) -> Maybe:
__identity__ = lambda val: val
- the_user = user_by_id(conn, uuid.UUID(rset["user_id"]))
- the_client = client(conn, uuid.UUID(rset["client_id"]),
- the_user.maybe(None, __identity__))
+ try:
+ the_user = user_by_id(conn, uuid.UUID(rset["user_id"]))
+ except NotFoundError as _nfe:
+ the_user = None
+ the_client = client(conn, uuid.UUID(rset["client_id"]), the_user)
- if the_client.is_just() and the_user.is_just():
+ if the_client.is_just() and bool(the_user):
return Just(OAuth2Token(token_id=uuid.UUID(rset["token_id"]),
client=the_client.maybe(None, __identity__),
token_type=rset["token_type"],
@@ -65,7 +69,7 @@ def __token_from_resultset__(conn: db.DbConnection, rset) -> Maybe:
issued_at=datetime.datetime.fromtimestamp(
rset["issued_at"]),
expires_in=rset["expires_in"],
- user=the_user.maybe(None, __identity__)))
+ user=the_user))# type: ignore[arg-type]
return Nothing
diff --git a/gn3/auth/authentication/users.py b/gn3/auth/authentication/users.py
index e65938e..54838a3 100644
--- a/gn3/auth/authentication/users.py
+++ b/gn3/auth/authentication/users.py
@@ -3,7 +3,6 @@ from uuid import UUID, uuid4
from typing import Any, Tuple, NamedTuple
import bcrypt
-from pymonad.maybe import Just, Maybe, Nothing
from gn3.auth import db
from gn3.auth.authorisation.errors import NotFoundError
@@ -37,16 +36,16 @@ def user_by_email(conn: db.DbConnection, email: str) -> User:
raise NotFoundError(f"Could not find user with email {email}")
-def user_by_id(conn: db.DbConnection, user_id: UUID) -> Maybe:
+def user_by_id(conn: db.DbConnection, user_id: UUID) -> User:
"""Retrieve user from database by their user id"""
with db.cursor(conn) as cursor:
cursor.execute("SELECT * FROM users WHERE user_id=?", (str(user_id),))
row = cursor.fetchone()
if row:
- return Just(User(UUID(row["user_id"]), row["email"], row["name"]))
+ return User(UUID(row["user_id"]), row["email"], row["name"])
- return Nothing
+ raise NotFoundError(f"Could not find user with ID {user_id}")
def valid_login(conn: db.DbConnection, user: User, password: str) -> bool:
"""Check the validity of the provided credentials for login."""