aboutsummaryrefslogtreecommitdiff
path: root/uploader/oauth2
diff options
context:
space:
mode:
Diffstat (limited to 'uploader/oauth2')
-rw-r--r--uploader/oauth2/__init__.py1
-rw-r--r--uploader/oauth2/client.py230
-rw-r--r--uploader/oauth2/jwks.py86
-rw-r--r--uploader/oauth2/views.py138
4 files changed, 455 insertions, 0 deletions
diff --git a/uploader/oauth2/__init__.py b/uploader/oauth2/__init__.py
new file mode 100644
index 0000000..aaea638
--- /dev/null
+++ b/uploader/oauth2/__init__.py
@@ -0,0 +1 @@
+"""Package to handle OAuth2 authentication/authorisation issues."""
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py
new file mode 100644
index 0000000..e7128de
--- /dev/null
+++ b/uploader/oauth2/client.py
@@ -0,0 +1,230 @@
+"""OAuth2 client utilities."""
+import json
+import time
+import random
+from datetime import datetime, timedelta
+from urllib.parse import urljoin, urlparse
+
+import requests
+from flask import request, current_app as app
+
+from pymonad.either import Left, Right, Either
+
+from authlib.common.urls import url_decode
+from authlib.jose.errors import BadSignatureError
+from authlib.jose import KeySet, JsonWebKey, JsonWebToken
+from authlib.integrations.requests_client import OAuth2Session
+
+from uploader import session
+import uploader.monadic_requests as mrequests
+
+SCOPE = ("profile group role resource register-client user masquerade "
+ "introspect migrate-data")
+
+
+def authserver_uri():
+ """Return URI to authorisation server."""
+ return app.config["AUTH_SERVER_URL"]
+
+
+def oauth2_clientid():
+ """Return the client id."""
+ return app.config["OAUTH2_CLIENT_ID"]
+
+
+def oauth2_clientsecret():
+ """Return the client secret."""
+ return app.config["OAUTH2_CLIENT_SECRET"]
+
+
+def __fetch_auth_server_jwks__() -> KeySet:
+ """Fetch the JWKs from the auth server."""
+ return KeySet([
+ JsonWebKey.import_key(key)
+ for key in requests.get(
+ urljoin(authserver_uri(), "auth/public-jwks")
+ ).json()["jwks"]])
+
+
+def __update_auth_server_jwks__(jwks) -> KeySet:
+ """Update the JWKs from the servers if necessary."""
+ last_updated = jwks["last-updated"]
+ now = datetime.now().timestamp()
+ # Maybe the `two_hours` variable below can be made into a configuration
+ # variable and passed in to this function
+ two_hours = timedelta(hours=2).seconds
+ if bool(last_updated) and (now - last_updated) < two_hours:
+ return jwks["jwks"]
+
+ return session.set_auth_server_jwks(__fetch_auth_server_jwks__())
+
+
+def auth_server_jwks() -> KeySet:
+ """Fetch the auth-server JSON Web Keys information."""
+ _jwks = session.session_info().get("auth_server_jwks") or {}
+ if bool(_jwks):
+ return __update_auth_server_jwks__({
+ "last-updated": _jwks["last-updated"],
+ "jwks": KeySet([
+ JsonWebKey.import_key(key) for key in _jwks.get(
+ "jwks", {"keys": []})["keys"]])
+ })
+
+ return __update_auth_server_jwks__({
+ "last-updated": (datetime.now() - timedelta(hours=3)).timestamp()
+ })
+
+
+def oauth2_client():
+ """Build the OAuth2 client for use fetching data."""
+ def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument]
+ """Update the token when refreshed."""
+ session.set_user_token(token)
+
+ def __json_auth__(client, _method, uri, headers, body):
+ return (
+ uri,
+ {**headers, "Content-Type": "application/json"},
+ json.dumps({
+ **dict(url_decode(body)),
+ "client_id": client.client_id,
+ "client_secret": client.client_secret
+ }))
+
+ def __client__(token) -> OAuth2Session:
+ client = OAuth2Session(
+ oauth2_clientid(),
+ oauth2_clientsecret(),
+ scope=SCOPE,
+ token_endpoint=urljoin(authserver_uri(), "/auth/token"),
+ token_endpoint_auth_method="client_secret_post",
+ token=token,
+ update_token=__update_token__)
+ client.register_client_auth_method(
+ ("client_secret_post", __json_auth__))
+ return client
+
+ def __token_expired__(token):
+ """Check whether the token has expired."""
+ jwks = auth_server_jwks()
+ if bool(jwks):
+ for jwk in jwks.keys:
+ try:
+ jwt = JsonWebToken(["RS256"]).decode(
+ token["access_token"], key=jwk)
+ return datetime.now().timestamp() > jwt["exp"]
+ except BadSignatureError as _bse:
+ pass
+
+ return False
+
+ def __delay__():
+ """Do a tiny delay."""
+ time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
+
+ def __refresh_token__(token):
+ """Refresh the token if necessary — synchronise amongst threads."""
+ if __token_expired__(token):
+ __delay__()
+ if session.is_token_refreshing():
+ while session.is_token_refreshing():
+ __delay__()
+
+ return session.user_token().either(None, lambda _tok: _tok)
+
+ session.toggle_token_refreshing()
+ _client = __client__(token)
+ _client.get(urljoin(authserver_uri(), "auth/user/"))
+ session.toggle_token_refreshing()
+ return _client.token
+
+ return token
+
+ return session.user_token().then(__refresh_token__).either(
+ lambda _notok: __client__(None),
+ __client__)
+
+
+def user_logged_in():
+ """Check whether the user has logged in."""
+ suser = session.session_info()["user"]
+ return suser["logged_in"] and suser["token"].is_right()
+
+
+def authserver_authorise_uri():
+ """Build up the authorisation URI."""
+ req_baseurl = urlparse(request.base_url, scheme=request.scheme)
+ host_uri = f"{req_baseurl.scheme}://{req_baseurl.netloc}/"
+ return urljoin(
+ authserver_uri(),
+ "auth/authorise?response_type=code"
+ f"&client_id={oauth2_clientid()}"
+ f"&redirect_uri={urljoin(host_uri, 'oauth2/code')}")
+
+
+def __no_token__(_err) -> Left:
+ """Handle situation where request is attempted with no token."""
+ resp = requests.models.Response()
+ resp._content = json.dumps({#pylint: disable=[protected-access]
+ "error": "AuthenticationError",
+ "error-description": ("You need to authenticate to access requested "
+ "information.")}).encode("utf-8")
+ resp.status_code = 400
+ return Left(resp)
+
+
+def oauth2_get(url, **kwargs) -> Either:
+ """Do a get request to the authentication/authorisation server."""
+ def __get__(_token) -> Either:
+ _uri = urljoin(authserver_uri(), url)
+ try:
+ resp = oauth2_client().get(
+ _uri,
+ **{
+ **kwargs,
+ "headers": {
+ **kwargs.get("headers", {}),
+ "Content-Type": "application/json"
+ }
+ })
+ if resp.status_code in mrequests.SUCCESS_CODES:
+ return Right(resp.json())
+ return Left(resp)
+ except Exception as exc:#pylint: disable=[broad-except]
+ app.logger.error("Error retrieving data from auth server: (GET %s)",
+ _uri,
+ exc_info=True)
+ return Left(exc)
+ return session.user_token().either(__no_token__, __get__)
+
+
+def oauth2_post(url, data=None, json=None, **kwargs):#pylint: disable=[redefined-outer-name]
+ """Do a POST request to the authentication/authorisation server."""
+ def __post__(_token) -> Either:
+ _uri = urljoin(authserver_uri(), url)
+ _headers = ({
+ **kwargs.get("headers", {}),
+ "Content-Type": "application/json"
+ }
+ if bool(json) else kwargs.get("headers", {}))
+ try:
+ request_data = {
+ **(data or {}),
+ **(json or {}),
+ "client_id": oauth2_clientid(),
+ "client_secret": oauth2_clientsecret()
+ }
+ resp = oauth2_client().post(
+ _uri,
+ data=(request_data if bool(data) else None),
+ json=(request_data if bool(json) else None),
+ **{**kwargs, "headers": _headers})
+ if resp.status_code in mrequests.SUCCESS_CODES:
+ return Right(resp.json())
+ return Left(resp)
+ except Exception as exc:#pylint: disable=[broad-except]
+ app.logger.error("Error retrieving data from auth server: (POST %s)",
+ _uri,
+ exc_info=True)
+ return Left(exc)
+ return session.user_token().either(__no_token__, __post__)
diff --git a/uploader/oauth2/jwks.py b/uploader/oauth2/jwks.py
new file mode 100644
index 0000000..efd0499
--- /dev/null
+++ b/uploader/oauth2/jwks.py
@@ -0,0 +1,86 @@
+"""Utilities dealing with JSON Web Keys (JWK)"""
+import os
+from pathlib import Path
+from typing import Any, Union
+from datetime import datetime, timedelta
+
+from flask import Flask
+from authlib.jose import JsonWebKey
+from pymonad.either import Left, Right, Either
+
+def jwks_directory(app: Flask, configname: str) -> Path:
+ """Compute the directory where the JWKs are stored."""
+ appsecretsdir = Path(app.config[configname]).parent
+ if appsecretsdir.exists() and appsecretsdir.is_dir():
+ jwksdir = Path(appsecretsdir, "jwks/")
+ if not jwksdir.exists():
+ jwksdir.mkdir()
+ return jwksdir
+ raise ValueError(
+ "The `appsecretsdir` value should be a directory that actually exists.")
+
+
+def generate_and_save_private_key(
+ storagedir: Path,
+ kty: str = "RSA",
+ crv_or_size: Union[str, int] = 2048,
+ options: tuple[tuple[str, Any]] = (("iat", datetime.now().timestamp()),)
+) -> JsonWebKey:
+ """Generate a private key and save to `storagedir`."""
+ privatejwk = JsonWebKey.generate_key(
+ kty, crv_or_size, dict(options), is_private=True)
+ keyname = f"{privatejwk.thumbprint()}.private.pem"
+ with open(Path(storagedir, keyname), "wb") as pemfile:
+ pemfile.write(privatejwk.as_pem(is_private=True))
+
+ return privatejwk
+
+
+def pem_to_jwk(filepath: Path) -> JsonWebKey:
+ """Parse a PEM file into a JWK object."""
+ with open(filepath, "rb") as pemfile:
+ return JsonWebKey.import_key(pemfile.read())
+
+
+def __sorted_jwks_paths__(storagedir: Path) -> tuple[tuple[float, Path], ...]:
+ """A sorted list of the JWK file paths with their creation timestamps."""
+ return tuple(sorted(((os.stat(keypath).st_ctime, keypath)
+ for keypath in (Path(storagedir, keyfile)
+ for keyfile in os.listdir(storagedir)
+ if keyfile.endswith(".pem"))),
+ key=lambda tpl: tpl[0]))
+
+
+def list_jwks(storagedir: Path) -> tuple[JsonWebKey, ...]:
+ """
+ List all the JWKs in a particular directory in the order they were created.
+ """
+ return tuple(pem_to_jwk(keypath) for ctime,keypath in
+ __sorted_jwks_paths__(storagedir))
+
+
+def newest_jwk(storagedir: Path) -> Either:
+ """
+ Return an Either monad with the newest JWK or a message if none exists.
+ """
+ existingkeys = __sorted_jwks_paths__(storagedir)
+ if len(existingkeys) > 0:
+ return Right(pem_to_jwk(existingkeys[-1][1]))
+ return Left("No JWKs exist")
+
+
+def newest_jwk_with_rotation(jwksdir: Path, keyage: int) -> JsonWebKey:
+ """
+ Retrieve the latests JWK, creating a new one if older than `keyage` days.
+ """
+ def newer_than_days(jwkey):
+ filestat = os.stat(Path(
+ jwksdir, f"{jwkey.as_dict()['kid']}.private.pem"))
+ oldesttimeallowed = (datetime.now() - timedelta(days=keyage))
+ if filestat.st_ctime < (oldesttimeallowed.timestamp()):
+ return Left("JWK is too old!")
+ return jwkey
+
+ return newest_jwk(jwksdir).then(newer_than_days).either(
+ lambda _errmsg: generate_and_save_private_key(jwksdir),
+ lambda key: key)
diff --git a/uploader/oauth2/views.py b/uploader/oauth2/views.py
new file mode 100644
index 0000000..61037f3
--- /dev/null
+++ b/uploader/oauth2/views.py
@@ -0,0 +1,138 @@
+"""Views for OAuth2 related functionality."""
+import uuid
+from datetime import datetime, timedelta
+from urllib.parse import urljoin, urlparse, urlunparse
+
+from authlib.jose import jwt
+from flask import (
+ flash,
+ jsonify,
+ url_for,
+ request,
+ redirect,
+ Blueprint,
+ current_app as app)
+
+from uploader import session
+from uploader import monadic_requests as mrequests
+from uploader.monadic_requests import make_error_handler
+
+from . import jwks
+from .client import (
+ SCOPE,
+ oauth2_get,
+ user_logged_in,
+ authserver_uri,
+ oauth2_clientid,
+ oauth2_clientsecret)
+
+oauth2 = Blueprint("oauth2", __name__)
+
+@oauth2.route("/code")
+def authorisation_code():
+ """Receive authorisation code from auth server and use it to get token."""
+ def __process_error__(resp_or_exception):
+ app.logger.debug("ERROR: (%s)", resp_or_exception)
+ flash("There was an error retrieving the authorisation token.",
+ "alert-danger")
+ return redirect("/")
+
+ def __fail_set_user_details__(_failure):
+ app.logger.debug("Fetching user details fails: %s", _failure)
+ flash("Could not retrieve the user details", "alert-danger")
+ return redirect("/")
+
+ def __success_set_user_details__(_success):
+ app.logger.debug("Session info: %s", _success)
+ return redirect("/")
+
+ def __success__(token):
+ session.set_user_token(token)
+ return oauth2_get("auth/user/").then(
+ lambda usrdets: session.set_user_details({
+ "user_id": uuid.UUID(usrdets["user_id"]),
+ "name": usrdets["name"],
+ "email": usrdets["email"],
+ "token": session.user_token(),
+ "logged_in": True})).either(
+ __fail_set_user_details__,
+ __success_set_user_details__)
+
+ code = request.args.get("code", "").strip()
+ if not bool(code):
+ flash("AuthorisationError: No code was provided.", "alert-danger")
+ return redirect("/")
+
+ baseurl = urlparse(request.base_url, scheme=request.scheme)
+ issued = datetime.now()
+ jwtkey = jwks.newest_jwk_with_rotation(
+ jwks.jwks_directory(app, "UPLOADER_SECRETS"),
+ int(app.config["JWKS_ROTATION_AGE_DAYS"]))
+ return mrequests.post(
+ urljoin(authserver_uri(), "auth/token"),
+ json={
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
+ "code": code,
+ "scope": SCOPE,
+ "redirect_uri": urljoin(
+ urlunparse(baseurl),
+ url_for("oauth2.authorisation_code")),
+ "assertion": jwt.encode(
+ header={
+ "alg": "RS256",
+ "typ": "JWT",
+ "kid": jwtkey.as_dict()["kid"]
+ },
+ payload={
+ "iss": str(oauth2_clientid()),
+ "sub": request.args["user_id"],
+ "aud": urljoin(authserver_uri(),"auth/token"),
+ "exp": (issued + timedelta(minutes=5)).timestamp(),
+ "nbf": int(issued.timestamp()),
+ "iat": int(issued.timestamp()),
+ "jti": str(uuid.uuid4())
+ },
+ key=jwtkey).decode("utf8"),
+ "client_id": oauth2_clientid()
+ }).either(__process_error__, __success__)
+
+@oauth2.route("/public-jwks")
+def public_jwks():
+ """List the available JWKs"""
+ return jsonify({
+ "documentation": (
+ "The keys are listed in order of creation, from the oldest (first) "
+ "to the newest (last)."),
+ "jwks": tuple(key.as_dict() for key
+ in jwks.list_jwks(jwks.jwks_directory(
+ app, "UPLOADER_SECRETS")))
+ })
+
+
+@oauth2.route("/logout", methods=["GET"])
+def logout():
+ """Log out of any active sessions."""
+ def __unset_session__(session_info):
+ _user = session_info["user"]
+ _user_str = f"{_user['name']} ({_user['email']})"
+ session.clear_session_info()
+ flash("Successfully logged out.", "alert-success")
+ return redirect("/")
+
+ if user_logged_in():
+ return session.user_token().then(
+ lambda _tok: mrequests.post(
+ urljoin(authserver_uri(), "auth/revoke"),
+ json={
+ "token": _tok["refresh_token"],
+ "token_type_hint": "refresh_token",
+ "client_id": oauth2_clientid(),
+ "client_secret": oauth2_clientsecret()
+ })).either(
+ make_error_handler(
+ redirect_to=redirect("/"),
+ cleanup_thunk=lambda: __unset_session__(
+ session.session_info())),
+ lambda res: __unset_session__(session.session_info()))
+ flash("There is no user that is currently logged in.", "alert-info")
+ return redirect("/")