"""OAuth2 client utilities.""" from datetime import datetime, timedelta from urllib.parse import urljoin, urlparse import requests from flask import request, current_app as app from pymonad.either import Left, Right from authlib.jose import jwt from authlib.jose import KeySet, JsonWebKey from authlib.jose.errors import BadSignatureError from authlib.integrations.requests_client import OAuth2Session from uploader import session SCOPE = ("profile group role resource register-client user masquerade " "introspect migrate-data") def authserver_uri(): """Return URI to authorisation server.""" return app.config["AUTH_SERVER_URL"] def oauth2_clientid(): """Return the client id.""" return app.config["OAUTH2_CLIENT_ID"] def oauth2_clientsecret(): """Return the client secret.""" return app.config["OAUTH2_CLIENT_SECRET"] def __make_token_validator__(keys: KeySet): """Make a token validator function.""" def __validator__(token: str): try: jwt.decode(token, keys) return Right(token) except BadSignatureError: return Left("INVALID-TOKEN") return __validator__ def __validate_token__(sess_info): """Validate that the token is really from the auth server.""" info = __update_auth_server_jwks__(sess_info) info["user"]["token"] = info["user"]["token"].then(__make_token_validator__( KeySet(JsonWebKey(key) for key in info.get( "auth_server_jwks", {}).get( "jwks", {"keys": []})["keys"]))) return session.save_session_info(info) def __update_auth_server_jwks__(sess_info): """Updates the JWKs every 2 hours or so.""" jwks = sess_info.get("auth_server_jwks") if bool(jwks): last_updated = jwks.get("last-updated") now = datetime.now().timestamp() if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds: return __validate_token__({**sess_info, "auth_server_jwks": jwks}) jwksuri = urljoin(authserver_uri(), "auth/public-jwks") return __validate_token__({ **sess_info, "auth_server_jwks": { "jwks": KeySet([ JsonWebKey.import_key(key) for key in requests.get(jwksuri).json()["jwks"]]).as_dict(), "last-updated": datetime.now().timestamp() } }) def oauth2_client(): """Build the OAuth2 client for use fetching data.""" def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument] """Update the token when refreshed.""" app.logger.debug(f"IN `{__name__}`:\n\trefresh_token: {refresh_token}" f"\n\taccess_token: {access_token}\n\ttoken: {token}") session.set_user_token(token) def __client__(token) -> OAuth2Session: client = OAuth2Session( oauth2_clientid(), oauth2_clientsecret(), scope=SCOPE, token_endpoint=urljoin(authserver_uri(), "/auth/token"), token_endpoint_auth_method="client_secret_post", token=token, update_token=__update_token__) return client __update_auth_server_jwks__(session.session_info()) return session.user_token().either( lambda _notok: __client__(None), __client__) def user_logged_in(): """Check whether the user has logged in.""" suser = session.session_info()["user"] # return suser["logged_in"] and suser["token"].is_right() return False def authserver_authorise_uri(): req_baseurl = urlparse(request.base_url, scheme=request.scheme) host_uri = f"{req_baseurl.scheme}://{req_baseurl.netloc}/" return urljoin( authserver_uri(), "auth/authorise?response_type=code" f"&client_id={oauth2_clientid()}" f"&redirect_uri={urljoin(host_uri, 'oauth2/code')}")