about summary refs log tree commit diff
path: root/gn2/wqflask/oauth2
diff options
context:
space:
mode:
Diffstat (limited to 'gn2/wqflask/oauth2')
-rw-r--r--gn2/wqflask/oauth2/client.py121
-rw-r--r--gn2/wqflask/oauth2/jwks.py86
-rw-r--r--gn2/wqflask/oauth2/resources.py6
-rw-r--r--gn2/wqflask/oauth2/session.py18
-rw-r--r--gn2/wqflask/oauth2/toplevel.py15
5 files changed, 232 insertions, 14 deletions
diff --git a/gn2/wqflask/oauth2/client.py b/gn2/wqflask/oauth2/client.py
index 770777b5..a7d20f6b 100644
--- a/gn2/wqflask/oauth2/client.py
+++ b/gn2/wqflask/oauth2/client.py
@@ -1,12 +1,17 @@
 """Common oauth2 client utilities."""
 import json
+import time
+import random
 import requests
 from typing import Optional
 from urllib.parse import urljoin
+from datetime import datetime, timedelta
 
 from flask import current_app as app
 from pymonad.either import Left, Right, Either
-from authlib.jose import jwt
+from authlib.common.urls import url_decode
+from authlib.jose.errors import BadSignatureError
+from authlib.jose import KeySet, JsonWebKey, JsonWebToken
 from authlib.integrations.requests_client import OAuth2Session
 
 from gn2.wqflask.oauth2 import session
@@ -34,24 +39,130 @@ def user_logged_in():
     return suser["logged_in"] and suser["token"].is_right()
 
 
+def __make_token_validator__(keys: KeySet):
+    """Make a token validator function."""
+    def __validator__(token: dict):
+        for key in keys.keys:
+            try:
+                # Fixes CVE-2016-10555. See
+                # https://docs.authlib.org/en/latest/jose/jwt.html
+                jwt = JsonWebToken(["RS256"])
+                jwt.decode(token["access_token"], key)
+                return Right(token)
+            except BadSignatureError:
+                pass
+
+        return Left("INVALID-TOKEN")
+
+    return __validator__
+
+
+def auth_server_jwks() -> Optional[KeySet]:
+    """Fetch the auth-server JSON Web Keys information."""
+    _jwks = session.session_info().get("auth_server_jwks")
+    if bool(_jwks):
+        return {
+            "last-updated": _jwks["last-updated"],
+            "jwks": KeySet([
+                JsonWebKey.import_key(key) for key in _jwks.get(
+                    "auth_server_jwks", {}).get(
+                        "jwks", {"keys": []})["keys"]])}
+
+
+def __validate_token__(keys):
+    """Validate that the token is really from the auth server."""
+    def __index__(_sess):
+        return _sess
+    return session.user_token().then(__make_token_validator__(keys)).then(
+        session.set_user_token).either(__index__, __index__)
+
+
+def __update_auth_server_jwks__():
+    """Updates the JWKs every 2 hours or so."""
+    jwks = auth_server_jwks()
+    if bool(jwks):
+        last_updated = jwks.get("last-updated")
+        now = datetime.now().timestamp()
+        if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds:
+            return __validate_token__(jwks["jwks"])
+
+    jwksuri = urljoin(authserver_uri(), "auth/public-jwks")
+    jwks = KeySet([
+        JsonWebKey.import_key(key)
+        for key in requests.get(jwksuri).json()["jwks"]])
+    return __validate_token__(jwks)
+
+
+def is_token_expired(token):
+    """Check whether the token has expired."""
+    __update_auth_server_jwks__()
+    jwks = auth_server_jwks()
+    if bool(jwks):
+        for jwk in jwks["jwks"].keys:
+            try:
+                jwt = JsonWebToken(["RS256"]).decode(
+                    token["access_token"], key=jwk)
+                return datetime.now().timestamp() > jwt["exp"]
+            except BadSignatureError as _bse:
+                pass
+
+    return False
+
+
 def oauth2_client():
     def __update_token__(token, refresh_token=None, access_token=None):
         """Update the token when refreshed."""
         session.set_user_token(token)
+        return token
+
+    def __delay__():
+        """Do a tiny delay."""
+        time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
+
+    def __refresh_token__(token):
+        """Synchronise token refresh."""
+        if is_token_expired(token):
+            __delay__()
+            if session.is_token_refreshing():
+                while session.is_token_refreshing():
+                    __delay__()
+
+                _token = session.user_token().either(None, lambda _tok: _tok)
+                return _token
+
+            session.toggle_token_refreshing()
+            _client = __client__(token)
+            _client.get(urljoin(authserver_uri(), "auth/user/"))
+            session.toggle_token_refreshing()
+            return _client.token
+
+        return token
+
+    def __json_auth__(client, method, uri, headers, body):
+        return (
+            uri,
+            {**headers, "Content-Type": "application/json"},
+            json.dumps({
+                **dict(url_decode(body)),
+                "client_id": oauth2_clientid(),
+                "client_secret": oauth2_clientsecret()
+            }))
 
     def __client__(token) -> OAuth2Session:
