about summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlexander Kabui2024-09-19 12:14:14 +0300
committerGitHub2024-09-19 12:14:14 +0300
commitadbd8b1a2fad59c95beb2d7ec2d2dec6165bb8d5 (patch)
treeac8a185adb4057627d8935b8dc1c72d5b3d13a2c
parentd72aa3b0ef39cef58a38a588cb2d07762a7ca424 (diff)
parent3f62d5d3e86f3b52488fb50ec00c88f77d03ab27 (diff)
downloadgenenetwork2-adbd8b1a2fad59c95beb2d7ec2d2dec6165bb8d5.tar.gz
Merge pull request #875 from genenetwork/chores/implement-redirect-for-gn2-auth
Chores/implement redirect for users 
-rw-r--r--gn2/wqflask/oauth2/checks.py24
-rw-r--r--gn2/wqflask/oauth2/session.py8
-rw-r--r--gn2/wqflask/oauth2/toplevel.py3
3 files changed, 26 insertions, 9 deletions
diff --git a/gn2/wqflask/oauth2/checks.py b/gn2/wqflask/oauth2/checks.py
index 4a5a117f..38e7e22f 100644
--- a/gn2/wqflask/oauth2/checks.py
+++ b/gn2/wqflask/oauth2/checks.py
@@ -2,8 +2,9 @@
 from functools import wraps
 from urllib.parse import urljoin
 
-from flask import flash, request, redirect
+from flask import flash, request, redirect, url_for
 from authlib.integrations.requests_client import OAuth2Session
+from werkzeug.routing import BuildError
 
 from . import session
 from .client import (
@@ -12,6 +13,7 @@ from .client import (
     authserver_uri,
     oauth2_clientid,
     oauth2_clientsecret)
+from .request_utils import authserver_authorise_uri
 
 
 def require_oauth2(func):
@@ -19,11 +21,17 @@ def require_oauth2(func):
     @wraps(func)
     def __token_valid__(*args, **kwargs):
         """Check that the user is logged in and their token is valid."""
-
-        def __clear_session__(_no_token):
-            session.clear_session_info()
-            flash("You need to be logged in.", "alert-warning")
-            return redirect("/")
+        def __redirect_to_login__(_token):
+            """
+            Save the current user request to session then
+            redirect to the login page.
+            """
+            try:
+                redirect_url = url_for(request.endpoint, _method="GET", **request.args)
+            except BuildError:
+                redirect_url = "/"
+            session.set_redirect_url(redirect_url)
+            return redirect(authserver_authorise_uri())
 
         def __with_token__(token):
             resp = oauth2_client().get(
@@ -32,9 +40,9 @@ def require_oauth2(func):
             if not user_details.get("error", False):
                 return func(*args, **kwargs)
 
-            return __clear_session__(token)
+            return __redirect_to_login__(token)
 
-        return session.user_token().either(__clear_session__, __with_token__)
+        return session.user_token().either(__redirect_to_login__, __with_token__)
 
     return __token_valid__
 
diff --git a/gn2/wqflask/oauth2/session.py b/gn2/wqflask/oauth2/session.py
index b91534b0..78c766a8 100644
--- a/gn2/wqflask/oauth2/session.py
+++ b/gn2/wqflask/oauth2/session.py
@@ -24,6 +24,7 @@ class SessionInfo(TypedDict):
     masquerade: Optional[UserDetails]
     refreshing_token: bool
     auth_server_jwks: Optional[dict[str, Any]]
+    redirect_url: Optional[str]
 
 __SESSION_KEY__ = "GN::2::session_info" # Do not use this outside this module!!
 
@@ -118,3 +119,10 @@ def toggle_token_refreshing():
 def is_token_refreshing():
     """Returns whether the token is being refreshed or not."""
     return session_info().get("token_refreshing", False)
+
+def set_redirect_url(url):
+    """Save the current endpoint object"""
+    return save_session_info({
+        **session_info(),
+        "redirect_url": url
+    })
diff --git a/gn2/wqflask/oauth2/toplevel.py b/gn2/wqflask/oauth2/toplevel.py
index 24d60311..425c598e 100644
--- a/gn2/wqflask/oauth2/toplevel.py
+++ b/gn2/wqflask/oauth2/toplevel.py
@@ -81,7 +81,8 @@ def authorisation_code():
                 "token": session.user_token(),
                 "logged_in": True
             })
-            return redirect("/")
+            redirect_url = session.session_info().get("redirect_url", "/")
+            return redirect(redirect_url)
 
         return no_token_post("auth/token", json=request_data).either(
             lambda err: __error__(process_error(err)), __success__)