aboutsummaryrefslogtreecommitdiff
path: root/uploader/oauth2/client.py
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-09-05 10:26:14 -0500
committerFrederick Muriuki Muriithi2024-09-05 10:26:29 -0500
commitf5b60f2909c6683345ae0f1070e84e40c41af5ad (patch)
treeb562c455ffb4567f347543f7f54a92da15acb369 /uploader/oauth2/client.py
parent40dbb863ffb99721d4b736eb237ce94f15e61e48 (diff)
downloadgn-uploader-f5b60f2909c6683345ae0f1070e84e40c41af5ad.tar.gz
Synchronise token refreshing.
When running flask with multiple threads/workers, as happens when using gunicorn, there is a potential for more than one thread running with an expired token, leading to multiple uncoordinated token refreshes. This commit coordinates the threads in the case there is need to refresh a token, ensuring only one thread does the token refresh.
Diffstat (limited to 'uploader/oauth2/client.py')
-rw-r--r--uploader/oauth2/client.py43
1 files changed, 41 insertions, 2 deletions
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py
index 8ad8f7c..e119cc3 100644
--- a/uploader/oauth2/client.py
+++ b/uploader/oauth2/client.py
@@ -1,5 +1,7 @@
"""OAuth2 client utilities."""
import json
+import time
+import random
from datetime import datetime, timedelta
from urllib.parse import urljoin, urlparse
@@ -9,7 +11,8 @@ from flask import request, current_app as app
from pymonad.either import Left, Right, Either
from authlib.common.urls import url_decode
-from authlib.jose import KeySet, JsonWebKey
+from authlib.jose.errors import BadSignatureError
+from authlib.jose import KeySet, JsonWebKey, JsonWebToken
from authlib.integrations.requests_client import OAuth2Session
from uploader import session
@@ -101,7 +104,43 @@ def oauth2_client():
("client_secret_post", __json_auth__))
return client
- return session.user_token().either(
+ def __token_expired__(token):
+ """Check whether the token has expired."""
+ jwks = auth_server_jwks()
+ if bool(jwks):
+ for jwk in jwks.keys:
+ try:
+ jwt = JsonWebToken(["RS256"]).decode(
+ token["access_token"], key=jwk)
+ return datetime.now().timestamp() > jwt["exp"]
+ except BadSignatureError as _bse:
+ pass
+
+ return False
+
+ def __delay__():
+ """Do a tiny delay."""
+ time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
+
+ def __refresh_token__(token):
+ """Refresh the token if necessary — synchronise amongst threads."""
+ if __token_expired__(token):
+ __delay__()
+ if session.is_token_refreshing():
+ while session.is_token_refreshing():
+ __delay__()
+
+ return session.user_token().either(None, lambda _tok: _tok)
+
+ session.toggle_token_refreshing()
+ _client = __client__(token)
+ _client.get(urljoin(authserver_uri(), "auth/user/"))
+ session.toggle_token_refreshing()
+ return _client.token
+
+ return token
+
+ return session.user_token().then(__refresh_token__).either(
lambda _notok: __client__(None),
__client__)