-        _jwt = jwt.decode(token["access_token"],
-                          app.config["AUTH_SERVER_SSL_PUBLIC_KEY"])
         client = OAuth2Session(
             oauth2_clientid(),
             oauth2_clientsecret(),
             scope=SCOPE,
-            token_endpoint=urljoin(authserver_uri(), "/auth/token"),
+            token_endpoint=urljoin(authserver_uri(), "auth/token"),
             token_endpoint_auth_method="client_secret_post",
             token=token,
             update_token=__update_token__)
+        client.register_client_auth_method(
+            ("client_secret_post", __json_auth__))
         return client
-    return session.user_token().either(
+
+    __update_auth_server_jwks__()
+    return session.user_token().then(__refresh_token__).either(
         lambda _notok: __client__(None),
         lambda token: __client__(token))
 
diff --git a/gn2/wqflask/oauth2/jwks.py b/gn2/wqflask/oauth2/jwks.py
new file mode 100644
index 00000000..efd04997
--- /dev/null
+++ b/gn2/wqflask/oauth2/jwks.py
@@ -0,0 +1,86 @@
+"""Utilities dealing with JSON Web Keys (JWK)"""
+import os
+from pathlib import Path
+from typing import Any, Union
+from datetime import datetime, timedelta
+
+from flask import Flask
+from authlib.jose import JsonWebKey
+from pymonad.either import Left, Right, Either
+
+def jwks_directory(app: Flask, configname: str) -> Path:
+    """Compute the directory where the JWKs are stored."""
+    appsecretsdir = Path(app.config[configname]).parent
+    if appsecretsdir.exists() and appsecretsdir.is_dir():
+        jwksdir = Path(appsecretsdir, "jwks/")
+        if not jwksdir.exists():
+            jwksdir.mkdir()
+        return jwksdir
+    raise ValueError(
+        "The `appsecretsdir` value should be a directory that actually exists.")
+
+
+def generate_and_save_private_key(
+        storagedir: Path,
+        kty: str = "RSA",
+        crv_or_size: Union[str, int] = 2048,
+        options: tuple[tuple[str, Any]] = (("iat", datetime.now().timestamp()),)
+) -> JsonWebKey:
+    """Generate a private key and save to `storagedir`."""
+    privatejwk = JsonWebKey.generate_key(
+        kty, crv_or_size, dict(options), is_private=True)
+    keyname = f"{privatejwk.thumbprint()}.private.pem"
+    with open(Path(storagedir, keyname), "wb") as pemfile:
+        pemfile.write(privatejwk.as_pem(is_private=True))
+
+    return privatejwk
+
+
+def pem_to_jwk(filepath: Path) -> JsonWebKey:
+    """Parse a PEM file into a JWK object."""
+    with open(filepath, "rb") as pemfile:
+        return JsonWebKey.import_key(pemfile.read())
+
+
+def __sorted_jwks_paths__(storagedir: Path) -> tuple[tuple[float, Path], ...]:
+    """A sorted list of the JWK file paths with their creation timestamps."""
+    return tuple(sorted(((os.stat(keypath).st_ctime, keypath)
+                         for keypath in (Path(storagedir, keyfile)
+                                         for keyfile in os.listdir(storagedir)
+                                         if keyfile.endswith(".pem"))),
+                        key=lambda tpl: tpl[0]))
+
+
+def list_jwks(storagedir: Path) -> tuple[JsonWebKey, ...]:
+    """
+    List all the JWKs in a particular directory in the order they were created.
+    """
+    return tuple(pem_to_jwk(keypath) for ctime,keypath in
+                 __sorted_jwks_paths__(storagedir))
+
+
+def newest_jwk(storagedir: Path) -> Either:
+    """
+    Return an Either monad with the newest JWK or a message if none exists.
+    """
+    existingkeys = __sorted_jwks_paths__(storagedir)
+    if len(existingkeys) > 0:
+        return Right(pem_to_jwk(existingkeys[-1][1]))
+    return Left("No JWKs exist")
+
+
+def newest_jwk_with_rotation(jwksdir: Path, keyage: int) -> JsonWebKey:
+    """
+    Retrieve the latests JWK, creating a new one if older than `keyage` days.
+    """
+    def newer_than_days(jwkey):
+        filestat = os.stat(Path(
+            jwksdir, f"{jwkey.as_dict()['kid']}.private.pem"))
+        oldesttimeallowed = (datetime.now() - timedelta(days=keyage))
+        if filestat.st_ctime < (oldesttimeallowed.timestamp()):
+            return Left("JWK is too old!")
+        return jwkey
+
+    return newest_jwk(jwksdir).then(newer_than_days).either(
+        lambda _errmsg: generate_and_save_private_key(jwksdir),
+        lambda key: key)
diff --git a/gn2/wqflask/oauth2/resources.py b/gn2/wqflask/oauth2/resources.py
index b8d804f9..7ea7fe38 100644
--- a/gn2/wqflask/oauth2/resources.py
+++ b/gn2/wqflask/oauth2/resources.py
@@ -129,8 +129,10 @@ def view_resource(resource_id: UUID):
         dataset_type = resource["resource_category"]["resource_category_key"]
         return oauth2_get(f"auth/group/{dataset_type}/unlinked-data").either(
             lambda err: render_ui(
-                "oauth2/view-resource.html", resource=resource,
-                unlinked_error=process_error(err)),
+                "oauth2/view-resource.html",
+                resource=resource,
+                unlinked_error=process_error(err),
+                count_per_page=count_per_page),
             lambda unlinked: __unlinked_success__(resource, unlinked))
 
     def __fetch_resource_data__(resource):
