about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--uploader/oauth2/client.py43
-rw-r--r--uploader/session.py13
2 files changed, 54 insertions, 2 deletions
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py
index 8ad8f7c..e119cc3 100644
--- a/uploader/oauth2/client.py
+++ b/uploader/oauth2/client.py
@@ -1,5 +1,7 @@
 """OAuth2 client utilities."""
 import json
+import time
+import random
 from datetime import datetime, timedelta
 from urllib.parse import urljoin, urlparse
 
@@ -9,7 +11,8 @@ from flask import request, current_app as app
 from pymonad.either import Left, Right, Either
 
 from authlib.common.urls import url_decode
-from authlib.jose import KeySet, JsonWebKey
+from authlib.jose.errors import BadSignatureError
+from authlib.jose import KeySet, JsonWebKey, JsonWebToken
 from authlib.integrations.requests_client import OAuth2Session
 
 from uploader import session
@@ -101,7 +104,43 @@ def oauth2_client():
             ("client_secret_post", __json_auth__))
         return client
 
-    return session.user_token().either(
+    def __token_expired__(token):
+        """Check whether the token has expired."""
+        jwks = auth_server_jwks()
+        if bool(jwks):
+            for jwk in jwks.keys:
+                try:
+                    jwt = JsonWebToken(["RS256"]).decode(
+                        token["access_token"], key=jwk)
+                    return datetime.now().timestamp() > jwt["exp"]
+                except BadSignatureError as _bse:
+                    pass
+
+        return False
+
+    def __delay__():
+        """Do a tiny delay."""
+        time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
+
+    def __refresh_token__(token):
+        """Refresh the token if necessary — synchronise amongst threads."""
+        if __token_expired__(token):
+            __delay__()
+            if session.is_token_refreshing():
+                while session.is_token_refreshing():
+                    __delay__()
+
+                return session.user_token().either(None, lambda _tok: _tok)
+
+            session.toggle_token_refreshing()
+            _client = __client__(token)
+            _client.get(urljoin(authserver_uri(), "auth/user/"))
+            session.toggle_token_refreshing()
+            return _client.token
+
+        return token
+
+    return session.user_token().then(__refresh_token__).either(
         lambda _notok: __client__(None),
         __client__)
 
diff --git a/uploader/session.py b/uploader/session.py
index 019d959..399f28c 100644
--- a/uploader/session.py
+++ b/uploader/session.py
@@ -103,3 +103,16 @@ def set_auth_server_jwks(keyset: KeySet) -> KeySet:
         }
     })
     return keyset
+
+
+def toggle_token_refreshing():
+    """Toggle the state of the token_refreshing variable."""
+    _session = session_info()
+    return save_session_info({
+        **_session,
+        "token_refreshing": not _session.get("token_refreshing", False)})
+
+
+def is_token_refreshing():
+    """Returns whether the token is being refreshed or not."""
+    return session_info().get("token_refreshing", False)