diff options
author | Frederick Muriuki Muriithi | 2024-08-01 12:19:34 -0500 |
---|---|---|
committer | Alexander_Kabui | 2024-08-28 15:02:46 +0300 |
commit | a9a8ef79a10c58a514d5aac0b2b1c9000a57f9f8 (patch) | |
tree | 497fa5f5816510a9de2eecedb3ff773143e6ac08 /gn2/wqflask/oauth2/client.py | |
parent | 9f4fa60a843ca764116286c77057824638b5c8d0 (diff) | |
download | genenetwork2-a9a8ef79a10c58a514d5aac0b2b1c9000a57f9f8.tar.gz |
Use JWKs from auth server public endpoint
* Fetch keys from auth server
* Validate token is signed with one of the keys from server
* Ensure refreshing of token is still synchronised
Diffstat (limited to 'gn2/wqflask/oauth2/client.py')
-rw-r--r-- | gn2/wqflask/oauth2/client.py | 95 |
1 files changed, 82 insertions, 13 deletions
diff --git a/gn2/wqflask/oauth2/client.py b/gn2/wqflask/oauth2/client.py index 0d4615e8..6f137f52 100644 --- a/gn2/wqflask/oauth2/client.py +++ b/gn2/wqflask/oauth2/client.py @@ -5,10 +5,12 @@ import random import requests from typing import Optional from urllib.parse import urljoin +from datetime import datetime, timedelta from flask import current_app as app from pymonad.either import Left, Right, Either -from authlib.jose import jwt +from authlib.jose import KeySet, JsonWebKey, JsonWebToken +from authlib.jose.errors import BadSignatureError from authlib.integrations.requests_client import OAuth2Session from gn2.wqflask.oauth2 import session @@ -36,30 +38,96 @@ def user_logged_in(): return suser["logged_in"] and suser["token"].is_right() +def __make_token_validator__(keys: KeySet): + """Make a token validator function.""" + def __validator__(token: dict): + for key in keys.keys: + try: + # Fixes CVE-2016-10555. See + # https://docs.authlib.org/en/latest/jose/jwt.html + jwt = JsonWebToken(["RS256"]) + jwt.decode(token["access_token"], key) + return Right(token) + except BadSignatureError: + pass + + return Left("INVALID-TOKEN") + + return __validator__ + + +def auth_server_jwks() -> Optional[KeySet]: + """Fetch the auth-server JSON Web Keys information.""" + _jwks = session.session_info().get("auth_server_jwks") + if bool(_jwks): + return { + "last-updated": _jwks["last-updated"], + "jwks": KeySet([ + JsonWebKey.import_key(key) for key in _jwks.get( + "auth_server_jwks", {}).get( + "jwks", {"keys": []})["keys"]])} + + +def __validate_token__(keys): + """Validate that the token is really from the auth server.""" + def __index__(_sess): + return _sess + return session.user_token().then(__make_token_validator__(keys)).then( + session.set_user_token).either(__index__, __index__) + + +def __update_auth_server_jwks__(): + """Updates the JWKs every 2 hours or so.""" + jwks = auth_server_jwks() + if bool(jwks): + last_updated = jwks.get("last-updated") + now = datetime.now().timestamp() + if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds: + return __validate_token__(jwks["jwks"]) + + jwksuri = urljoin(authserver_uri(), "auth/public-jwks") + jwks = KeySet([ + JsonWebKey.import_key(key) + for key in requests.get(jwksuri).json()["jwks"]]) + return __validate_token__(jwks) + + +def is_token_expired(token): + """Check whether the token has expired.""" + __update_auth_server_jwks__() + jwks = auth_server_jwks() + if bool(jwks): + for jwk in jwks["jwks"].keys: + try: + jwt = JsonWebToken(["RS256"]).decode( + token["access_token"], key=jwk) + return datetime.now().timestamp() > jwt["exp"] + except BadSignatureError as _bse: + pass + + return False + + def oauth2_client(): def __update_token__(token, refresh_token=None, access_token=None): """Update the token when refreshed.""" session.set_user_token(token) return token - def __validate_token__(token): - _jwt = jwt.decode(token["access_token"], - app.config["AUTH_SERVER_SSL_PUBLIC_KEY"]) - return token - def __delay__(): """Do a tiny delay.""" time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100)))) def __refresh_token__(token): """Synchronise token refresh.""" - if session.is_token_expired(): + if is_token_expired(token): __delay__() if session.is_token_refreshing(): while session.is_token_refreshing(): __delay__() - _token = session.user_token().either(None, lambda _tok: _tok) - return _token + + _token = session.user_token().either(None, lambda _tok: _tok) + return _token session.toggle_token_refreshing() _client = __client__(token) @@ -79,10 +147,11 @@ def oauth2_client(): token=token, update_token=__update_token__) return client - return session.user_token().then(__validate_token__).then( - __refresh_token__).either( - lambda _notok: __client__(None), - lambda token: __client__(token)) + + __update_auth_server_jwks__() + return session.user_token().then(__refresh_token__).either( + lambda _notok: __client__(None), + lambda token: __client__(token)) def __no_token__(_err) -> Left: """Handle situation where request is attempted with no token.""" |