diff options
author | Frederick Muriuki Muriithi | 2024-09-05 10:26:14 -0500 |
---|---|---|
committer | Frederick Muriuki Muriithi | 2024-09-05 10:26:29 -0500 |
commit | f5b60f2909c6683345ae0f1070e84e40c41af5ad (patch) | |
tree | b562c455ffb4567f347543f7f54a92da15acb369 /uploader | |
parent | 40dbb863ffb99721d4b736eb237ce94f15e61e48 (diff) | |
download | gn-uploader-f5b60f2909c6683345ae0f1070e84e40c41af5ad.tar.gz |
Synchronise token refreshing.
When running flask with multiple threads/workers, as happens when
using gunicorn, there is a potential for more than one thread running
with an expired token, leading to multiple uncoordinated token
refreshes.
This commit coordinates the threads in the case there is need to
refresh a token, ensuring only one thread does the token refresh.
Diffstat (limited to 'uploader')
-rw-r--r-- | uploader/oauth2/client.py | 43 | ||||
-rw-r--r-- | uploader/session.py | 13 |
2 files changed, 54 insertions, 2 deletions
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py index 8ad8f7c..e119cc3 100644 --- a/uploader/oauth2/client.py +++ b/uploader/oauth2/client.py @@ -1,5 +1,7 @@ """OAuth2 client utilities.""" import json +import time +import random from datetime import datetime, timedelta from urllib.parse import urljoin, urlparse @@ -9,7 +11,8 @@ from flask import request, current_app as app from pymonad.either import Left, Right, Either from authlib.common.urls import url_decode -from authlib.jose import KeySet, JsonWebKey +from authlib.jose.errors import BadSignatureError +from authlib.jose import KeySet, JsonWebKey, JsonWebToken from authlib.integrations.requests_client import OAuth2Session from uploader import session @@ -101,7 +104,43 @@ def oauth2_client(): ("client_secret_post", __json_auth__)) return client - return session.user_token().either( + def __token_expired__(token): + """Check whether the token has expired.""" + jwks = auth_server_jwks() + if bool(jwks): + for jwk in jwks.keys: + try: + jwt = JsonWebToken(["RS256"]).decode( + token["access_token"], key=jwk) + return datetime.now().timestamp() > jwt["exp"] + except BadSignatureError as _bse: + pass + + return False + + def __delay__(): + """Do a tiny delay.""" + time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100)))) + + def __refresh_token__(token): + """Refresh the token if necessary — synchronise amongst threads.""" + if __token_expired__(token): + __delay__() + if session.is_token_refreshing(): + while session.is_token_refreshing(): + __delay__() + + return session.user_token().either(None, lambda _tok: _tok) + + session.toggle_token_refreshing() + _client = __client__(token) + _client.get(urljoin(authserver_uri(), "auth/user/")) + session.toggle_token_refreshing() + return _client.token + + return token + + return session.user_token().then(__refresh_token__).either( lambda _notok: __client__(None), __client__) diff --git a/uploader/session.py b/uploader/session.py index 019d959..399f28c 100644 --- a/uploader/session.py +++ b/uploader/session.py @@ -103,3 +103,16 @@ def set_auth_server_jwks(keyset: KeySet) -> KeySet: } }) return keyset + + +def toggle_token_refreshing(): + """Toggle the state of the token_refreshing variable.""" + _session = session_info() + return save_session_info({ + **_session, + "token_refreshing": not _session.get("token_refreshing", False)}) + + +def is_token_refreshing(): + """Returns whether the token is being refreshed or not.""" + return session_info().get("token_refreshing", False) |