about summary refs log tree commit diff
path: root/gn_auth/auth/authentication
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth/auth/authentication')
-rw-r--r--gn_auth/auth/authentication/oauth2/models/oauth2client.py5
-rw-r--r--gn_auth/auth/authentication/oauth2/views.py8
-rw-r--r--gn_auth/auth/authentication/users.py4
3 files changed, 10 insertions, 7 deletions
diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
index 1639e2e..fe12ff9 100644
--- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py
+++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
@@ -2,6 +2,7 @@
 import json
 import datetime
 from uuid import UUID
+from urllib.parse import urlparse
 from functools import cached_property
 from dataclasses import asdict, dataclass
 from typing import Any, Sequence, Optional
@@ -135,7 +136,9 @@ class OAuth2Client(ClientMixin):
         """
         Check whether the given `redirect_uri` is one of the expected ones.
         """
-        return redirect_uri in self.redirect_uris
+        uri = urlparse(redirect_uri)._replace(
+            query="")._replace(fragment="").geturl()
+        return uri in self.redirect_uris
 
     @cached_property
     def response_types(self) -> Sequence[str]:
diff --git a/gn_auth/auth/authentication/oauth2/views.py b/gn_auth/auth/authentication/oauth2/views.py
index 0e2c4eb..6c3de51 100644
--- a/gn_auth/auth/authentication/oauth2/views.py
+++ b/gn_auth/auth/authentication/oauth2/views.py
@@ -44,7 +44,7 @@ def authorise():
                               or str(uuid.uuid4()))
         client = server.query_client(client_id)
         if not bool(client):
-            flash("Invalid OAuth2 client.", "alert-danger")
+            flash("Invalid OAuth2 client.", "alert alert-danger")
 
         if request.method == "GET":
             def __forgot_password_table_exists__(conn):
@@ -88,15 +88,15 @@ def authorise():
                                     email=email["email"]),
                             code=307)
                     return server.create_authorization_response(request=request, grant_user=user)
-                flash(email_passwd_msg, "alert-danger")
+                flash(email_passwd_msg, "alert alert-danger")
                 return redirect_response # type: ignore[return-value]
             except EmailNotValidError as _enve:
                 app.logger.debug(traceback.format_exc())
-                flash(email_passwd_msg, "alert-danger")
+                flash(email_passwd_msg, "alert alert-danger")
                 return redirect_response # type: ignore[return-value]
             except NotFoundError as _nfe:
                 app.logger.debug(traceback.format_exc())
-                flash(email_passwd_msg, "alert-danger")
+                flash(email_passwd_msg, "alert alert-danger")
                 return redirect_response # type: ignore[return-value]
 
         return with_db_connection(__authorise__)
diff --git a/gn_auth/auth/authentication/users.py b/gn_auth/auth/authentication/users.py
index 140ce36..fded79f 100644
--- a/gn_auth/auth/authentication/users.py
+++ b/gn_auth/auth/authentication/users.py
@@ -1,6 +1,6 @@
 """User-specific code and data structures."""
 import datetime
-from typing import Tuple
+from typing import Tuple, Union
 from uuid import UUID, uuid4
 from dataclasses import dataclass
 
@@ -26,7 +26,7 @@ class User:
         return self.user_id
 
     @staticmethod
-    def from_sqlite3_row(row: sqlite3.Row):
+    def from_sqlite3_row(row: Union[sqlite3.Row, dict]):
         """Generate a user from a row in an SQLite3 resultset"""
         return User(user_id=UUID(row["user_id"]),
                     email=row["email"],