diff options
-rw-r--r-- | uploader/oauth2/client.py | 79 |
1 files changed, 34 insertions, 45 deletions
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py index a3e4ba3..8ad8f7c 100644 --- a/uploader/oauth2/client.py +++ b/uploader/oauth2/client.py @@ -8,10 +8,8 @@ from flask import request, current_app as app from pymonad.either import Left, Right, Either -from authlib.jose import jwt from authlib.common.urls import url_decode from authlib.jose import KeySet, JsonWebKey -from authlib.jose.errors import BadSignatureError from authlib.integrations.requests_client import OAuth2Session from uploader import session @@ -36,49 +34,41 @@ def oauth2_clientsecret(): return app.config["OAUTH2_CLIENT_SECRET"] -def __make_token_validator__(keys: KeySet): - """Make a token validator function.""" - def __validator__(token: dict): - for key in keys.keys: - try: - jwt.decode(token["access_token"], key) - return Right(token) - except BadSignatureError: - pass - - return Left("INVALID-TOKEN") - - return __validator__ - - -def __validate_token__(sess_info): - """Validate that the token is really from the auth server.""" - info = sess_info - info["user"]["token"] = info["user"]["token"].then(__make_token_validator__( - KeySet([JsonWebKey.import_key(key) for key in info.get( - "auth_server_jwks", {}).get( - "jwks", {"keys": []})["keys"]]))) - return session.save_session_info(info) - - -def __update_auth_server_jwks__(sess_info): - """Updates the JWKs every 2 hours or so.""" - jwks = sess_info.get("auth_server_jwks") - if bool(jwks): - last_updated = jwks.get("last-updated") - now = datetime.now().timestamp() - if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds: - return __validate_token__({**sess_info, "auth_server_jwks": jwks}) - - jwksuri = urljoin(authserver_uri(), "auth/public-jwks") - return __validate_token__({ - **sess_info, - "auth_server_jwks": { +def __fetch_auth_server_jwks__() -> KeySet: + """Fetch the JWKs from the auth server.""" + return KeySet([ + JsonWebKey.import_key(key) + for key in requests.get( + urljoin(authserver_uri(), "auth/public-jwks") + ).json()["jwks"]]) + + +def __update_auth_server_jwks__(jwks) -> KeySet: + """Update the JWKs from the servers if necessary.""" + last_updated = jwks["last-updated"] + now = datetime.now().timestamp() + # Maybe the `two_hours` variable below can be made into a configuration + # variable and passed in to this function + two_hours = timedelta(hours=2).seconds + if bool(last_updated) and (now - last_updated) < two_hours: + return jwks["jwks"] + + return session.set_auth_server_jwks(__fetch_auth_server_jwks__()) + + +def auth_server_jwks() -> KeySet: + """Fetch the auth-server JSON Web Keys information.""" + _jwks = session.session_info().get("auth_server_jwks") + if bool(_jwks): + return __update_auth_server_jwks__({ + "last-updated": _jwks["last-updated"], "jwks": KeySet([ - JsonWebKey.import_key(key) - for key in requests.get(jwksuri).json()["jwks"]]).as_dict(), - "last-updated": datetime.now().timestamp() - } + JsonWebKey.import_key(key) for key in _jwks.get( + "jwks", {"keys": []})["keys"]]) + }) + + return __update_auth_server_jwks__({ + "last-updated": (datetime.now() - timedelta(hours=3)).timestamp() }) @@ -111,7 +101,6 @@ def oauth2_client(): ("client_secret_post", __json_auth__)) return client - __update_auth_server_jwks__(session.session_info()) return session.user_token().either( lambda _notok: __client__(None), __client__) |