about summary refs log tree commit diff
path: root/gn_auth/auth/authentication
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-07-18 16:43:17 -0500
committerFrederick Muriuki Muriithi2024-07-31 09:30:20 -0500
commit945c70b238ec3cc31613b3c17d5ad57c5a2eedee (patch)
treedbcf10f72959259f55e1f62534635e9c1c3dab2f /gn_auth/auth/authentication
parent284b5baaffdd26599224d9c69ecd8f202b7277cb (diff)
downloadgn-auth-945c70b238ec3cc31613b3c17d5ad57c5a2eedee.tar.gz
Retrieve newest JWK, creating a new JWK where necessary.
To help with key rotation, we fetch the latest key, creating a new JWK
in any of the following 2 conditions:
* There is no JWK in the first place
* The "newest" key is older than a specified number of days
Diffstat (limited to 'gn_auth/auth/authentication')
-rw-r--r--gn_auth/auth/authentication/oauth2/server.py43
1 files changed, 33 insertions, 10 deletions
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py
index d845c60..d1aa69e 100644
--- a/gn_auth/auth/authentication/oauth2/server.py
+++ b/gn_auth/auth/authentication/oauth2/server.py
@@ -1,15 +1,20 @@
 """Initialise the OAuth2 Server"""
+import os
 import uuid
-import datetime
+from pathlib import Path
 from typing import Callable
+from datetime import datetime, timedelta
 
+from pymonad.either import Left
 from flask import Flask, current_app
-from authlib.jose import jwk, jwt
 from authlib.oauth2.rfc7523 import JWTBearerTokenValidator
+from authlib.jose import jwk, jwt, JsonWebKey
 from authlib.oauth2.rfc6749.errors import InvalidClientError
 from authlib.integrations.flask_oauth2 import AuthorizationServer
 
 from gn_auth.auth.db import sqlite3 as db
+from gn_auth.auth.jwks import (
+    newest_jwk, jwks_directory, generate_and_save_private_key)
 
 from .models.oauth2client import client as fetch_client
 from .models.oauth2token import OAuth2Token, save_token
@@ -86,10 +91,25 @@ def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable:
 
     return __save_token__
 
+def newest_jwk_with_rotation(jwksdir: Path, keyage: int) -> JsonWebKey:
+    """
+    Retrieve the latests JWK, creating a new one if older than `keyage` days.
+    """
+    def newer_than_days(jwkey):
+        filestat = os.stat(Path(
+            jwksdir, f"{jwkey.as_dict()['kid']}.private.pem"))
+        oldesttimeallowed = (datetime.now() - timedelta(days=keyage))
+        if filestat.st_ctime < (oldesttimeallowed.timestamp()):
+            return Left("JWK is too old!")
+        return jwkey
+
+    return newest_jwk(jwksdir).then(newer_than_days).either(
+        lambda _errmsg: generate_and_save_private_key(jwksdir),
+        lambda key: key)
+
 
 def make_jwt_token_generator(app):
     """Make token generator function."""
-    _gen = JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"])
     def __generator__(# pylint: disable=[too-many-arguments]
             grant_type,
             client,
@@ -98,13 +118,16 @@ def make_jwt_token_generator(app):
             expires_in=None,# pylint: disable=[unused-argument]
             include_refresh_token=True
     ):
-        return _gen.__call__(
-            grant_type,
-            client,
-            user,
-            scope,
-            JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
-            include_refresh_token)
+        return JWTBearerTokenGenerator(
+            newest_jwk_with_rotation(
+                jwks_directory(app),
+                int(app.config["JWKS_ROTATION_AGE_DAYS"]))).__call__(
+                        grant_type,
+                        client,
+                        user,
+                        scope,
+                        JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
+                        include_refresh_token)
     return __generator__