aboutsummaryrefslogtreecommitdiff
"""OAuth2 client utilities."""
import json
import time
import random
from datetime import datetime, timedelta
from urllib.parse import urljoin, urlparse

import requests
from flask import request, current_app as app

from pymonad.either import Left, Right, Either

from authlib.common.urls import url_decode
from authlib.jose.errors import BadSignatureError
from authlib.jose import KeySet, JsonWebKey, JsonWebToken
from authlib.integrations.requests_client import OAuth2Session

from uploader import session
import uploader.monadic_requests as mrequests

SCOPE = ("profile group role resource register-client user masquerade "
         "introspect migrate-data")


def authserver_uri():
    """Return URI to authorisation server."""
    return app.config["AUTH_SERVER_URL"]


def oauth2_clientid():
    """Return the client id."""
    return app.config["OAUTH2_CLIENT_ID"]


def oauth2_clientsecret():
    """Return the client secret."""
    return app.config["OAUTH2_CLIENT_SECRET"]


def __fetch_auth_server_jwks__() -> KeySet:
    """Fetch the JWKs from the auth server."""
    return KeySet([
        JsonWebKey.import_key(key)
        for key in requests.get(
                urljoin(authserver_uri(), "auth/public-jwks")
        ).json()["jwks"]])


def __update_auth_server_jwks__(jwks) -> KeySet:
    """Update the JWKs from the servers if necessary."""
    last_updated = jwks["last-updated"]
    now = datetime.now().timestamp()
    # Maybe the `two_hours` variable below can be made into a configuration
    # variable and passed in to this function
    two_hours = timedelta(hours=2).seconds
    if bool(last_updated) and (now - last_updated) < two_hours:
        return jwks["jwks"]

    return session.set_auth_server_jwks(__fetch_auth_server_jwks__())


def auth_server_jwks() -> KeySet:
    """Fetch the auth-server JSON Web Keys information."""
    _jwks = session.session_info().get("auth_server_jwks") or {}
    if bool(_jwks):
        return __update_auth_server_jwks__({
            "last-updated": _jwks["last-updated"],
            "jwks": KeySet([
                JsonWebKey.import_key(key) for key in _jwks.get(
                        "jwks", {"keys": []})["keys"]])
        })

    return __update_auth_server_jwks__({
        "last-updated": (datetime.now() - timedelta(hours=3)).timestamp()
    })


def oauth2_client():
    """Build the OAuth2 client for use fetching data."""
    def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument]
        """Update the token when refreshed."""
        session.set_user_token(token)

    def __json_auth__(client, _method, uri, headers, body):
        return (
            uri,
            {**headers, "Content-Type": "application/json"},
            json.dumps({
                **dict(url_decode(body)),
                "client_id": client.client_id,
                "client_secret": client.client_secret
            }))

    def __client__(token) -> OAuth2Session:
        client = OAuth2Session(
            oauth2_clientid(),
            oauth2_clientsecret(),
            scope=SCOPE,
            token_endpoint=urljoin(authserver_uri(), "/auth/token"),
            token_endpoint_auth_method="client_secret_post",
            token=token,
            update_token=__update_token__)
        client.register_client_auth_method(
            ("client_secret_post", __json_auth__))
        return client

    def __token_expired__(token):
        """Check whether the token has expired."""
        jwks = auth_server_jwks()
        if bool(jwks):
            for jwk in jwks.keys:
                try:
                    jwt = JsonWebToken(["RS256"]).decode(
                        token["access_token"], key=jwk)
                    return datetime.now().timestamp() > jwt["exp"]
                except BadSignatureError as _bse:
                    pass

        return False

    def __delay__():
        """Do a tiny delay."""
        time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))

    def __refresh_token__(token):
        """Refresh the token if necessary — synchronise amongst threads."""
        if __token_expired__(token):
            __delay__()
            if session.is_token_refreshing():
                while session.is_token_refreshing():
                    __delay__()

                return session.user_token().either(None, lambda _tok: _tok)

            session.toggle_token_refreshing()
            _client = __client__(token)
            _client.get(urljoin(authserver_uri(), "auth/user/"))
            session.toggle_token_refreshing()
            return _client.token

        return token

    return session.user_token().then(__refresh_token__).either(
        lambda _notok: __client__(None),
        __client__)


def user_logged_in():
    """Check whether the user has logged in."""
    suser = session.session_info()["user"]
    return suser["logged_in"] and suser["token"].is_right()


def authserver_authorise_uri():
    """Build up the authorisation URI."""
    req_baseurl = urlparse(request.base_url, scheme=request.scheme)
    host_uri = f"{req_baseurl.scheme}://{req_baseurl.netloc}/"
    return urljoin(
        authserver_uri(),
        "auth/authorise?response_type=code"
        f"&client_id={oauth2_clientid()}"
        f"&redirect_uri={urljoin(host_uri, 'oauth2/code')}")


def __no_token__(_err) -> Left:
    """Handle situation where request is attempted with no token."""
    resp = requests.models.Response()
    resp._content = json.dumps({#pylint: disable=[protected-access]
        "error": "AuthenticationError",
        "error-description": ("You need to authenticate to access requested "
                              "information.")}).encode("utf-8")
    resp.status_code = 400
    return Left(resp)


def oauth2_get(url, **kwargs) -> Either:
    """Do a get request to the authentication/authorisation server."""
    def __get__(_token) -> Either:
        _uri = urljoin(authserver_uri(), url)
        try:
            resp = oauth2_client().get(
                _uri,
                **{
                    **kwargs,
                    "headers": {
                        **kwargs.get("headers", {}),
                        "Content-Type": "application/json"
                    }
                })
            if resp.status_code in mrequests.SUCCESS_CODES:
                return Right(resp.json())
            return Left(resp)
        except Exception as exc:#pylint: disable=[broad-except]
            app.logger.error("Error retrieving data from auth server: (GET %s)",
                             _uri,
                             exc_info=True)
            return Left(exc)
    return session.user_token().either(__no_token__, __get__)


def oauth2_post(url, data=None, json=None, **kwargs):#pylint: disable=[redefined-outer-name]
    """Do a POST request to the authentication/authorisation server."""
    def __post__(_token) -> Either:
        _uri = urljoin(authserver_uri(), url)
        _headers = ({
                        **kwargs.get("headers", {}),
                        "Content-Type": "application/json"
                    }
                    if bool(json) else kwargs.get("headers", {}))
        try:
            request_data = {
                **(data or {}),
                **(json or {}),
                "client_id": oauth2_clientid(),
                "client_secret": oauth2_clientsecret()
            }
            resp = oauth2_client().post(
                _uri,
                data=(request_data if bool(data) else None),
                json=(request_data if bool(json) else None),
                **{**kwargs, "headers": _headers})
            if resp.status_code in mrequests.SUCCESS_CODES:
                return Right(resp.json())
            return Left(resp)
        except Exception as exc:#pylint: disable=[broad-except]
            app.logger.error("Error retrieving data from auth server: (POST %s)",
                             _uri,
                             exc_info=True)
            return Left(exc)
    return session.user_token().either(__no_token__, __post__)