about summary refs log tree commit diff
path: root/gn3/auth/authentication
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/auth/authentication')
-rw-r--r--gn3/auth/authentication/oauth2/models/oauth2client.py43
-rw-r--r--gn3/auth/authentication/users.py12
2 files changed, 44 insertions, 11 deletions
diff --git a/gn3/auth/authentication/oauth2/models/oauth2client.py b/gn3/auth/authentication/oauth2/models/oauth2client.py
index b7d37be..da20200 100644
--- a/gn3/auth/authentication/oauth2/models/oauth2client.py
+++ b/gn3/auth/authentication/oauth2/models/oauth2client.py
@@ -7,7 +7,7 @@ from typing import Sequence, Optional, NamedTuple
 from pymonad.maybe import Just, Maybe, Nothing
 
 from gn3.auth import db
-from gn3.auth.authentication.users import User, user_by_id
+from gn3.auth.authentication.users import User, user_by_id, same_password
 
 from gn3.auth.authorisation.errors import NotFoundError
 
@@ -161,16 +161,45 @@ def client_by_id_and_secret(conn: db.DbConnection, client_id: uuid.UUID,
     """Retrieve a client by its ID and secret"""
     with db.cursor(conn) as cursor:
         cursor.execute(
-            "SELECT * FROM oauth2_clients WHERE client_id=? AND "
-            "client_secret=?",
-            (str(client_id), client_secret))
+            "SELECT * FROM oauth2_clients WHERE client_id=?",
+            (str(client_id),))
         row = cursor.fetchone()
-        if bool(row):
+        if bool(row) and same_password(client_secret, row["client_secret"]):
             return OAuth2Client(
                 client_id, client_secret,
                 datetime.datetime.fromtimestamp(row["client_id_issued_at"]),
-                datetime.datetime.fromtimestamp(row["client_secret_expires_at"]),
+                datetime.datetime.fromtimestamp(
+                    row["client_secret_expires_at"]),
                 json.loads(row["client_metadata"]),
                 user_by_id(conn, uuid.UUID(row["user_id"])))
 
-        raise NotFoundError(f"Could not find client with ID '{client_id}'")
+        raise NotFoundError("Could not find client with the given credentials.")
+
+def save_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client:
+    """Persist the client details into the database."""
+    with db.cursor(conn) as cursor:
+        query = (
+            "INSERT INTO oauth2_clients "
+            "(client_id, client_secret, client_id_issued_at, "
+            "client_secret_expires_at, client_metadata, user_id) "
+            "VALUES "
+            "(:client_id, :client_secret, :client_id_issued_at, "
+            ":client_secret_expires_at, :client_metadata, :user_id) "
+            "ON CONFLICT (client_id) DO UPDATE SET "
+            "client_secret=:client_secret, "
+            "client_id_issued_at=:client_id_issued_at, "
+            "client_secret_expires_at=:client_secret_expires_at, "
+            "client_metadata=:client_metadata, user_id=:user_id")
+        cursor.execute(
+            query,
+            {
+                "client_id": str(the_client.client_id),
+                "client_secret": the_client.client_secret,
+                "client_id_issued_at": (
+                    the_client.client_id_issued_at.timestamp()),
+                "client_secret_expires_at": (
+                    the_client.client_secret_expires_at.timestamp()),
+                "client_metadata": json.dumps(the_client.client_metadata),
+                "user_id": str(the_client.user.user_id)
+            })
+        return the_client
diff --git a/gn3/auth/authentication/users.py b/gn3/auth/authentication/users.py
index 17e89ae..8b4f115 100644
--- a/gn3/auth/authentication/users.py
+++ b/gn3/auth/authentication/users.py
@@ -48,6 +48,13 @@ def user_by_id(conn: db.DbConnection, user_id: UUID) -> User:
 
     raise NotFoundError(f"Could not find user with ID {user_id}")
 
+def same_password(password: str, hashed: str) -> bool:
+    """Check that `raw_password` is hashed to `hash`"""
+    try:
+        return hasher().verify(hashed, password)
+    except VerifyMismatchError as _vme:
+        return False
+
 def valid_login(conn: db.DbConnection, user: User, password: str) -> bool:
     """Check the validity of the provided credentials for login."""
     with db.cursor(conn) as cursor:
@@ -61,10 +68,7 @@ def valid_login(conn: db.DbConnection, user: User, password: str) -> bool:
     if row is None:
         return False
 
-    try:
-        return hasher().verify(row["password"], password)
-    except VerifyMismatchError as _vme:
-        return False
+    return same_password(password, row["password"])
 
 def save_user(cursor: db.DbCursor, email: str, name: str) -> User:
     """