"""Authorisation utilities.""" import logging from functools import wraps from typing import Callable from flask import flash, redirect from pymonad.either import Left, Right, Either from authlib.jose import KeySet, JsonWebToken from authlib.jose.errors import BadSignatureError from uploader import session from uploader.oauth2.client import auth_server_jwks def require_login(function): """Check that the user is logged in before executing `func`.""" @wraps(function) def __is_session_valid__(*args, **kwargs): """Check that the user is logged in and their token is valid.""" def __clear_session__(_no_token): session.clear_session_info() flash("You need to be logged in.", "alert-danger") return redirect("/") return session.user_token().either( __clear_session__, lambda token: function(*args, **kwargs)) return __is_session_valid__ def __validate_token__(jwks: KeySet, token: dict) -> Either: """Check that a token is signed by a key from the authorisation server.""" for key in jwks.keys: try: # Fixes CVE-2016-10555. See # https://docs.authlib.org/en/latest/jose/jwt.html jwt = JsonWebToken(["RS256"]) jwt.decode(token["access_token"], key) return Right(token) except BadSignatureError: pass return Left({"token": token}) def require_token(func: Callable) -> Callable: """ Wrap functions that require the user be authorised to perform the operations that the functions in question provide. """ def __invalid_token__(_whatever): logging.debug("==========> Failure log: %s", _whatever) raise Exception( "You attempted to access a feature of the system that requires " "authorisation. Unfortunately, we could not verify you have the " "appropriate authorisation to perform the action you requested. " "You might need to log in, or if you already are logged in, you " "need to log out, then log back in to get a newer token/session.") @wraps(func) def __wrapper__(*args, **kwargs): return session.user_token().then(lambda tok: { "jwks": auth_server_jwks(), "token": tok }).then(lambda vals: __validate_token__(**vals)).either( __invalid_token__, lambda tok: func(*args, **{**kwargs, "token": tok})) return __wrapper__