diff options
Diffstat (limited to 'gn2/wqflask/oauth2')
-rw-r--r-- | gn2/wqflask/oauth2/client.py | 121 | ||||
-rw-r--r-- | gn2/wqflask/oauth2/jwks.py | 86 | ||||
-rw-r--r-- | gn2/wqflask/oauth2/resources.py | 6 | ||||
-rw-r--r-- | gn2/wqflask/oauth2/session.py | 18 | ||||
-rw-r--r-- | gn2/wqflask/oauth2/toplevel.py | 15 |
5 files changed, 232 insertions, 14 deletions
diff --git a/gn2/wqflask/oauth2/client.py b/gn2/wqflask/oauth2/client.py index 770777b5..a7d20f6b 100644 --- a/gn2/wqflask/oauth2/client.py +++ b/gn2/wqflask/oauth2/client.py @@ -1,12 +1,17 @@ """Common oauth2 client utilities.""" import json +import time +import random import requests from typing import Optional from urllib.parse import urljoin +from datetime import datetime, timedelta from flask import current_app as app from pymonad.either import Left, Right, Either -from authlib.jose import jwt +from authlib.common.urls import url_decode +from authlib.jose.errors import BadSignatureError +from authlib.jose import KeySet, JsonWebKey, JsonWebToken from authlib.integrations.requests_client import OAuth2Session from gn2.wqflask.oauth2 import session @@ -34,24 +39,130 @@ def user_logged_in(): return suser["logged_in"] and suser["token"].is_right() +def __make_token_validator__(keys: KeySet): + """Make a token validator function.""" + def __validator__(token: dict): + for key in keys.keys: + try: + # Fixes CVE-2016-10555. See + # https://docs.authlib.org/en/latest/jose/jwt.html + jwt = JsonWebToken(["RS256"]) + jwt.decode(token["access_token"], key) + return Right(token) + except BadSignatureError: + pass + + return Left("INVALID-TOKEN") + + return __validator__ + + +def auth_server_jwks() -> Optional[KeySet]: + """Fetch the auth-server JSON Web Keys information.""" + _jwks = session.session_info().get("auth_server_jwks") + if bool(_jwks): + return { + "last-updated": _jwks["last-updated"], + "jwks": KeySet([ + JsonWebKey.import_key(key) for key in _jwks.get( + "auth_server_jwks", {}).get( + "jwks", {"keys": []})["keys"]])} + + +def __validate_token__(keys): + """Validate that the token is really from the auth server.""" + def __index__(_sess): + return _sess + return session.user_token().then(__make_token_validator__(keys)).then( + session.set_user_token).either(__index__, __index__) + + +def __update_auth_server_jwks__(): + """Updates the JWKs every 2 hours or so.""" + jwks = auth_server_jwks() + if bool(jwks): + last_updated = jwks.get("last-updated") + now = datetime.now().timestamp() + if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds: + return __validate_token__(jwks["jwks"]) + + jwksuri = urljoin(authserver_uri(), "auth/public-jwks") + jwks = KeySet([ + JsonWebKey.import_key(key) + for key in requests.get(jwksuri).json()["jwks"]]) + return __validate_token__(jwks) + + +def is_token_expired(token): + """Check whether the token has expired.""" + __update_auth_server_jwks__() + jwks = auth_server_jwks() + if bool(jwks): + for jwk in jwks["jwks"].keys: + try: + jwt = JsonWebToken(["RS256"]).decode( + token["access_token"], key=jwk) + return datetime.now().timestamp() > jwt["exp"] + except BadSignatureError as _bse: + pass + + return False + + def oauth2_client(): def __update_token__(token, refresh_token=None, access_token=None): """Update the token when refreshed.""" session.set_user_token(token) + return token + + def __delay__(): + """Do a tiny delay.""" + time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100)))) + + def __refresh_token__(token): + """Synchronise token refresh.""" + if is_token_expired(token): + __delay__() + if session.is_token_refreshing(): + while session.is_token_refreshing(): + __delay__() + + _token = session.user_token().either(None, lambda _tok: _tok) + return _token + + session.toggle_token_refreshing() + _client = __client__(token) + _client.get(urljoin(authserver_uri(), "auth/user/")) + session.toggle_token_refreshing() + return _client.token + + return token + + def __json_auth__(client, method, uri, headers, body): + return ( + uri, + {**headers, "Content-Type": "application/json"}, + json.dumps({ + **dict(url_decode(body)), + "client_id": oauth2_clientid(), + "client_secret": oauth2_clientsecret() + })) def __client__(token) -> OAuth2Session: - _jwt = jwt.decode(token["access_token"], - app.config["AUTH_SERVER_SSL_PUBLIC_KEY"]) client = OAuth2Session( oauth2_clientid(), oauth2_clientsecret(), scope=SCOPE, - token_endpoint=urljoin(authserver_uri(), "/auth/token"), + token_endpoint=urljoin(authserver_uri(), "auth/token"), token_endpoint_auth_method="client_secret_post", token=token, update_token=__update_token__) + client.register_client_auth_method( + ("client_secret_post", __json_auth__)) return client - return session.user_token().either( + + __update_auth_server_jwks__() + return session.user_token().then(__refresh_token__).either( lambda _notok: __client__(None), lambda token: __client__(token)) diff --git a/gn2/wqflask/oauth2/jwks.py b/gn2/wqflask/oauth2/jwks.py new file mode 100644 index 00000000..efd04997 --- /dev/null +++ b/gn2/wqflask/oauth2/jwks.py @@ -0,0 +1,86 @@ +"""Utilities dealing with JSON Web Keys (JWK)""" +import os +from pathlib import Path +from typing import Any, Union +from datetime import datetime, timedelta + +from flask import Flask +from authlib.jose import JsonWebKey +from pymonad.either import Left, Right, Either + +def jwks_directory(app: Flask, configname: str) -> Path: + """Compute the directory where the JWKs are stored.""" + appsecretsdir = Path(app.config[configname]).parent + if appsecretsdir.exists() and appsecretsdir.is_dir(): + jwksdir = Path(appsecretsdir, "jwks/") + if not jwksdir.exists(): + jwksdir.mkdir() + return jwksdir + raise ValueError( + "The `appsecretsdir` value should be a directory that actually exists.") + + +def generate_and_save_private_key( + storagedir: Path, + kty: str = "RSA", + crv_or_size: Union[str, int] = 2048, + options: tuple[tuple[str, Any]] = (("iat", datetime.now().timestamp()),) +) -> JsonWebKey: + """Generate a private key and save to `storagedir`.""" + privatejwk = JsonWebKey.generate_key( + kty, crv_or_size, dict(options), is_private=True) + keyname = f"{privatejwk.thumbprint()}.private.pem" + with open(Path(storagedir, keyname), "wb") as pemfile: + pemfile.write(privatejwk.as_pem(is_private=True)) + + return privatejwk + + +def pem_to_jwk(filepath: Path) -> JsonWebKey: + """Parse a PEM file into a JWK object.""" + with open(filepath, "rb") as pemfile: + return JsonWebKey.import_key(pemfile.read()) + + +def __sorted_jwks_paths__(storagedir: Path) -> tuple[tuple[float, Path], ...]: + """A sorted list of the JWK file paths with their creation timestamps.""" + return tuple(sorted(((os.stat(keypath).st_ctime, keypath) + for keypath in (Path(storagedir, keyfile) + for keyfile in os.listdir(storagedir) + if keyfile.endswith(".pem"))), + key=lambda tpl: tpl[0])) + + +def list_jwks(storagedir: Path) -> tuple[JsonWebKey, ...]: + """ + List all the JWKs in a particular directory in the order they were created. + """ + return tuple(pem_to_jwk(keypath) for ctime,keypath in + __sorted_jwks_paths__(storagedir)) + + +def newest_jwk(storagedir: Path) -> Either: + """ + Return an Either monad with the newest JWK or a message if none exists. + """ + existingkeys = __sorted_jwks_paths__(storagedir) + if len(existingkeys) > 0: + return Right(pem_to_jwk(existingkeys[-1][1])) + return Left("No JWKs exist") + + +def newest_jwk_with_rotation(jwksdir: Path, keyage: int) -> JsonWebKey: + """ + Retrieve the latests JWK, creating a new one if older than `keyage` days. + """ + def newer_than_days(jwkey): + filestat = os.stat(Path( + jwksdir, f"{jwkey.as_dict()['kid']}.private.pem")) + oldesttimeallowed = (datetime.now() - timedelta(days=keyage)) + if filestat.st_ctime < (oldesttimeallowed.timestamp()): + return Left("JWK is too old!") + return jwkey + + return newest_jwk(jwksdir).then(newer_than_days).either( + lambda _errmsg: generate_and_save_private_key(jwksdir), + lambda key: key) diff --git a/gn2/wqflask/oauth2/resources.py b/gn2/wqflask/oauth2/resources.py index b8d804f9..7ea7fe38 100644 --- a/gn2/wqflask/oauth2/resources.py +++ b/gn2/wqflask/oauth2/resources.py @@ -129,8 +129,10 @@ def view_resource(resource_id: UUID): dataset_type = resource["resource_category"]["resource_category_key"] return oauth2_get(f"auth/group/{dataset_type}/unlinked-data").either( lambda err: render_ui( - "oauth2/view-resource.html", resource=resource, - unlinked_error=process_error(err)), + "oauth2/view-resource.html", + resource=resource, + unlinked_error=process_error(err), + count_per_page=count_per_page), lambda unlinked: __unlinked_success__(resource, unlinked)) def __fetch_resource_data__(resource): diff --git a/gn2/wqflask/oauth2/session.py b/gn2/wqflask/oauth2/session.py index eec48a7f..b91534b0 100644 --- a/gn2/wqflask/oauth2/session.py +++ b/gn2/wqflask/oauth2/session.py @@ -22,6 +22,8 @@ class SessionInfo(TypedDict): user_agent: str ip_addr: str masquerade: Optional[UserDetails] + refreshing_token: bool + auth_server_jwks: Optional[dict[str, Any]] __SESSION_KEY__ = "GN::2::session_info" # Do not use this outside this module!! @@ -61,7 +63,8 @@ def session_info() -> SessionInfo: "user_agent": request.headers.get("User-Agent"), "ip_addr": request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr), - "masquerading": None + "masquerading": None, + "token_refreshing": False })) @@ -102,3 +105,16 @@ def unset_masquerading(): "user": the_session["masquerading"], "masquerading": None }) + + +def toggle_token_refreshing(): + """Toggle the state of the token_refreshing variable.""" + _session = session_info() + return save_session_info({ + **_session, + "token_refreshing": not _session.get("token_refreshing", False)}) + + +def is_token_refreshing(): + """Returns whether the token is being refreshed or not.""" + return session_info().get("token_refreshing", False) diff --git a/gn2/wqflask/oauth2/toplevel.py b/gn2/wqflask/oauth2/toplevel.py index 210b0756..24d60311 100644 --- a/gn2/wqflask/oauth2/toplevel.py +++ b/gn2/wqflask/oauth2/toplevel.py @@ -13,6 +13,7 @@ from flask import (flash, render_template, current_app as app) +from . import jwks from . import session from .checks import require_oauth2 from .request_utils import user_details, process_error @@ -34,7 +35,9 @@ def authorisation_code(): code = request.args.get("code", "") if bool(code): base_url = urlparse(request.base_url, scheme=request.scheme) - jwtkey = app.config["SSL_PRIVATE_KEY"] + jwtkey = jwks.newest_jwk_with_rotation( + jwks.jwks_directory(app, "GN2_SECRETS"), + int(app.config["JWKS_ROTATION_AGE_DAYS"])) issued = datetime.datetime.now() request_data = { "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", @@ -80,9 +83,8 @@ def authorisation_code(): }) return redirect("/") - return no_token_post( - "auth/token", json=request_data).either( - lambda err: __error__(process_error(err)), __success__) + return no_token_post("auth/token", json=request_data).either( + lambda err: __error__(process_error(err)), __success__) flash("AuthorisationError: No code was provided.", "alert-danger") return redirect("/") @@ -91,6 +93,7 @@ def authorisation_code(): def public_jwks(): """Provide endpoint that returns the public keys.""" return jsonify({ - "documentation": "Returns a static key for the time being. This will change.", - "jwks": KeySet([app.config["SSL_PRIVATE_KEY"]]).as_dict().get("keys") + "documentation": "The keys are listed in order of creation.", + "jwks": KeySet(jwks.list_jwks( + jwks.jwks_directory(app, "GN2_SECRETS"))).as_dict().get("keys") }) |