From 7eb26c8e0a01b61a0e79d2acc8ba010660aaa010 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Thu, 18 Jul 2024 16:56:33 -0500 Subject: Validate JWTs against all existing JWKs. --- .../auth/authentication/oauth2/resource_server.py | 35 ++++++++++++++++++++++ gn_auth/auth/authentication/oauth2/server.py | 10 +++---- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/gn_auth/auth/authentication/oauth2/resource_server.py b/gn_auth/auth/authentication/oauth2/resource_server.py index 2405ee2..6ebaecb 100644 --- a/gn_auth/auth/authentication/oauth2/resource_server.py +++ b/gn_auth/auth/authentication/oauth2/resource_server.py @@ -1,9 +1,14 @@ """Protect the resources endpoints""" +from datetime import datetime, timezone, timedelta from flask import current_app as app + +from authlib.jose import KeySet +from authlib.oauth2.rfc7523 import JWTBearerTokenValidator as _JWTBearerTokenValidator from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator from authlib.integrations.flask_oauth2 import ResourceProtector +from gn_auth.auth.jwks import list_jwks, jwks_directory from gn_auth.auth.db import sqlite3 as db from gn_auth.auth.authentication.oauth2.models.oauth2token import token_by_access_token @@ -14,4 +19,34 @@ class BearerTokenValidator(_BearerTokenValidator): return token_by_access_token(conn, token_string).maybe(# type: ignore[misc] None, lambda tok: tok) +class JWTBearerTokenValidator(_JWTBearerTokenValidator): + """Validate a token using all the keys""" + _local_attributes = ("jwt_refresh_frequency_hours",) + + def __init__(self, public_key, issuer=None, realm=None, **extra_attributes): + """Initialise the validator class.""" + # https://docs.authlib.org/en/latest/jose/jwt.html#use-dynamic-keys + # We can simply use the KeySet rather than a specific key. + super().__init__(public_key, + issuer, + realm, + **{ + key: value for key,value + in extra_attributes.items() + if key not in self._local_attributes + }) + self._last_jwks_update = datetime.now(tz=timezone.utc) + self._refresh_frequency = timedelta(hours=int( + extra_attributes.get("jwt_refresh_frequency_hours", 6))) + + def __refresh_jwks__(self): + now = datetime.now(tz=timezone.utc) + if (now - self._last_jwks_update) >= self._refresh_frequency: + self.public_key = KeySet(list_jwks(jwks_directory(app))) + + def authenticate_token(self, token_string: str): + self.__refresh_jwks__() + return super().authenticate_token(token_string) + + require_oauth = ResourceProtector() diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py index 5806da6..63cbf37 100644 --- a/gn_auth/auth/authentication/oauth2/server.py +++ b/gn_auth/auth/authentication/oauth2/server.py @@ -7,14 +7,13 @@ from datetime import datetime, timedelta from pymonad.either import Left from flask import Flask, current_app -from authlib.oauth2.rfc7523 import JWTBearerTokenValidator -from authlib.jose import jwk, jwt, JsonWebKey +from authlib.jose import jwt, KeySet, JsonWebKey from authlib.oauth2.rfc6749.errors import InvalidClientError from authlib.integrations.flask_oauth2 import AuthorizationServer from gn_auth.auth.db import sqlite3 as db from gn_auth.auth.jwks import ( - newest_jwk, jwks_directory, generate_and_save_private_key) + list_jwks, newest_jwk, jwks_directory, generate_and_save_private_key) from .models.oauth2client import client as fetch_client from .models.oauth2token import OAuth2Token, save_token @@ -32,7 +31,7 @@ from .grants.jwt_bearer_grant import JWTBearerGrant, JWTBearerTokenGenerator from .endpoints.revocation import RevocationEndpoint from .endpoints.introspection import IntrospectionEndpoint -from .resource_server import require_oauth, BearerTokenValidator +from .resource_server import require_oauth, JWTBearerTokenValidator def create_query_client_func() -> Callable: @@ -164,6 +163,5 @@ def setup_oauth2_server(app: Flask) -> None: app.config["OAUTH2_SERVER"] = server ## Set up the token validators - require_oauth.register_token_validator(BearerTokenValidator()) require_oauth.register_token_validator( - JWTBearerTokenValidator(app.config["SSL_PRIVATE_KEY"].get_public_key())) + JWTBearerTokenValidator(KeySet(list_jwks(jwks_directory(app))))) -- cgit v1.2.3