about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-09-04 18:35:45 -0500
committerFrederick Muriuki Muriithi2024-09-04 18:35:45 -0500
commitc00f1b948d51cb6c567a6d583ac32c8b78a29692 (patch)
tree8045c31e953ace668d762c29c22826105fab166e
parent5ce4394c34048c6bbf253124df7b3bed22094d92 (diff)
downloadgn-uploader-c00f1b948d51cb6c567a6d583ac32c8b78a29692.tar.gz
Update retrival of JSON Web Keys
* Introduce the function `auth_server_jwks()` to use for fetching the
  keys from the session if present, or from the server when absent or
  out-of-date. It also handles updating the keys in the session.
* Remove the unnecessary verification of JWKs until the point where
  that is needed, i.e. at the point(s) where there is need to verify
  authorisation.
-rw-r--r--uploader/oauth2/client.py79
1 files changed, 34 insertions, 45 deletions
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py
index a3e4ba3..8ad8f7c 100644
--- a/uploader/oauth2/client.py
+++ b/uploader/oauth2/client.py
@@ -8,10 +8,8 @@ from flask import request, current_app as app
 
 from pymonad.either import Left, Right, Either
 
-from authlib.jose import jwt
 from authlib.common.urls import url_decode
 from authlib.jose import KeySet, JsonWebKey
-from authlib.jose.errors import BadSignatureError
 from authlib.integrations.requests_client import OAuth2Session
 
 from uploader import session
@@ -36,49 +34,41 @@ def oauth2_clientsecret():
     return app.config["OAUTH2_CLIENT_SECRET"]
 
 
-def __make_token_validator__(keys: KeySet):
-    """Make a token validator function."""
-    def __validator__(token: dict):
-        for key in keys.keys:
-            try:
-                jwt.decode(token["access_token"], key)
-                return Right(token)
-            except BadSignatureError:
-                pass
-
-        return Left("INVALID-TOKEN")
-
-    return __validator__
-
-
-def __validate_token__(sess_info):
-    """Validate that the token is really from the auth server."""
-    info = sess_info
-    info["user"]["token"] = info["user"]["token"].then(__make_token_validator__(
-        KeySet([JsonWebKey.import_key(key) for key in info.get(
-            "auth_server_jwks", {}).get(
-                "jwks", {"keys": []})["keys"]])))
-    return session.save_session_info(info)
-
-
-def __update_auth_server_jwks__(sess_info):
-    """Updates the JWKs every 2 hours or so."""
-    jwks = sess_info.get("auth_server_jwks")
-    if bool(jwks):
-        last_updated = jwks.get("last-updated")
-        now = datetime.now().timestamp()
-        if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds:
-            return __validate_token__({**sess_info, "auth_server_jwks": jwks})
-
-    jwksuri = urljoin(authserver_uri(), "auth/public-jwks")
-    return __validate_token__({
-        **sess_info,
-        "auth_server_jwks": {
+def __fetch_auth_server_jwks__() -> KeySet:
+    """Fetch the JWKs from the auth server."""
+    return KeySet([
+        JsonWebKey.import_key(key)
+        for key in requests.get(
+                urljoin(authserver_uri(), "auth/public-jwks")
+        ).json()["jwks"]])
+
+
+def __update_auth_server_jwks__(jwks) -> KeySet:
+    """Update the JWKs from the servers if necessary."""
+    last_updated = jwks["last-updated"]
+    now = datetime.now().timestamp()
+    # Maybe the `two_hours` variable below can be made into a configuration
+    # variable and passed in to this function
+    two_hours = timedelta(hours=2).seconds
+    if bool(last_updated) and (now - last_updated) < two_hours:
+        return jwks["jwks"]
+
+    return session.set_auth_server_jwks(__fetch_auth_server_jwks__())
+
+
+def auth_server_jwks() -> KeySet:
+    """Fetch the auth-server JSON Web Keys information."""
+    _jwks = session.session_info().get("auth_server_jwks")
+    if bool(_jwks):
+        return __update_auth_server_jwks__({
+            "last-updated": _jwks["last-updated"],
             "jwks": KeySet([
-                JsonWebKey.import_key(key)
-                for key in requests.get(jwksuri).json()["jwks"]]).as_dict(),
-            "last-updated": datetime.now().timestamp()
-        }
+                JsonWebKey.import_key(key) for key in _jwks.get(
+                        "jwks", {"keys": []})["keys"]])
+        })
+
+    return __update_auth_server_jwks__({
+        "last-updated": (datetime.now() - timedelta(hours=3)).timestamp()
     })
 
 
@@ -111,7 +101,6 @@ def oauth2_client():
             ("client_secret_post", __json_auth__))
         return client
 
-    __update_auth_server_jwks__(session.session_info())
     return session.user_token().either(
         lambda _notok: __client__(None),
         __client__)