about summary refs log tree commit diff
path: root/gn2/wqflask/oauth2/checks.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn2/wqflask/oauth2/checks.py')
-rw-r--r--gn2/wqflask/oauth2/checks.py24
1 files changed, 16 insertions, 8 deletions
diff --git a/gn2/wqflask/oauth2/checks.py b/gn2/wqflask/oauth2/checks.py
index 4a5a117f..38e7e22f 100644
--- a/gn2/wqflask/oauth2/checks.py
+++ b/gn2/wqflask/oauth2/checks.py
@@ -2,8 +2,9 @@
 from functools import wraps
 from urllib.parse import urljoin
 
-from flask import flash, request, redirect
+from flask import flash, request, redirect, url_for
 from authlib.integrations.requests_client import OAuth2Session
+from werkzeug.routing import BuildError
 
 from . import session
 from .client import (
@@ -12,6 +13,7 @@ from .client import (
     authserver_uri,
     oauth2_clientid,
     oauth2_clientsecret)
+from .request_utils import authserver_authorise_uri
 
 
 def require_oauth2(func):
@@ -19,11 +21,17 @@ def require_oauth2(func):
     @wraps(func)
     def __token_valid__(*args, **kwargs):
         """Check that the user is logged in and their token is valid."""
-
-        def __clear_session__(_no_token):
-            session.clear_session_info()
-            flash("You need to be logged in.", "alert-warning")
-            return redirect("/")
+        def __redirect_to_login__(_token):
+            """
+            Save the current user request to session then
+            redirect to the login page.
+            """
+            try:
+                redirect_url = url_for(request.endpoint, _method="GET", **request.args)
+            except BuildError:
+                redirect_url = "/"
+            session.set_redirect_url(redirect_url)
+            return redirect(authserver_authorise_uri())
 
         def __with_token__(token):
             resp = oauth2_client().get(
@@ -32,9 +40,9 @@ def require_oauth2(func):
             if not user_details.get("error", False):
                 return func(*args, **kwargs)
 
-            return __clear_session__(token)
+            return __redirect_to_login__(token)
 
-        return session.user_token().either(__clear_session__, __with_token__)
+        return session.user_token().either(__redirect_to_login__, __with_token__)
 
     return __token_valid__