aboutsummaryrefslogtreecommitdiff
path: root/wqflask
diff options
context:
space:
mode:
Diffstat (limited to 'wqflask')
-rw-r--r--wqflask/wqflask/oauth2/checks.py34
-rw-r--r--wqflask/wqflask/oauth2/routes.py6
2 files changed, 36 insertions, 4 deletions
diff --git a/wqflask/wqflask/oauth2/checks.py b/wqflask/wqflask/oauth2/checks.py
new file mode 100644
index 00000000..a2cf9ed4
--- /dev/null
+++ b/wqflask/wqflask/oauth2/checks.py
@@ -0,0 +1,34 @@
+"""Various checkers for OAuth2"""
+from functools import wraps
+from urllib.parse import urljoin
+
+from authlib.integrations.requests_client import OAuth2Session
+from flask import flash, request, session, url_for, redirect, current_app
+
+def user_logged_in():
+ """Check whether the user has logged in."""
+ return bool(session.get("oauth2_token", False))
+
+def require_oauth2(func):
+ """Decorator for ensuring user is logged in."""
+ @wraps(func)
+ def __token_valid__(*args, **kwargs):
+ """Check that the user is logged in and their token is valid."""
+ if user_logged_in():
+ config = current_app.config
+ client = OAuth2Session(
+ config["OAUTH2_CLIENT_ID"], config["OAUTH2_CLIENT_SECRET"],
+ token=session["oauth2_token"])
+ resp = client.get(
+ urljoin(config["GN_SERVER_URL"], "oauth2/user"))
+ user_details = resp.json()
+ if not user_details.get("error", False):
+ return func(*args, **kwargs)
+
+ session.pop("oauth2_token", None)
+ session.pop("user_details", None)
+
+ flash("You need to be logged in.", "alert-warning")
+ return redirect(url_for("oauth2.login", next=request.endpoint))
+
+ return __token_valid__
diff --git a/wqflask/wqflask/oauth2/routes.py b/wqflask/wqflask/oauth2/routes.py
index 931b8b61..a72501c4 100644
--- a/wqflask/wqflask/oauth2/routes.py
+++ b/wqflask/wqflask/oauth2/routes.py
@@ -9,11 +9,9 @@ from flask import (
flash, request, session, redirect, Blueprint, render_template,
current_app as app)
-oauth2 = Blueprint("oauth2", __name__)
+from .checks import require_oauth2, user_logged_in
-def user_logged_in():
- """Check whether the user has logged in."""
- return bool(session.get("oauth2_token", False))
+oauth2 = Blueprint("oauth2", __name__)
@oauth2.route("/login", methods=["GET", "POST"])
def login():