From 20933e55de7927063dd159d116b468b4724a19a8 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Thu, 1 Aug 2024 12:19:34 -0500 Subject: Use JWKs from auth server public endpoint * Fetch keys from auth server * Validate token is signed with one of the keys from server * Ensure refreshing of token is still synchronised --- gn2/wqflask/app_errors.py | 2 +- gn2/wqflask/oauth2/client.py | 95 +++++++++++++++++++++++++++++++++++++------ gn2/wqflask/oauth2/session.py | 8 +--- 3 files changed, 84 insertions(+), 21 deletions(-) diff --git a/gn2/wqflask/app_errors.py b/gn2/wqflask/app_errors.py index 7c07fde6..bafe773b 100644 --- a/gn2/wqflask/app_errors.py +++ b/gn2/wqflask/app_errors.py @@ -51,7 +51,7 @@ def handle_invalid_token_error(exc: InvalidTokenError): flash("An invalid session token was detected. " "You have been logged out of the system.", "alert-danger") - current_app.logger.error("Invalit token detected. %s", request.url, exc_info=True) + current_app.logger.error("Invalid token detected. %s", request.url, exc_info=True) session.clear_session_info() return redirect("/") diff --git a/gn2/wqflask/oauth2/client.py b/gn2/wqflask/oauth2/client.py index 0d4615e8..6f137f52 100644 --- a/gn2/wqflask/oauth2/client.py +++ b/gn2/wqflask/oauth2/client.py @@ -5,10 +5,12 @@ import random import requests from typing import Optional from urllib.parse import urljoin +from datetime import datetime, timedelta from flask import current_app as app from pymonad.either import Left, Right, Either -from authlib.jose import jwt +from authlib.jose import KeySet, JsonWebKey, JsonWebToken +from authlib.jose.errors import BadSignatureError from authlib.integrations.requests_client import OAuth2Session from gn2.wqflask.oauth2 import session @@ -36,30 +38,96 @@ def user_logged_in(): return suser["logged_in"] and suser["token"].is_right() +def __make_token_validator__(keys: KeySet): + """Make a token validator function.""" + def __validator__(token: dict): + for key in keys.keys: + try: + # Fixes CVE-2016-10555. See + # https://docs.authlib.org/en/latest/jose/jwt.html + jwt = JsonWebToken(["RS256"]) + jwt.decode(token["access_token"], key) + return Right(token) + except BadSignatureError: + pass + + return Left("INVALID-TOKEN") + + return __validator__ + + +def auth_server_jwks() -> Optional[KeySet]: + """Fetch the auth-server JSON Web Keys information.""" + _jwks = session.session_info().get("auth_server_jwks") + if bool(_jwks): + return { + "last-updated": _jwks["last-updated"], + "jwks": KeySet([ + JsonWebKey.import_key(key) for key in _jwks.get( + "auth_server_jwks", {}).get( + "jwks", {"keys": []})["keys"]])} + + +def __validate_token__(keys): + """Validate that the token is really from the auth server.""" + def __index__(_sess): + return _sess + return session.user_token().then(__make_token_validator__(keys)).then( + session.set_user_token).either(__index__, __index__) + + +def __update_auth_server_jwks__(): + """Updates the JWKs every 2 hours or so.""" + jwks = auth_server_jwks() + if bool(jwks): + last_updated = jwks.get("last-updated") + now = datetime.now().timestamp() + if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds: + return __validate_token__(jwks["jwks"]) + + jwksuri = urljoin(authserver_uri(), "auth/public-jwks") + jwks = KeySet([ + JsonWebKey.import_key(key) + for key in requests.get(jwksuri).json()["jwks"]]) + return __validate_token__(jwks) + + +def is_token_expired(token): + """Check whether the token has expired.""" + __update_auth_server_jwks__() + jwks = auth_server_jwks() + if bool(jwks): + for jwk in jwks["jwks"].keys: + try: + jwt = JsonWebToken(["RS256"]).decode( + token["access_token"], key=jwk) + return datetime.now().timestamp() > jwt["exp"] + except BadSignatureError as _bse: + pass + + return False + + def oauth2_client(): def __update_token__(token, refresh_token=None, access_token=None): """Update the token when refreshed.""" session.set_user_token(token) return token - def __validate_token__(token): - _jwt = jwt.decode(token["access_token"], - app.config["AUTH_SERVER_SSL_PUBLIC_KEY"]) - return token - def __delay__(): """Do a tiny delay.""" time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100)))) def __refresh_token__(token): """Synchronise token refresh.""" - if session.is_token_expired(): + if is_token_expired(token): __delay__() if session.is_token_refreshing(): while session.is_token_refreshing(): __delay__() - _token = session.user_token().either(None, lambda _tok: _tok) - return _token + + _token = session.user_token().either(None, lambda _tok: _tok) + return _token session.toggle_token_refreshing() _client = __client__(token) @@ -79,10 +147,11 @@ def oauth2_client(): token=token, update_token=__update_token__) return client - return session.user_token().then(__validate_token__).then( - __refresh_token__).either( - lambda _notok: __client__(None), - lambda token: __client__(token)) + + __update_auth_server_jwks__() + return session.user_token().then(__refresh_token__).either( + lambda _notok: __client__(None), + lambda token: __client__(token)) def __no_token__(_err) -> Left: """Handle situation where request is attempted with no token.""" diff --git a/gn2/wqflask/oauth2/session.py b/gn2/wqflask/oauth2/session.py index 92181ccf..b91534b0 100644 --- a/gn2/wqflask/oauth2/session.py +++ b/gn2/wqflask/oauth2/session.py @@ -23,6 +23,7 @@ class SessionInfo(TypedDict): ip_addr: str masquerade: Optional[UserDetails] refreshing_token: bool + auth_server_jwks: Optional[dict[str, Any]] __SESSION_KEY__ = "GN::2::session_info" # Do not use this outside this module!! @@ -114,13 +115,6 @@ def toggle_token_refreshing(): "token_refreshing": not _session.get("token_refreshing", False)}) -def is_token_expired(): - """Check whether the token is expired.""" - return user_token().either( - lambda _no_token: False, - lambda token: datetime.now().timestamp() > token["expires_at"]) - - def is_token_refreshing(): """Returns whether the token is being refreshed or not.""" return session_info().get("token_refreshing", False) -- cgit v1.2.3