aboutsummaryrefslogtreecommitdiff
path: root/uploader/oauth2/client.py
blob: 6e101ae38c7b207fabca4c6172951371dfdedd36 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""OAuth2 client utilities."""
from urllib.parse import urljoin
from datetime import datetime, timedelta

import requests
from flask import current_app as app

from pymonad.either import Left, Right

from authlib.jose import jwt
from authlib.jose import KeySet, JsonWebKey
from authlib.jose.errors import BadSignatureError
from authlib.integrations.requests_client import OAuth2Session

from uploader import session

SCOPE = ("profile group role resource register-client user masquerade "
         "introspect migrate-data")


def authserver_uri():
    """Return URI to authorisation server."""
    return app.config["AUTH_SERVER_URL"]


def oauth2_clientid():
    """Return the client id."""
    return app.config["OAUTH2_CLIENT_ID"]


def oauth2_clientsecret():
    """Return the client secret."""
    return app.config["OAUTH2_CLIENT_SECRET"]


def __make_token_validator__(keys: KeySet):
    """Make a token validator function."""
    def __validator__(token: str):
        try:
            jwt.decode(token, keys)
            return Right(token)
        except BadSignatureError:
            return Left("INVALID-TOKEN")

    return __validator__


def __validate_token__(sess_info):
    """Validate that the token is really from the auth server."""
    info = __update_auth_server_jwks__(sess_info)
    info["user"]["token"] = info["user"]["token"].then(__make_token_validator__(
        KeySet(JsonWebKey(key) for key in info.get(
            "auth_server_jwks", {}).get(
                "jwks", {"keys": []})["keys"])))
    return session.save_session_info(info)


def __update_auth_server_jwks__(sess_info):
    """Updates the JWKs every 2 hours or so."""
    jwks = sess_info.get("auth_server_jwks")
    if bool(jwks):
        last_updated = jwks.get("last-updated")
        now = datetime.now().timestamp()
        if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds:
            return __validate_token__({**sess_info, "auth_server_jwks": jwks})

    jwksuri = urljoin(authserver_uri(), "auth/public-jwks")
    return __validate_token__({
        **sess_info,
        "auth_server_jwks": {
            "jwks": KeySet([
                JsonWebKey.import_key(key)
                for key in requests.get(jwksuri).json()["jwks"]]).as_dict(),
            "last-updated": datetime.now().timestamp()
        }
    })


def oauth2_client():
    """Build the OAuth2 client for use fetching data."""
    def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument]
        """Update the token when refreshed."""
        app.logger.debug(f"IN `{__name__}`:\n\trefresh_token: {refresh_token}"
                         f"\n\taccess_token: {access_token}\n\ttoken: {token}")
        session.set_user_token(token)

    def __client__(token) -> OAuth2Session:
        client = OAuth2Session(
            oauth2_clientid(),
            oauth2_clientsecret(),
            scope=SCOPE,
            token_endpoint=urljoin(authserver_uri(), "/auth/token"),
            token_endpoint_auth_method="client_secret_post",
            token=token,
            update_token=__update_token__)
        return client

    __update_auth_server_jwks__(session.session_info())
    return session.user_token().either(
        lambda _notok: __client__(None),
        __client__)