diff options
Diffstat (limited to 'uploader/oauth2/client.py')
-rw-r--r-- | uploader/oauth2/client.py | 101 |
1 files changed, 101 insertions, 0 deletions
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py new file mode 100644 index 0000000..6e101ae --- /dev/null +++ b/uploader/oauth2/client.py @@ -0,0 +1,101 @@ +"""OAuth2 client utilities.""" +from urllib.parse import urljoin +from datetime import datetime, timedelta + +import requests +from flask import current_app as app + +from pymonad.either import Left, Right + +from authlib.jose import jwt +from authlib.jose import KeySet, JsonWebKey +from authlib.jose.errors import BadSignatureError +from authlib.integrations.requests_client import OAuth2Session + +from uploader import session + +SCOPE = ("profile group role resource register-client user masquerade " + "introspect migrate-data") + + +def authserver_uri(): + """Return URI to authorisation server.""" + return app.config["AUTH_SERVER_URL"] + + +def oauth2_clientid(): + """Return the client id.""" + return app.config["OAUTH2_CLIENT_ID"] + + +def oauth2_clientsecret(): + """Return the client secret.""" + return app.config["OAUTH2_CLIENT_SECRET"] + + +def __make_token_validator__(keys: KeySet): + """Make a token validator function.""" + def __validator__(token: str): + try: + jwt.decode(token, keys) + return Right(token) + except BadSignatureError: + return Left("INVALID-TOKEN") + + return __validator__ + + +def __validate_token__(sess_info): + """Validate that the token is really from the auth server.""" + info = __update_auth_server_jwks__(sess_info) + info["user"]["token"] = info["user"]["token"].then(__make_token_validator__( + KeySet(JsonWebKey(key) for key in info.get( + "auth_server_jwks", {}).get( + "jwks", {"keys": []})["keys"]))) + return session.save_session_info(info) + + +def __update_auth_server_jwks__(sess_info): + """Updates the JWKs every 2 hours or so.""" + jwks = sess_info.get("auth_server_jwks") + if bool(jwks): + last_updated = jwks.get("last-updated") + now = datetime.now().timestamp() + if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds: + return __validate_token__({**sess_info, "auth_server_jwks": jwks}) + + jwksuri = urljoin(authserver_uri(), "auth/public-jwks") + return __validate_token__({ + **sess_info, + "auth_server_jwks": { + "jwks": KeySet([ + JsonWebKey.import_key(key) + for key in requests.get(jwksuri).json()["jwks"]]).as_dict(), + "last-updated": datetime.now().timestamp() + } + }) + + +def oauth2_client(): + """Build the OAuth2 client for use fetching data.""" + def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument] + """Update the token when refreshed.""" + app.logger.debug(f"IN `{__name__}`:\n\trefresh_token: {refresh_token}" + f"\n\taccess_token: {access_token}\n\ttoken: {token}") + session.set_user_token(token) + + def __client__(token) -> OAuth2Session: + client = OAuth2Session( + oauth2_clientid(), + oauth2_clientsecret(), + scope=SCOPE, + token_endpoint=urljoin(authserver_uri(), "/auth/token"), + token_endpoint_auth_method="client_secret_post", + token=token, + update_token=__update_token__) + return client + + __update_auth_server_jwks__(session.session_info()) + return session.user_token().either( + lambda _notok: __client__(None), + __client__) |