From b9a0a5d4b64c65c9472b253db7a28cb91328401f Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Wed, 31 Jul 2024 15:27:59 -0500 Subject: Synchronise token refreshes The application can be run in a multi-threaded server, leading to a situation where the multiple threads attempt to get a new JWT using the exact same refresh token. This synchronises the various threads ensuring only a single thread is able to retrieve the new JWT that all the rest of the threads then use. --- gn2/wqflask/oauth2/client.py | 37 +++++++++++++++++++++++++++++++++---- gn2/wqflask/oauth2/session.py | 24 +++++++++++++++++++++++- 2 files changed, 56 insertions(+), 5 deletions(-) (limited to 'gn2/wqflask/oauth2') diff --git a/gn2/wqflask/oauth2/client.py b/gn2/wqflask/oauth2/client.py index 770777b5..0d4615e8 100644 --- a/gn2/wqflask/oauth2/client.py +++ b/gn2/wqflask/oauth2/client.py @@ -1,5 +1,7 @@ """Common oauth2 client utilities.""" import json +import time +import random import requests from typing import Optional from urllib.parse import urljoin @@ -38,10 +40,36 @@ def oauth2_client(): def __update_token__(token, refresh_token=None, access_token=None): """Update the token when refreshed.""" session.set_user_token(token) + return token - def __client__(token) -> OAuth2Session: + def __validate_token__(token): _jwt = jwt.decode(token["access_token"], app.config["AUTH_SERVER_SSL_PUBLIC_KEY"]) + return token + + def __delay__(): + """Do a tiny delay.""" + time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100)))) + + def __refresh_token__(token): + """Synchronise token refresh.""" + if session.is_token_expired(): + __delay__() + if session.is_token_refreshing(): + while session.is_token_refreshing(): + __delay__() + _token = session.user_token().either(None, lambda _tok: _tok) + return _token + + session.toggle_token_refreshing() + _client = __client__(token) + _client.get(urljoin(authserver_uri(), "auth/user/")) + session.toggle_token_refreshing() + return _client.token + + return token + + def __client__(token) -> OAuth2Session: client = OAuth2Session( oauth2_clientid(), oauth2_clientsecret(), @@ -51,9 +79,10 @@ def oauth2_client(): token=token, update_token=__update_token__) return client - return session.user_token().either( - lambda _notok: __client__(None), - lambda token: __client__(token)) + return session.user_token().then(__validate_token__).then( + __refresh_token__).either( + lambda _notok: __client__(None), + lambda token: __client__(token)) def __no_token__(_err) -> Left: """Handle situation where request is attempted with no token.""" diff --git a/gn2/wqflask/oauth2/session.py b/gn2/wqflask/oauth2/session.py index eec48a7f..92181ccf 100644 --- a/gn2/wqflask/oauth2/session.py +++ b/gn2/wqflask/oauth2/session.py @@ -22,6 +22,7 @@ class SessionInfo(TypedDict): user_agent: str ip_addr: str masquerade: Optional[UserDetails] + refreshing_token: bool __SESSION_KEY__ = "GN::2::session_info" # Do not use this outside this module!! @@ -61,7 +62,8 @@ def session_info() -> SessionInfo: "user_agent": request.headers.get("User-Agent"), "ip_addr": request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr), - "masquerading": None + "masquerading": None, + "token_refreshing": False })) @@ -102,3 +104,23 @@ def unset_masquerading(): "user": the_session["masquerading"], "masquerading": None }) + + +def toggle_token_refreshing(): + """Toggle the state of the token_refreshing variable.""" + _session = session_info() + return save_session_info({ + **_session, + "token_refreshing": not _session.get("token_refreshing", False)}) + + +def is_token_expired(): + """Check whether the token is expired.""" + return user_token().either( + lambda _no_token: False, + lambda token: datetime.now().timestamp() > token["expires_at"]) + + +def is_token_refreshing(): + """Returns whether the token is being refreshed or not.""" + return session_info().get("token_refreshing", False) -- cgit v1.2.3