about summary refs log tree commit diff
path: root/gn_auth/auth/authentication/oauth2
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-07-18 16:56:33 -0500
committerFrederick Muriuki Muriithi2024-07-31 09:30:21 -0500
commit7eb26c8e0a01b61a0e79d2acc8ba010660aaa010 (patch)
treec31454006d505a11c23467bd3d05a1846cf963fa /gn_auth/auth/authentication/oauth2
parent8a3a16f25f6d87b6cf679c888eacba816415baa9 (diff)
downloadgn-auth-7eb26c8e0a01b61a0e79d2acc8ba010660aaa010.tar.gz
Validate JWTs against all existing JWKs.
Diffstat (limited to 'gn_auth/auth/authentication/oauth2')
-rw-r--r--gn_auth/auth/authentication/oauth2/resource_server.py35
-rw-r--r--gn_auth/auth/authentication/oauth2/server.py10
2 files changed, 39 insertions, 6 deletions
diff --git a/gn_auth/auth/authentication/oauth2/resource_server.py b/gn_auth/auth/authentication/oauth2/resource_server.py
index 2405ee2..6ebaecb 100644
--- a/gn_auth/auth/authentication/oauth2/resource_server.py
+++ b/gn_auth/auth/authentication/oauth2/resource_server.py
@@ -1,9 +1,14 @@
 """Protect the resources endpoints"""
+from datetime import datetime, timezone, timedelta
 
 from flask import current_app as app
+
+from authlib.jose import KeySet
+from authlib.oauth2.rfc7523 import JWTBearerTokenValidator as _JWTBearerTokenValidator
 from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator
 from authlib.integrations.flask_oauth2 import ResourceProtector
 
+from gn_auth.auth.jwks import list_jwks, jwks_directory
 from gn_auth.auth.db import sqlite3 as db
 from gn_auth.auth.authentication.oauth2.models.oauth2token import token_by_access_token
 
@@ -14,4 +19,34 @@ class BearerTokenValidator(_BearerTokenValidator):
             return token_by_access_token(conn, token_string).maybe(# type: ignore[misc]
                 None, lambda tok: tok)
 
+class JWTBearerTokenValidator(_JWTBearerTokenValidator):
+    """Validate a token using all the keys"""
+    _local_attributes = ("jwt_refresh_frequency_hours",)
+
+    def __init__(self, public_key, issuer=None, realm=None, **extra_attributes):
+        """Initialise the validator class."""
+        # https://docs.authlib.org/en/latest/jose/jwt.html#use-dynamic-keys
+        # We can simply use the KeySet rather than a specific key.
+        super().__init__(public_key,
+                         issuer,
+                         realm,
+                         **{
+                             key: value for key,value
+                             in extra_attributes.items()
+                             if key not in self._local_attributes
+                         })
+        self._last_jwks_update = datetime.now(tz=timezone.utc)
+        self._refresh_frequency = timedelta(hours=int(
+            extra_attributes.get("jwt_refresh_frequency_hours", 6)))
+
+    def __refresh_jwks__(self):
+        now = datetime.now(tz=timezone.utc)
+        if (now - self._last_jwks_update) >= self._refresh_frequency:
+            self.public_key = KeySet(list_jwks(jwks_directory(app)))
+
+    def authenticate_token(self, token_string: str):
+        self.__refresh_jwks__()
+        return super().authenticate_token(token_string)
+
+
 require_oauth = ResourceProtector()
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py
index 5806da6..63cbf37 100644
--- a/gn_auth/auth/authentication/oauth2/server.py
+++ b/gn_auth/auth/authentication/oauth2/server.py
@@ -7,14 +7,13 @@ from datetime import datetime, timedelta
 
 from pymonad.either import Left
 from flask import Flask, current_app
-from authlib.oauth2.rfc7523 import JWTBearerTokenValidator
-from authlib.jose import jwk, jwt, JsonWebKey
+from authlib.jose import jwt, KeySet, JsonWebKey
 from authlib.oauth2.rfc6749.errors import InvalidClientError
 from authlib.integrations.flask_oauth2 import AuthorizationServer
 
 from gn_auth.auth.db import sqlite3 as db
 from gn_auth.auth.jwks import (
-    newest_jwk, jwks_directory, generate_and_save_private_key)
+    list_jwks, newest_jwk, jwks_directory, generate_and_save_private_key)
 
 from .models.oauth2client import client as fetch_client
 from .models.oauth2token import OAuth2Token, save_token
@@ -32,7 +31,7 @@ from .grants.jwt_bearer_grant import JWTBearerGrant, JWTBearerTokenGenerator
 from .endpoints.revocation import RevocationEndpoint
 from .endpoints.introspection import IntrospectionEndpoint
 
-from .resource_server import require_oauth, BearerTokenValidator
+from .resource_server import require_oauth, JWTBearerTokenValidator
 
 
 def create_query_client_func() -> Callable:
@@ -164,6 +163,5 @@ def setup_oauth2_server(app: Flask) -> None:
     app.config["OAUTH2_SERVER"] = server
 
     ## Set up the token validators
-    require_oauth.register_token_validator(BearerTokenValidator())
     require_oauth.register_token_validator(
-        JWTBearerTokenValidator(app.config["SSL_PRIVATE_KEY"].get_public_key()))
+        JWTBearerTokenValidator(KeySet(list_jwks(jwks_directory(app)))))