1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
|
"""OAuth2 client utilities."""
from urllib.parse import urljoin
from datetime import datetime, timedelta
import requests
from flask import current_app as app
from pymonad.either import Left, Right
from authlib.jose import jwt
from authlib.jose import KeySet, JsonWebKey
from authlib.jose.errors import BadSignatureError
from authlib.integrations.requests_client import OAuth2Session
from uploader import session
SCOPE = ("profile group role resource register-client user masquerade "
"introspect migrate-data")
def authserver_uri():
"""Return URI to authorisation server."""
return app.config["AUTH_SERVER_URL"]
def oauth2_clientid():
"""Return the client id."""
return app.config["OAUTH2_CLIENT_ID"]
def oauth2_clientsecret():
"""Return the client secret."""
return app.config["OAUTH2_CLIENT_SECRET"]
def __make_token_validator__(keys: KeySet):
"""Make a token validator function."""
def __validator__(token: str):
try:
jwt.decode(token, keys)
return Right(token)
except BadSignatureError:
return Left("INVALID-TOKEN")
return __validator__
def __validate_token__(sess_info):
"""Validate that the token is really from the auth server."""
info = __update_auth_server_jwks__(sess_info)
info["user"]["token"] = info["user"]["token"].then(__make_token_validator__(
KeySet(JsonWebKey(key) for key in info.get(
"auth_server_jwks", {}).get(
"jwks", {"keys": []})["keys"])))
return session.save_session_info(info)
def __update_auth_server_jwks__(sess_info):
"""Updates the JWKs every 2 hours or so."""
jwks = sess_info.get("auth_server_jwks")
if bool(jwks):
last_updated = jwks.get("last-updated")
now = datetime.now().timestamp()
if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds:
return __validate_token__({**sess_info, "auth_server_jwks": jwks})
jwksuri = urljoin(authserver_uri(), "auth/public-jwks")
return __validate_token__({
**sess_info,
"auth_server_jwks": {
"jwks": KeySet([
JsonWebKey.import_key(key)
for key in requests.get(jwksuri).json()["jwks"]]).as_dict(),
"last-updated": datetime.now().timestamp()
}
})
def oauth2_client():
"""Build the OAuth2 client for use fetching data."""
def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument]
"""Update the token when refreshed."""
app.logger.debug(f"IN `{__name__}`:\n\trefresh_token: {refresh_token}"
f"\n\taccess_token: {access_token}\n\ttoken: {token}")
session.set_user_token(token)
def __client__(token) -> OAuth2Session:
client = OAuth2Session(
oauth2_clientid(),
oauth2_clientsecret(),
scope=SCOPE,
token_endpoint=urljoin(authserver_uri(), "/auth/token"),
token_endpoint_auth_method="client_secret_post",
token=token,
update_token=__update_token__)
return client
__update_auth_server_jwks__(session.session_info())
return session.user_token().either(
lambda _notok: __client__(None),
__client__)
|