aboutsummaryrefslogtreecommitdiff
path: root/uploader/oauth2/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'uploader/oauth2/client.py')
-rw-r--r--uploader/oauth2/client.py101
1 files changed, 101 insertions, 0 deletions
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py
new file mode 100644
index 0000000..6e101ae
--- /dev/null
+++ b/uploader/oauth2/client.py
@@ -0,0 +1,101 @@
+"""OAuth2 client utilities."""
+from urllib.parse import urljoin
+from datetime import datetime, timedelta
+
+import requests
+from flask import current_app as app
+
+from pymonad.either import Left, Right
+
+from authlib.jose import jwt
+from authlib.jose import KeySet, JsonWebKey
+from authlib.jose.errors import BadSignatureError
+from authlib.integrations.requests_client import OAuth2Session
+
+from uploader import session
+
+SCOPE = ("profile group role resource register-client user masquerade "
+ "introspect migrate-data")
+
+
+def authserver_uri():
+ """Return URI to authorisation server."""
+ return app.config["AUTH_SERVER_URL"]
+
+
+def oauth2_clientid():
+ """Return the client id."""
+ return app.config["OAUTH2_CLIENT_ID"]
+
+
+def oauth2_clientsecret():
+ """Return the client secret."""
+ return app.config["OAUTH2_CLIENT_SECRET"]
+
+
+def __make_token_validator__(keys: KeySet):
+ """Make a token validator function."""
+ def __validator__(token: str):
+ try:
+ jwt.decode(token, keys)
+ return Right(token)
+ except BadSignatureError:
+ return Left("INVALID-TOKEN")
+
+ return __validator__
+
+
+def __validate_token__(sess_info):
+ """Validate that the token is really from the auth server."""
+ info = __update_auth_server_jwks__(sess_info)
+ info["user"]["token"] = info["user"]["token"].then(__make_token_validator__(
+ KeySet(JsonWebKey(key) for key in info.get(
+ "auth_server_jwks", {}).get(
+ "jwks", {"keys": []})["keys"])))
+ return session.save_session_info(info)
+
+
+def __update_auth_server_jwks__(sess_info):
+ """Updates the JWKs every 2 hours or so."""
+ jwks = sess_info.get("auth_server_jwks")
+ if bool(jwks):
+ last_updated = jwks.get("last-updated")
+ now = datetime.now().timestamp()
+ if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds:
+ return __validate_token__({**sess_info, "auth_server_jwks": jwks})
+
+ jwksuri = urljoin(authserver_uri(), "auth/public-jwks")
+ return __validate_token__({
+ **sess_info,
+ "auth_server_jwks": {
+ "jwks": KeySet([
+ JsonWebKey.import_key(key)
+ for key in requests.get(jwksuri).json()["jwks"]]).as_dict(),
+ "last-updated": datetime.now().timestamp()
+ }
+ })
+
+
+def oauth2_client():
+ """Build the OAuth2 client for use fetching data."""
+ def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument]
+ """Update the token when refreshed."""
+ app.logger.debug(f"IN `{__name__}`:\n\trefresh_token: {refresh_token}"
+ f"\n\taccess_token: {access_token}\n\ttoken: {token}")
+ session.set_user_token(token)
+
+ def __client__(token) -> OAuth2Session:
+ client = OAuth2Session(
+ oauth2_clientid(),
+ oauth2_clientsecret(),
+ scope=SCOPE,
+ token_endpoint=urljoin(authserver_uri(), "/auth/token"),
+ token_endpoint_auth_method="client_secret_post",
+ token=token,
+ update_token=__update_token__)
+ return client
+
+ __update_auth_server_jwks__(session.session_info())
+ return session.user_token().either(
+ lambda _notok: __client__(None),
+ __client__)