about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-08-01 12:19:34 -0500
committerAlexander_Kabui2024-08-28 15:02:46 +0300
commita9a8ef79a10c58a514d5aac0b2b1c9000a57f9f8 (patch)
tree497fa5f5816510a9de2eecedb3ff773143e6ac08
parent9f4fa60a843ca764116286c77057824638b5c8d0 (diff)
downloadgenenetwork2-a9a8ef79a10c58a514d5aac0b2b1c9000a57f9f8.tar.gz
Use JWKs from auth server public endpoint
* Fetch keys from auth server
* Validate token is signed with one of the keys from server
* Ensure refreshing of token is still synchronised
-rw-r--r--gn2/wqflask/app_errors.py2
-rw-r--r--gn2/wqflask/oauth2/client.py95
-rw-r--r--gn2/wqflask/oauth2/session.py8
3 files changed, 84 insertions, 21 deletions
diff --git a/gn2/wqflask/app_errors.py b/gn2/wqflask/app_errors.py
index 7c07fde6..bafe773b 100644
--- a/gn2/wqflask/app_errors.py
+++ b/gn2/wqflask/app_errors.py
@@ -51,7 +51,7 @@ def handle_invalid_token_error(exc: InvalidTokenError):
     flash("An invalid session token was detected. "
           "You have been logged out of the system.",
           "alert-danger")
-    current_app.logger.error("Invalit token detected. %s", request.url, exc_info=True)
+    current_app.logger.error("Invalid token detected. %s", request.url, exc_info=True)
     session.clear_session_info()
     return redirect("/")
 
diff --git a/gn2/wqflask/oauth2/client.py b/gn2/wqflask/oauth2/client.py
index 0d4615e8..6f137f52 100644
--- a/gn2/wqflask/oauth2/client.py
+++ b/gn2/wqflask/oauth2/client.py
@@ -5,10 +5,12 @@ import random
 import requests
 from typing import Optional
 from urllib.parse import urljoin
+from datetime import datetime, timedelta
 
 from flask import current_app as app
 from pymonad.either import Left, Right, Either
-from authlib.jose import jwt
+from authlib.jose import KeySet, JsonWebKey, JsonWebToken
+from authlib.jose.errors import BadSignatureError
 from authlib.integrations.requests_client import OAuth2Session
 
 from gn2.wqflask.oauth2 import session
@@ -36,30 +38,96 @@ def user_logged_in():
     return suser["logged_in"] and suser["token"].is_right()
 
 
+def __make_token_validator__(keys: KeySet):
+    """Make a token validator function."""
+    def __validator__(token: dict):
+        for key in keys.keys:
+            try:
+                # Fixes CVE-2016-10555. See
+                # https://docs.authlib.org/en/latest/jose/jwt.html
+                jwt = JsonWebToken(["RS256"])
+                jwt.decode(token["access_token"], key)
+                return Right(token)
+            except BadSignatureError:
+                pass
+
+        return Left("INVALID-TOKEN")
+
+    return __validator__
+
+
+def auth_server_jwks() -> Optional[KeySet]:
+    """Fetch the auth-server JSON Web Keys information."""
+    _jwks = session.session_info().get("auth_server_jwks")
+    if bool(_jwks):
+        return {
+            "last-updated": _jwks["last-updated"],
+            "jwks": KeySet([
+                JsonWebKey.import_key(key) for key in _jwks.get(
+                    "auth_server_jwks", {}).get(
+                        "jwks", {"keys": []})["keys"]])}
+
+
+def __validate_token__(keys):
+    """Validate that the token is really from the auth server."""
+    def __index__(_sess):
+        return _sess
+    return session.user_token().then(__make_token_validator__(keys)).then(
+        session.set_user_token).either(__index__, __index__)
+
+
+def __update_auth_server_jwks__():
+    """Updates the JWKs every 2 hours or so."""
+    jwks = auth_server_jwks()
+    if bool(jwks):
+        last_updated = jwks.get("last-updated")
+        now = datetime.now().timestamp()
+        if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds:
+            return __validate_token__(jwks["jwks"])
+
+    jwksuri = urljoin(authserver_uri(), "auth/public-jwks")
+    jwks = KeySet([
+        JsonWebKey.import_key(key)
+        for key in requests.get(jwksuri).json()["jwks"]])
+    return __validate_token__(jwks)
+
+
+def is_token_expired(token):
+    """Check whether the token has expired."""
+    __update_auth_server_jwks__()
+    jwks = auth_server_jwks()
+    if bool(jwks):
+        for jwk in jwks["jwks"].keys:
+            try:
+                jwt = JsonWebToken(["RS256"]).decode(
+                    token["access_token"], key=jwk)
+                return datetime.now().timestamp() > jwt["exp"]
+            except BadSignatureError as _bse:
+                pass
+
+    return False
+
+
 def oauth2_client():
     def __update_token__(token, refresh_token=None, access_token=None):
         """Update the token when refreshed."""
         session.set_user_token(token)
         return token
 
-    def __validate_token__(token):
-        _jwt = jwt.decode(token["access_token"],
-                          app.config["AUTH_SERVER_SSL_PUBLIC_KEY"])
-        return token
-
     def __delay__():
         """Do a tiny delay."""
         time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
 
     def __refresh_token__(token):
         """Synchronise token refresh."""
-        if session.is_token_expired():
+        if is_token_expired(token):
             __delay__()
             if session.is_token_refreshing():
                 while session.is_token_refreshing():
                     __delay__()
-                    _token = session.user_token().either(None, lambda _tok: _tok)
-                    return _token
+
+                _token = session.user_token().either(None, lambda _tok: _tok)
+                return _token
 
             session.toggle_token_refreshing()
             _client = __client__(token)
@@ -79,10 +147,11 @@ def oauth2_client():
             token=token,
             update_token=__update_token__)
         return client
-    return session.user_token().then(__validate_token__).then(
-        __refresh_token__).either(
-            lambda _notok: __client__(None),
-            lambda token: __client__(token))
+
+    __update_auth_server_jwks__()
+    return session.user_token().then(__refresh_token__).either(
+        lambda _notok: __client__(None),
+        lambda token: __client__(token))
 
 def __no_token__(_err) -> Left:
     """Handle situation where request is attempted with no token."""
diff --git a/gn2/wqflask/oauth2/session.py b/gn2/wqflask/oauth2/session.py
index 92181ccf..b91534b0 100644
--- a/gn2/wqflask/oauth2/session.py
+++ b/gn2/wqflask/oauth2/session.py
@@ -23,6 +23,7 @@ class SessionInfo(TypedDict):
     ip_addr: str
     masquerade: Optional[UserDetails]
     refreshing_token: bool
+    auth_server_jwks: Optional[dict[str, Any]]
 
 __SESSION_KEY__ = "GN::2::session_info" # Do not use this outside this module!!
 
@@ -114,13 +115,6 @@ def toggle_token_refreshing():
         "token_refreshing": not _session.get("token_refreshing", False)})
 
 
-def is_token_expired():
-    """Check whether the token is expired."""
-    return user_token().either(
-        lambda _no_token: False,
-        lambda token: datetime.now().timestamp() > token["expires_at"])
-
-
 def is_token_refreshing():
     """Returns whether the token is being refreshed or not."""
     return session_info().get("token_refreshing", False)