about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-07-31 15:27:59 -0500
committerFrederick Muriuki Muriithi2024-07-31 15:27:59 -0500
commitb9a0a5d4b64c65c9472b253db7a28cb91328401f (patch)
tree137946c7121d189153f8a92237c8b8d66412cdef
parente693056b193c9c08bf0a0e99df4352cfeb83de2f (diff)
downloadgenenetwork2-b9a0a5d4b64c65c9472b253db7a28cb91328401f.tar.gz
Synchronise token refreshes
The application can be run in a multi-threaded server, leading to a
situation where the multiple threads attempt to get a new JWT using
the exact same refresh token.

This synchronises the various threads ensuring only a single thread is
able to retrieve the new JWT that all the rest of the threads then
use.
-rw-r--r--gn2/wqflask/oauth2/client.py37
-rw-r--r--gn2/wqflask/oauth2/session.py24
2 files changed, 56 insertions, 5 deletions
diff --git a/gn2/wqflask/oauth2/client.py b/gn2/wqflask/oauth2/client.py
index 770777b5..0d4615e8 100644
--- a/gn2/wqflask/oauth2/client.py
+++ b/gn2/wqflask/oauth2/client.py
@@ -1,5 +1,7 @@
 """Common oauth2 client utilities."""
 import json
+import time
+import random
 import requests
 from typing import Optional
 from urllib.parse import urljoin
@@ -38,10 +40,36 @@ def oauth2_client():
     def __update_token__(token, refresh_token=None, access_token=None):
         """Update the token when refreshed."""
         session.set_user_token(token)
+        return token
 
-    def __client__(token) -> OAuth2Session:
+    def __validate_token__(token):
         _jwt = jwt.decode(token["access_token"],
                           app.config["AUTH_SERVER_SSL_PUBLIC_KEY"])
+        return token
+
+    def __delay__():
+        """Do a tiny delay."""
+        time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
+
+    def __refresh_token__(token):
+        """Synchronise token refresh."""
+        if session.is_token_expired():
+            __delay__()
+            if session.is_token_refreshing():
+                while session.is_token_refreshing():
+                    __delay__()
+                    _token = session.user_token().either(None, lambda _tok: _tok)
+                    return _token
+
+            session.toggle_token_refreshing()
+            _client = __client__(token)
+            _client.get(urljoin(authserver_uri(), "auth/user/"))
+            session.toggle_token_refreshing()
+            return _client.token
+
+        return token
+
+    def __client__(token) -> OAuth2Session:
         client = OAuth2Session(
             oauth2_clientid(),
             oauth2_clientsecret(),
@@ -51,9 +79,10 @@ def oauth2_client():
             token=token,
             update_token=__update_token__)
         return client
-    return session.user_token().either(
-        lambda _notok: __client__(None),
-        lambda token: __client__(token))
+    return session.user_token().then(__validate_token__).then(
+        __refresh_token__).either(
+            lambda _notok: __client__(None),
+            lambda token: __client__(token))
 
 def __no_token__(_err) -> Left:
     """Handle situation where request is attempted with no token."""
diff --git a/gn2/wqflask/oauth2/session.py b/gn2/wqflask/oauth2/session.py
index eec48a7f..92181ccf 100644
--- a/gn2/wqflask/oauth2/session.py
+++ b/gn2/wqflask/oauth2/session.py
@@ -22,6 +22,7 @@ class SessionInfo(TypedDict):
     user_agent: str
     ip_addr: str
     masquerade: Optional[UserDetails]
+    refreshing_token: bool
 
 __SESSION_KEY__ = "GN::2::session_info" # Do not use this outside this module!!
 
@@ -61,7 +62,8 @@ def session_info() -> SessionInfo:
             "user_agent": request.headers.get("User-Agent"),
             "ip_addr": request.environ.get("HTTP_X_FORWARDED_FOR",
                                            request.remote_addr),
-            "masquerading": None
+            "masquerading": None,
+            "token_refreshing": False
         }))
 
 
@@ -102,3 +104,23 @@ def unset_masquerading():
         "user": the_session["masquerading"],
         "masquerading": None
     })
+
+
+def toggle_token_refreshing():
+    """Toggle the state of the token_refreshing variable."""
+    _session = session_info()
+    return save_session_info({
+        **_session,
+        "token_refreshing": not _session.get("token_refreshing", False)})
+
+
+def is_token_expired():
+    """Check whether the token is expired."""
+    return user_token().either(
+        lambda _no_token: False,
+        lambda token: datetime.now().timestamp() > token["expires_at"])
+
+
+def is_token_refreshing():
+    """Returns whether the token is being refreshed or not."""
+    return session_info().get("token_refreshing", False)