diff --git a/gn2/wqflask/oauth2/session.py b/gn2/wqflask/oauth2/session.py
index eec48a7f..b91534b0 100644
--- a/gn2/wqflask/oauth2/session.py
+++ b/gn2/wqflask/oauth2/session.py
@@ -22,6 +22,8 @@ class SessionInfo(TypedDict):
     user_agent: str
     ip_addr: str
     masquerade: Optional[UserDetails]
+    refreshing_token: bool
+    auth_server_jwks: Optional[dict[str, Any]]
 
 __SESSION_KEY__ = "GN::2::session_info" # Do not use this outside this module!!
 
@@ -61,7 +63,8 @@ def session_info() -> SessionInfo:
             "user_agent": request.headers.get("User-Agent"),
             "ip_addr": request.environ.get("HTTP_X_FORWARDED_FOR",
                                            request.remote_addr),
-            "masquerading": None
+            "masquerading": None,
+            "token_refreshing": False
         }))
 
 
@@ -102,3 +105,16 @@ def unset_masquerading():
         "user": the_session["masquerading"],
         "masquerading": None
     })
+
+
+def toggle_token_refreshing():
+    """Toggle the state of the token_refreshing variable."""
+    _session = session_info()
+    return save_session_info({
+        **_session,
+        "token_refreshing": not _session.get("token_refreshing", False)})
+
+
+def is_token_refreshing():
+    """Returns whether the token is being refreshed or not."""
+    return session_info().get("token_refreshing", False)
diff --git a/gn2/wqflask/oauth2/toplevel.py b/gn2/wqflask/oauth2/toplevel.py
index 210b0756..24d60311 100644
--- a/gn2/wqflask/oauth2/toplevel.py
+++ b/gn2/wqflask/oauth2/toplevel.py
@@ -13,6 +13,7 @@ from flask import (flash,
                    render_template,
                    current_app as app)
 
+from . import jwks
 from . import session
 from .checks import require_oauth2
 from .request_utils import user_details, process_error
@@ -34,7 +35,9 @@ def authorisation_code():
     code = request.args.get("code", "")
     if bool(code):
         base_url = urlparse(request.base_url, scheme=request.scheme)
-        jwtkey = app.config["SSL_PRIVATE_KEY"]
+        jwtkey = jwks.newest_jwk_with_rotation(
+            jwks.jwks_directory(app, "GN2_SECRETS"),
+            int(app.config["JWKS_ROTATION_AGE_DAYS"]))
         issued = datetime.datetime.now()
         request_data = {
             "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
@@ -80,9 +83,8 @@ def authorisation_code():
             })
             return redirect("/")
 
-        return no_token_post(
-            "auth/token", json=request_data).either(
-                lambda err: __error__(process_error(err)), __success__)
+        return no_token_post("auth/token", json=request_data).either(
+            lambda err: __error__(process_error(err)), __success__)
     flash("AuthorisationError: No code was provided.", "alert-danger")
     return redirect("/")
 
@@ -91,6 +93,7 @@ def authorisation_code():
 def public_jwks():
     """Provide endpoint that returns the public keys."""
     return jsonify({
-        "documentation": "Returns a static key for the time being. This will change.",
-        "jwks": KeySet([app.config["SSL_PRIVATE_KEY"]]).as_dict().get("keys")
+        "documentation": "The keys are listed in order of creation.",
+        "jwks": KeySet(jwks.list_jwks(
+            jwks.jwks_directory(app, "GN2_SECRETS"))).as_dict().get("keys")
     })