about summary refs log tree commit diff
path: root/gn_auth/auth/authentication/oauth2/views.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth/auth/authentication/oauth2/views.py')
-rw-r--r--gn_auth/auth/authentication/oauth2/views.py44
1 files changed, 35 insertions, 9 deletions
diff --git a/gn_auth/auth/authentication/oauth2/views.py b/gn_auth/auth/authentication/oauth2/views.py
index 22437a2..8cc123f 100644
--- a/gn_auth/auth/authentication/oauth2/views.py
+++ b/gn_auth/auth/authentication/oauth2/views.py
@@ -1,5 +1,6 @@
 """Endpoints for the oauth2 server"""
 import uuid
+import logging
 import traceback
 from urllib.parse import urlparse
 
@@ -9,6 +10,7 @@ from flask import (
     flash,
     request,
     url_for,
+    jsonify,
     redirect,
     Response,
     Blueprint,
@@ -17,6 +19,7 @@ from flask import (
 
 from gn_auth.auth.db import sqlite3 as db
 from gn_auth.auth.db.sqlite3 import with_db_connection
+from gn_auth.auth.jwks import jwks_directory, list_jwks
 from gn_auth.auth.errors import NotFoundError, ForbiddenAccess
 from gn_auth.auth.authentication.users import valid_login, user_by_email
 
@@ -25,8 +28,10 @@ from .endpoints.revocation import RevocationEndpoint
 from .endpoints.introspection import IntrospectionEndpoint
 
 
+logger = logging.getLogger(__name__)
 auth = Blueprint("auth", __name__)
 
+
 @auth.route("/delete-client/<uuid:client_id>", methods=["GET", "POST"])
 def delete_client(client_id: uuid.UUID):
     """Delete an OAuth2 client."""
@@ -42,9 +47,17 @@ def authorise():
                               or str(uuid.uuid4()))
         client = server.query_client(client_id)
         if not bool(client):
-            flash("Invalid OAuth2 client.", "alert-danger")
+            flash("Invalid OAuth2 client.", "alert alert-danger")
 
         if request.method == "GET":
+            def __forgot_password_table_exists__(conn):
+                with db.cursor(conn) as cursor:
+                    cursor.execute("SELECT name FROM sqlite_master "
+                                   "WHERE type='table' "
+                                   "AND name='forgot_password_tokens'")
+                    return bool(cursor.fetchone())
+                return False
+
             client = server.query_client(request.args.get("client_id"))
             _src = urlparse(request.args["redirect_uri"])
             return render_template(
@@ -53,7 +66,9 @@ def authorise():
                 scope=client.scope,
                 response_type=request.args["response_type"],
                 redirect_uri=request.args["redirect_uri"],
-                source_uri=f"{_src.scheme}://{_src.netloc}/")
+                source_uri=f"{_src.scheme}://{_src.netloc}/",
+                display_forgot_password=with_db_connection(
+                    __forgot_password_table_exists__))
 
         form = request.form
         def __authorise__(conn: db.DbConnection):
@@ -65,25 +80,26 @@ def authorise():
             try:
                 email = validate_email(
                     form.get("user:email"), check_deliverability=False)
-                user = user_by_email(conn, email["email"])
+                user = user_by_email(conn, email["email"])  # type: ignore
                 if valid_login(conn, user, form.get("user:password", "")):
                     if not user.verified:
                         return redirect(
                             url_for("oauth2.users.handle_unverified",
                                     response_type=form["response_type"],
                                     client_id=client_id,
-                                    redirect_uri=form["redirect_uri"]),
+                                    redirect_uri=form["redirect_uri"],
+                                    email=email["email"]),
                             code=307)
                     return server.create_authorization_response(request=request, grant_user=user)
-                flash(email_passwd_msg, "alert-danger")
+                flash(email_passwd_msg, "alert alert-danger")
                 return redirect_response # type: ignore[return-value]
             except EmailNotValidError as _enve:
-                app.logger.debug(traceback.format_exc())
-                flash(email_passwd_msg, "alert-danger")
+                logger.debug(traceback.format_exc())
+                flash(email_passwd_msg, "alert alert-danger")
                 return redirect_response # type: ignore[return-value]
             except NotFoundError as _nfe:
-                app.logger.debug(traceback.format_exc())
-                flash(email_passwd_msg, "alert-danger")
+                logger.debug(traceback.format_exc())
+                flash(email_passwd_msg, "alert alert-danger")
                 return redirect_response # type: ignore[return-value]
 
         return with_db_connection(__authorise__)
@@ -116,3 +132,13 @@ def introspect_token() -> Response:
                 IntrospectionEndpoint.ENDPOINT_NAME)
 
     raise ForbiddenAccess("You cannot access this endpoint")
+
+
+@auth.route("/public-jwks", methods=["GET"])
+def public_jwks():
+    """Provide the JWK public keys used by this application."""
+    return jsonify({
+        "documentation": (
+            "The keys are listed in order of creation, from the oldest (first) "
+            "to the newest (last)."),
+        "jwks": tuple(key.as_dict() for key in list_jwks(jwks_directory(app)))})