aboutsummaryrefslogtreecommitdiff
path: root/uploader/oauth2/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'uploader/oauth2/client.py')
-rw-r--r--uploader/oauth2/client.py230
1 files changed, 230 insertions, 0 deletions
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py
new file mode 100644
index 0000000..e119cc3
--- /dev/null
+++ b/uploader/oauth2/client.py
@@ -0,0 +1,230 @@
+"""OAuth2 client utilities."""
+import json
+import time
+import random
+from datetime import datetime, timedelta
+from urllib.parse import urljoin, urlparse
+
+import requests
+from flask import request, current_app as app
+
+from pymonad.either import Left, Right, Either
+
+from authlib.common.urls import url_decode
+from authlib.jose.errors import BadSignatureError
+from authlib.jose import KeySet, JsonWebKey, JsonWebToken
+from authlib.integrations.requests_client import OAuth2Session
+
+from uploader import session
+import uploader.monadic_requests as mrequests
+
+SCOPE = ("profile group role resource register-client user masquerade "
+ "introspect migrate-data")
+
+
+def authserver_uri():
+ """Return URI to authorisation server."""
+ return app.config["AUTH_SERVER_URL"]
+
+
+def oauth2_clientid():
+ """Return the client id."""
+ return app.config["OAUTH2_CLIENT_ID"]
+
+
+def oauth2_clientsecret():
+ """Return the client secret."""
+ return app.config["OAUTH2_CLIENT_SECRET"]
+
+
+def __fetch_auth_server_jwks__() -> KeySet:
+ """Fetch the JWKs from the auth server."""
+ return KeySet([
+ JsonWebKey.import_key(key)
+ for key in requests.get(
+ urljoin(authserver_uri(), "auth/public-jwks")
+ ).json()["jwks"]])
+
+
+def __update_auth_server_jwks__(jwks) -> KeySet:
+ """Update the JWKs from the servers if necessary."""
+ last_updated = jwks["last-updated"]
+ now = datetime.now().timestamp()
+ # Maybe the `two_hours` variable below can be made into a configuration
+ # variable and passed in to this function
+ two_hours = timedelta(hours=2).seconds
+ if bool(last_updated) and (now - last_updated) < two_hours:
+ return jwks["jwks"]
+
+ return session.set_auth_server_jwks(__fetch_auth_server_jwks__())
+
+
+def auth_server_jwks() -> KeySet:
+ """Fetch the auth-server JSON Web Keys information."""
+ _jwks = session.session_info().get("auth_server_jwks")
+ if bool(_jwks):
+ return __update_auth_server_jwks__({
+ "last-updated": _jwks["last-updated"],
+ "jwks": KeySet([
+ JsonWebKey.import_key(key) for key in _jwks.get(
+ "jwks", {"keys": []})["keys"]])
+ })
+
+ return __update_auth_server_jwks__({
+ "last-updated": (datetime.now() - timedelta(hours=3)).timestamp()
+ })
+
+
+def oauth2_client():
+ """Build the OAuth2 client for use fetching data."""
+ def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument]
+ """Update the token when refreshed."""
+ session.set_user_token(token)
+
+ def __json_auth__(client, _method, uri, headers, body):
+ return (
+ uri,
+ {**headers, "Content-Type": "application/json"},
+ json.dumps({
+ **dict(url_decode(body)),
+ "client_id": client.client_id,
+ "client_secret": client.client_secret
+ }))
+
+ def __client__(token) -> OAuth2Session:
+ client = OAuth2Session(
+ oauth2_clientid(),
+ oauth2_clientsecret(),
+ scope=SCOPE,
+ token_endpoint=urljoin(authserver_uri(), "/auth/token"),
+ token_endpoint_auth_method="client_secret_post",
+ token=token,
+ update_token=__update_token__)
+ client.register_client_auth_method(
+ ("client_secret_post", __json_auth__))
+ return client
+
+ def __token_expired__(token):
+ """Check whether the token has expired."""
+ jwks = auth_server_jwks()
+ if bool(jwks):
+ for jwk in jwks.keys:
+ try:
+ jwt = JsonWebToken(["RS256"]).decode(
+ token["access_token"], key=jwk)
+ return datetime.now().timestamp() > jwt["exp"]
+ except BadSignatureError as _bse:
+ pass
+
+ return False
+
+ def __delay__():
+ """Do a tiny delay."""
+ time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
+
+ def __refresh_token__(token):
+ """Refresh the token if necessary — synchronise amongst threads."""
+ if __token_expired__(token):
+ __delay__()
+ if session.is_token_refreshing():
+ while session.is_token_refreshing():
+ __delay__()
+
+ return session.user_token().either(None, lambda _tok: _tok)
+
+ session.toggle_token_refreshing()
+ _client = __client__(token)
+ _client.get(urljoin(authserver_uri(), "auth/user/"))
+ session.toggle_token_refreshing()
+ return _client.token
+
+ return token
+
+ return session.user_token().then(__refresh_token__).either(
+ lambda _notok: __client__(None),
+ __client__)
+
+
+def user_logged_in():
+ """Check whether the user has logged in."""
+ suser = session.session_info()["user"]
+ return suser["logged_in"] and suser["token"].is_right()
+
+
+def authserver_authorise_uri():
+ """Build up the authorisation URI."""
+ req_baseurl = urlparse(request.base_url, scheme=request.scheme)
+ host_uri = f"{req_baseurl.scheme}://{req_baseurl.netloc}/"
+ return urljoin(
+ authserver_uri(),
+ "auth/authorise?response_type=code"
+ f"&client_id={oauth2_clientid()}"
+ f"&redirect_uri={urljoin(host_uri, 'oauth2/code')}")
+
+
+def __no_token__(_err) -> Left:
+ """Handle situation where request is attempted with no token."""
+ resp = requests.models.Response()
+ resp._content = json.dumps({#pylint: disable=[protected-access]
+ "error": "AuthenticationError",
+ "error-description": ("You need to authenticate to access requested "
+ "information.")}).encode("utf-8")
+ resp.status_code = 400
+ return Left(resp)
+
+
+def oauth2_get(url, **kwargs) -> Either:
+ """Do a get request to the authentication/authorisation server."""
+ def __get__(_token) -> Either:
+ _uri = urljoin(authserver_uri(), url)
+ try:
+ resp = oauth2_client().get(
+ _uri,
+ **{
+ **kwargs,
+ "headers": {
+ **kwargs.get("headers", {}),
+ "Content-Type": "application/json"
+ }
+ })
+ if resp.status_code in mrequests.SUCCESS_CODES:
+ return Right(resp.json())
+ return Left(resp)
+ except Exception as exc:#pylint: disable=[broad-except]
+ app.logger.error("Error retriving data from auth server: (GET %s)",
+ _uri,
+ exc_info=True)
+ return Left(exc)
+ return session.user_token().either(__no_token__, __get__)
+
+
+def oauth2_post(url, data=None, json=None, **kwargs):#pylint: disable=[redefined-outer-name]
+ """Do a POST request to the authentication/authorisation server."""
+ def __post__(_token) -> Either:
+ _uri = urljoin(authserver_uri(), url)
+ _headers = ({
+ **kwargs.get("headers", {}),
+ "Content-Type": "application/json"
+ }
+ if bool(json) else kwargs.get("headers", {}))
+ try:
+ request_data = {
+ **(data or {}),
+ **(json or {}),
+ "client_id": oauth2_clientid(),
+ "client_secret": oauth2_clientsecret()
+ }
+ resp = oauth2_client().post(
+ _uri,
+ data=(request_data if bool(data) else None),
+ json=(request_data if bool(json) else None),
+ **{**kwargs, "headers": _headers})
+ if resp.status_code in mrequests.SUCCESS_CODES:
+ return Right(resp.json())
+ return Left(resp)
+ except Exception as exc:#pylint: disable=[broad-except]
+ app.logger.error("Error retriving data from auth server: (POST %s)",
+ _uri,
+ exc_info=True)
+ return Left(exc)
+ return session.user_token().either(__no_token__, __post__)