diff options
-rw-r--r-- | gn2/wqflask/oauth2/checks.py | 24 | ||||
-rw-r--r-- | gn2/wqflask/oauth2/session.py | 8 | ||||
-rw-r--r-- | gn2/wqflask/oauth2/toplevel.py | 3 |
3 files changed, 26 insertions, 9 deletions
diff --git a/gn2/wqflask/oauth2/checks.py b/gn2/wqflask/oauth2/checks.py index 4a5a117f..38e7e22f 100644 --- a/gn2/wqflask/oauth2/checks.py +++ b/gn2/wqflask/oauth2/checks.py @@ -2,8 +2,9 @@ from functools import wraps from urllib.parse import urljoin -from flask import flash, request, redirect +from flask import flash, request, redirect, url_for from authlib.integrations.requests_client import OAuth2Session +from werkzeug.routing import BuildError from . import session from .client import ( @@ -12,6 +13,7 @@ from .client import ( authserver_uri, oauth2_clientid, oauth2_clientsecret) +from .request_utils import authserver_authorise_uri def require_oauth2(func): @@ -19,11 +21,17 @@ def require_oauth2(func): @wraps(func) def __token_valid__(*args, **kwargs): """Check that the user is logged in and their token is valid.""" - - def __clear_session__(_no_token): - session.clear_session_info() - flash("You need to be logged in.", "alert-warning") - return redirect("/") + def __redirect_to_login__(_token): + """ + Save the current user request to session then + redirect to the login page. + """ + try: + redirect_url = url_for(request.endpoint, _method="GET", **request.args) + except BuildError: + redirect_url = "/" + session.set_redirect_url(redirect_url) + return redirect(authserver_authorise_uri()) def __with_token__(token): resp = oauth2_client().get( @@ -32,9 +40,9 @@ def require_oauth2(func): if not user_details.get("error", False): return func(*args, **kwargs) - return __clear_session__(token) + return __redirect_to_login__(token) - return session.user_token().either(__clear_session__, __with_token__) + return session.user_token().either(__redirect_to_login__, __with_token__) return __token_valid__ diff --git a/gn2/wqflask/oauth2/session.py b/gn2/wqflask/oauth2/session.py index b91534b0..78c766a8 100644 --- a/gn2/wqflask/oauth2/session.py +++ b/gn2/wqflask/oauth2/session.py @@ -24,6 +24,7 @@ class SessionInfo(TypedDict): masquerade: Optional[UserDetails] refreshing_token: bool auth_server_jwks: Optional[dict[str, Any]] + redirect_url: Optional[str] __SESSION_KEY__ = "GN::2::session_info" # Do not use this outside this module!! @@ -118,3 +119,10 @@ def toggle_token_refreshing(): def is_token_refreshing(): """Returns whether the token is being refreshed or not.""" return session_info().get("token_refreshing", False) + +def set_redirect_url(url): + """Save the current endpoint object""" + return save_session_info({ + **session_info(), + "redirect_url": url + }) diff --git a/gn2/wqflask/oauth2/toplevel.py b/gn2/wqflask/oauth2/toplevel.py index 24d60311..425c598e 100644 --- a/gn2/wqflask/oauth2/toplevel.py +++ b/gn2/wqflask/oauth2/toplevel.py @@ -81,7 +81,8 @@ def authorisation_code(): "token": session.user_token(), "logged_in": True }) - return redirect("/") + redirect_url = session.session_info().get("redirect_url", "/") + return redirect(redirect_url) return no_token_post("auth/token", json=request_data).either( lambda err: __error__(process_error(err)), __success__) |