diff options
Diffstat (limited to 'uploader/oauth2')
-rw-r--r-- | uploader/oauth2/__init__.py | 1 | ||||
-rw-r--r-- | uploader/oauth2/client.py | 230 | ||||
-rw-r--r-- | uploader/oauth2/jwks.py | 86 | ||||
-rw-r--r-- | uploader/oauth2/views.py | 138 |
4 files changed, 455 insertions, 0 deletions
diff --git a/uploader/oauth2/__init__.py b/uploader/oauth2/__init__.py new file mode 100644 index 0000000..aaea638 --- /dev/null +++ b/uploader/oauth2/__init__.py @@ -0,0 +1 @@ +"""Package to handle OAuth2 authentication/authorisation issues.""" diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py new file mode 100644 index 0000000..e7128de --- /dev/null +++ b/uploader/oauth2/client.py @@ -0,0 +1,230 @@ +"""OAuth2 client utilities.""" +import json +import time +import random +from datetime import datetime, timedelta +from urllib.parse import urljoin, urlparse + +import requests +from flask import request, current_app as app + +from pymonad.either import Left, Right, Either + +from authlib.common.urls import url_decode +from authlib.jose.errors import BadSignatureError +from authlib.jose import KeySet, JsonWebKey, JsonWebToken +from authlib.integrations.requests_client import OAuth2Session + +from uploader import session +import uploader.monadic_requests as mrequests + +SCOPE = ("profile group role resource register-client user masquerade " + "introspect migrate-data") + + +def authserver_uri(): + """Return URI to authorisation server.""" + return app.config["AUTH_SERVER_URL"] + + +def oauth2_clientid(): + """Return the client id.""" + return app.config["OAUTH2_CLIENT_ID"] + + +def oauth2_clientsecret(): + """Return the client secret.""" + return app.config["OAUTH2_CLIENT_SECRET"] + + +def __fetch_auth_server_jwks__() -> KeySet: + """Fetch the JWKs from the auth server.""" + return KeySet([ + JsonWebKey.import_key(key) + for key in requests.get( + urljoin(authserver_uri(), "auth/public-jwks") + ).json()["jwks"]]) + + +def __update_auth_server_jwks__(jwks) -> KeySet: + """Update the JWKs from the servers if necessary.""" + last_updated = jwks["last-updated"] + now = datetime.now().timestamp() + # Maybe the `two_hours` variable below can be made into a configuration + # variable and passed in to this function + two_hours = timedelta(hours=2).seconds + if bool(last_updated) and (now - last_updated) < two_hours: + return jwks["jwks"] + + return session.set_auth_server_jwks(__fetch_auth_server_jwks__()) + + +def auth_server_jwks() -> KeySet: + """Fetch the auth-server JSON Web Keys information.""" + _jwks = session.session_info().get("auth_server_jwks") or {} + if bool(_jwks): + return __update_auth_server_jwks__({ + "last-updated": _jwks["last-updated"], + "jwks": KeySet([ + JsonWebKey.import_key(key) for key in _jwks.get( + "jwks", {"keys": []})["keys"]]) + }) + + return __update_auth_server_jwks__({ + "last-updated": (datetime.now() - timedelta(hours=3)).timestamp() + }) + + +def oauth2_client(): + """Build the OAuth2 client for use fetching data.""" + def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument] + """Update the token when refreshed.""" + session.set_user_token(token) + + def __json_auth__(client, _method, uri, headers, body): + return ( + uri, + {**headers, "Content-Type": "application/json"}, + json.dumps({ + **dict(url_decode(body)), + "client_id": client.client_id, + "client_secret": client.client_secret + })) + + def __client__(token) -> OAuth2Session: + client = OAuth2Session( + oauth2_clientid(), + oauth2_clientsecret(), + scope=SCOPE, + token_endpoint=urljoin(authserver_uri(), "/auth/token"), + token_endpoint_auth_method="client_secret_post", + token=token, + update_token=__update_token__) + client.register_client_auth_method( + ("client_secret_post", __json_auth__)) + return client + + def __token_expired__(token): + """Check whether the token has expired.""" + jwks = auth_server_jwks() + if bool(jwks): + for jwk in jwks.keys: + try: + jwt = JsonWebToken(["RS256"]).decode( + token["access_token"], key=jwk) + return datetime.now().timestamp() > jwt["exp"] + except BadSignatureError as _bse: + pass + + return False + + def __delay__(): + """Do a tiny delay.""" + time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100)))) + + def __refresh_token__(token): + """Refresh the token if necessary — synchronise amongst threads.""" + if __token_expired__(token): + __delay__() + if session.is_token_refreshing(): + while session.is_token_refreshing(): + __delay__() + + return session.user_token().either(None, lambda _tok: _tok) + + session.toggle_token_refreshing() + _client = __client__(token) + _client.get(urljoin(authserver_uri(), "auth/user/")) + session.toggle_token_refreshing() + return _client.token + + return token + + return session.user_token().then(__refresh_token__).either( + lambda _notok: __client__(None), + __client__) + + +def user_logged_in(): + """Check whether the user has logged in.""" + suser = session.session_info()["user"] + return suser["logged_in"] and suser["token"].is_right() + + +def authserver_authorise_uri(): + """Build up the authorisation URI.""" + req_baseurl = urlparse(request.base_url, scheme=request.scheme) + host_uri = f"{req_baseurl.scheme}://{req_baseurl.netloc}/" + return urljoin( + authserver_uri(), + "auth/authorise?response_type=code" + f"&client_id={oauth2_clientid()}" + f"&redirect_uri={urljoin(host_uri, 'oauth2/code')}") + + +def __no_token__(_err) -> Left: + """Handle situation where request is attempted with no token.""" + resp = requests.models.Response() + resp._content = json.dumps({#pylint: disable=[protected-access] + "error": "AuthenticationError", + "error-description": ("You need to authenticate to access requested " + "information.")}).encode("utf-8") + resp.status_code = 400 + return Left(resp) + + +def oauth2_get(url, **kwargs) -> Either: + """Do a get request to the authentication/authorisation server.""" + def __get__(_token) -> Either: + _uri = urljoin(authserver_uri(), url) + try: + resp = oauth2_client().get( + _uri, + **{ + **kwargs, + "headers": { + **kwargs.get("headers", {}), + "Content-Type": "application/json" + } + }) + if resp.status_code in mrequests.SUCCESS_CODES: + return Right(resp.json()) + return Left(resp) + except Exception as exc:#pylint: disable=[broad-except] + app.logger.error("Error retrieving data from auth server: (GET %s)", + _uri, + exc_info=True) + return Left(exc) + return session.user_token().either(__no_token__, __get__) + + +def oauth2_post(url, data=None, json=None, **kwargs):#pylint: disable=[redefined-outer-name] + """Do a POST request to the authentication/authorisation server.""" + def __post__(_token) -> Either: + _uri = urljoin(authserver_uri(), url) + _headers = ({ + **kwargs.get("headers", {}), + "Content-Type": "application/json" + } + if bool(json) else kwargs.get("headers", {})) + try: + request_data = { + **(data or {}), + **(json or {}), + "client_id": oauth2_clientid(), + "client_secret": oauth2_clientsecret() + } + resp = oauth2_client().post( + _uri, + data=(request_data if bool(data) else None), + json=(request_data if bool(json) else None), + **{**kwargs, "headers": _headers}) + if resp.status_code in mrequests.SUCCESS_CODES: + return Right(resp.json()) + return Left(resp) + except Exception as exc:#pylint: disable=[broad-except] + app.logger.error("Error retrieving data from auth server: (POST %s)", + _uri, + exc_info=True) + return Left(exc) + return session.user_token().either(__no_token__, __post__) diff --git a/uploader/oauth2/jwks.py b/uploader/oauth2/jwks.py new file mode 100644 index 0000000..efd0499 --- /dev/null +++ b/uploader/oauth2/jwks.py @@ -0,0 +1,86 @@ +"""Utilities dealing with JSON Web Keys (JWK)""" +import os +from pathlib import Path +from typing import Any, Union +from datetime import datetime, timedelta + +from flask import Flask +from authlib.jose import JsonWebKey +from pymonad.either import Left, Right, Either + +def jwks_directory(app: Flask, configname: str) -> Path: + """Compute the directory where the JWKs are stored.""" + appsecretsdir = Path(app.config[configname]).parent + if appsecretsdir.exists() and appsecretsdir.is_dir(): + jwksdir = Path(appsecretsdir, "jwks/") + if not jwksdir.exists(): + jwksdir.mkdir() + return jwksdir + raise ValueError( + "The `appsecretsdir` value should be a directory that actually exists.") + + +def generate_and_save_private_key( + storagedir: Path, + kty: str = "RSA", + crv_or_size: Union[str, int] = 2048, + options: tuple[tuple[str, Any]] = (("iat", datetime.now().timestamp()),) +) -> JsonWebKey: + """Generate a private key and save to `storagedir`.""" + privatejwk = JsonWebKey.generate_key( + kty, crv_or_size, dict(options), is_private=True) + keyname = f"{privatejwk.thumbprint()}.private.pem" + with open(Path(storagedir, keyname), "wb") as pemfile: + pemfile.write(privatejwk.as_pem(is_private=True)) + + return privatejwk + + +def pem_to_jwk(filepath: Path) -> JsonWebKey: + """Parse a PEM file into a JWK object.""" + with open(filepath, "rb") as pemfile: + return JsonWebKey.import_key(pemfile.read()) + + +def __sorted_jwks_paths__(storagedir: Path) -> tuple[tuple[float, Path], ...]: + """A sorted list of the JWK file paths with their creation timestamps.""" + return tuple(sorted(((os.stat(keypath).st_ctime, keypath) + for keypath in (Path(storagedir, keyfile) + for keyfile in os.listdir(storagedir) + if keyfile.endswith(".pem"))), + key=lambda tpl: tpl[0])) + + +def list_jwks(storagedir: Path) -> tuple[JsonWebKey, ...]: + """ + List all the JWKs in a particular directory in the order they were created. + """ + return tuple(pem_to_jwk(keypath) for ctime,keypath in + __sorted_jwks_paths__(storagedir)) + + +def newest_jwk(storagedir: Path) -> Either: + """ + Return an Either monad with the newest JWK or a message if none exists. + """ + existingkeys = __sorted_jwks_paths__(storagedir) + if len(existingkeys) > 0: + return Right(pem_to_jwk(existingkeys[-1][1])) + return Left("No JWKs exist") + + +def newest_jwk_with_rotation(jwksdir: Path, keyage: int) -> JsonWebKey: + """ + Retrieve the latests JWK, creating a new one if older than `keyage` days. + """ + def newer_than_days(jwkey): + filestat = os.stat(Path( + jwksdir, f"{jwkey.as_dict()['kid']}.private.pem")) + oldesttimeallowed = (datetime.now() - timedelta(days=keyage)) + if filestat.st_ctime < (oldesttimeallowed.timestamp()): + return Left("JWK is too old!") + return jwkey + + return newest_jwk(jwksdir).then(newer_than_days).either( + lambda _errmsg: generate_and_save_private_key(jwksdir), + lambda key: key) diff --git a/uploader/oauth2/views.py b/uploader/oauth2/views.py new file mode 100644 index 0000000..61037f3 --- /dev/null +++ b/uploader/oauth2/views.py @@ -0,0 +1,138 @@ +"""Views for OAuth2 related functionality.""" +import uuid +from datetime import datetime, timedelta +from urllib.parse import urljoin, urlparse, urlunparse + +from authlib.jose import jwt +from flask import ( + flash, + jsonify, + url_for, + request, + redirect, + Blueprint, + current_app as app) + +from uploader import session +from uploader import monadic_requests as mrequests +from uploader.monadic_requests import make_error_handler + +from . import jwks +from .client import ( + SCOPE, + oauth2_get, + user_logged_in, + authserver_uri, + oauth2_clientid, + oauth2_clientsecret) + +oauth2 = Blueprint("oauth2", __name__) + +@oauth2.route("/code") +def authorisation_code(): + """Receive authorisation code from auth server and use it to get token.""" + def __process_error__(resp_or_exception): + app.logger.debug("ERROR: (%s)", resp_or_exception) + flash("There was an error retrieving the authorisation token.", + "alert-danger") + return redirect("/") + + def __fail_set_user_details__(_failure): + app.logger.debug("Fetching user details fails: %s", _failure) + flash("Could not retrieve the user details", "alert-danger") + return redirect("/") + + def __success_set_user_details__(_success): + app.logger.debug("Session info: %s", _success) + return redirect("/") + + def __success__(token): + session.set_user_token(token) + return oauth2_get("auth/user/").then( + lambda usrdets: session.set_user_details({ + "user_id": uuid.UUID(usrdets["user_id"]), + "name": usrdets["name"], + "email": usrdets["email"], + "token": session.user_token(), + "logged_in": True})).either( + __fail_set_user_details__, + __success_set_user_details__) + + code = request.args.get("code", "").strip() + if not bool(code): + flash("AuthorisationError: No code was provided.", "alert-danger") + return redirect("/") + + baseurl = urlparse(request.base_url, scheme=request.scheme) + issued = datetime.now() + jwtkey = jwks.newest_jwk_with_rotation( + jwks.jwks_directory(app, "UPLOADER_SECRETS"), + int(app.config["JWKS_ROTATION_AGE_DAYS"])) + return mrequests.post( + urljoin(authserver_uri(), "auth/token"), + json={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "code": code, + "scope": SCOPE, + "redirect_uri": urljoin( + urlunparse(baseurl), + url_for("oauth2.authorisation_code")), + "assertion": jwt.encode( + header={ + "alg": "RS256", + "typ": "JWT", + "kid": jwtkey.as_dict()["kid"] + }, + payload={ + "iss": str(oauth2_clientid()), + "sub": request.args["user_id"], + "aud": urljoin(authserver_uri(),"auth/token"), + "exp": (issued + timedelta(minutes=5)).timestamp(), + "nbf": int(issued.timestamp()), + "iat": int(issued.timestamp()), + "jti": str(uuid.uuid4()) + }, + key=jwtkey).decode("utf8"), + "client_id": oauth2_clientid() + }).either(__process_error__, __success__) + +@oauth2.route("/public-jwks") +def public_jwks(): + """List the available JWKs""" + return jsonify({ + "documentation": ( + "The keys are listed in order of creation, from the oldest (first) " + "to the newest (last)."), + "jwks": tuple(key.as_dict() for key + in jwks.list_jwks(jwks.jwks_directory( + app, "UPLOADER_SECRETS"))) + }) + + +@oauth2.route("/logout", methods=["GET"]) +def logout(): + """Log out of any active sessions.""" + def __unset_session__(session_info): + _user = session_info["user"] + _user_str = f"{_user['name']} ({_user['email']})" + session.clear_session_info() + flash("Successfully logged out.", "alert-success") + return redirect("/") + + if user_logged_in(): + return session.user_token().then( + lambda _tok: mrequests.post( + urljoin(authserver_uri(), "auth/revoke"), + json={ + "token": _tok["refresh_token"], + "token_type_hint": "refresh_token", + "client_id": oauth2_clientid(), + "client_secret": oauth2_clientsecret() + })).either( + make_error_handler( + redirect_to=redirect("/"), + cleanup_thunk=lambda: __unset_session__( + session.session_info())), + lambda res: __unset_session__(session.session_info())) + flash("There is no user that is currently logged in.", "alert-info") + return redirect("/") |