aboutsummaryrefslogtreecommitdiff
path: root/gn3/auth/authentication
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/auth/authentication')
-rw-r--r--gn3/auth/authentication/oauth2/server.py8
-rw-r--r--gn3/auth/authentication/oauth2/views.py73
2 files changed, 46 insertions, 35 deletions
diff --git a/gn3/auth/authentication/oauth2/server.py b/gn3/auth/authentication/oauth2/server.py
index e9946b4..7d7113a 100644
--- a/gn3/auth/authentication/oauth2/server.py
+++ b/gn3/auth/authentication/oauth2/server.py
@@ -4,6 +4,7 @@ import datetime
from typing import Callable
from flask import Flask, current_app
+from authlib.oauth2.rfc6749.errors import InvalidClientError
from authlib.integrations.flask_oauth2 import AuthorizationServer
# from authlib.oauth2.rfc7636 import CodeChallenge
@@ -24,7 +25,12 @@ def create_query_client_func() -> Callable:
# use current_app rather than passing the db_uri to avoid issues
# when config changes, e.g. while testing.
with db.connection(current_app.config["AUTH_DB"]) as conn:
- return client(conn, client_id).maybe(None, lambda clt: clt) # type: ignore[misc]
+ the_client = client(conn, client_id).maybe(
+ None, lambda clt: clt) # type: ignore[misc]
+ if bool(the_client):
+ return the_client
+ raise InvalidClientError(
+ "No client found for the given CLIENT_ID and CLIENT_SECRET.")
return __query_client__
diff --git a/gn3/auth/authentication/oauth2/views.py b/gn3/auth/authentication/oauth2/views.py
index e096002..f281295 100644
--- a/gn3/auth/authentication/oauth2/views.py
+++ b/gn3/auth/authentication/oauth2/views.py
@@ -2,6 +2,7 @@
import uuid
import traceback
+from authlib.oauth2.rfc6749.errors import InvalidClientError
from email_validator import validate_email, EmailNotValidError
from flask import (
flash,
@@ -38,42 +39,46 @@ def delete_client(client_id: uuid.UUID):
@auth.route("/authorise", methods=["GET", "POST"])
def authorise():
"""Authorise a user"""
- server = app.config["OAUTH2_SERVER"]
- client_id = uuid.UUID(request.args.get("client_id", str(uuid.uuid4())))
- client = server.query_client(client_id)
- if not bool(client):
- flash("Invalid OAuth2 client.", "alert-error")
- if request.method == "GET":
- client = server.query_client(request.args.get("client_id"))
- return render_template(
- "oauth2/authorise-user.html",
- client=client,
- scope=client.scope,
- response_type="code")
+ try:
+ server = app.config["OAUTH2_SERVER"]
+ client_id = uuid.UUID(request.args.get("client_id", str(uuid.uuid4())))
+ client = server.query_client(client_id)
+ if not bool(client):
+ flash("Invalid OAuth2 client.", "alert-error")
+ if request.method == "GET":
+ client = server.query_client(request.args.get("client_id"))
+ return render_template(
+ "oauth2/authorise-user.html",
+ client=client,
+ scope=client.scope,
+ response_type="code")
- form = request.form
- def __authorise__(conn: db.DbConnection) -> Response:
- email_passwd_msg = "Email or password is invalid!"
- redirect_response = redirect(url_for("oauth2.auth.authorise",
- client_id=client_id))
- try:
- email = validate_email(
- form.get("user:email"), check_deliverability=False)
- user = user_by_email(conn, email["email"])
- if valid_login(conn, user, form.get("user:password", "")):
- return server.create_authorization_response(request=request, grant_user=user)
- flash(email_passwd_msg, "alert-error")
- return redirect_response # type: ignore[return-value]
- except EmailNotValidError as _enve:
- app.logger.debug(traceback.format_exc())
- flash(email_passwd_msg, "alert-error")
- return redirect_response # type: ignore[return-value]
- except NotFoundError as _nfe:
- app.logger.debug(traceback.format_exc())
- flash(email_passwd_msg, "alert-error")
- return redirect_response # type: ignore[return-value]
+ form = request.form
+ def __authorise__(conn: db.DbConnection) -> Response:
+ email_passwd_msg = "Email or password is invalid!"
+ redirect_response = redirect(url_for("oauth2.auth.authorise",
+ client_id=client_id))
+ try:
+ email = validate_email(
+ form.get("user:email"), check_deliverability=False)
+ user = user_by_email(conn, email["email"])
+ if valid_login(conn, user, form.get("user:password", "")):
+ return server.create_authorization_response(request=request, grant_user=user)
+ flash(email_passwd_msg, "alert-error")
+ return redirect_response # type: ignore[return-value]
+ except EmailNotValidError as _enve:
+ app.logger.debug(traceback.format_exc())
+ flash(email_passwd_msg, "alert-error")
+ return redirect_response # type: ignore[return-value]
+ except NotFoundError as _nfe:
+ app.logger.debug(traceback.format_exc())
+ flash(email_passwd_msg, "alert-error")
+ return redirect_response # type: ignore[return-value]
- return with_db_connection(__authorise__)
+ return with_db_connection(__authorise__)
+ except InvalidClientError as ice:
+ return render_template(
+ "oauth2/oauth2_error.html", error=ice), ice.status_code
@auth.route("/token", methods=["POST"])
def token():