about summary refs log tree commit diff
path: root/gn_auth/auth/authentication/oauth2
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth/auth/authentication/oauth2')
-rw-r--r--gn_auth/auth/authentication/oauth2/endpoints/introspection.py1
-rw-r--r--gn_auth/auth/authentication/oauth2/endpoints/revocation.py1
-rw-r--r--gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py65
-rw-r--r--gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py10
-rw-r--r--gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py50
-rw-r--r--gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py12
-rw-r--r--gn_auth/auth/authentication/oauth2/models/oauth2client.py69
-rw-r--r--gn_auth/auth/authentication/oauth2/resource_server.py59
-rw-r--r--gn_auth/auth/authentication/oauth2/server.py118
-rw-r--r--gn_auth/auth/authentication/oauth2/views.py29
10 files changed, 321 insertions, 93 deletions
diff --git a/gn_auth/auth/authentication/oauth2/endpoints/introspection.py b/gn_auth/auth/authentication/oauth2/endpoints/introspection.py
index 572324e..200b25d 100644
--- a/gn_auth/auth/authentication/oauth2/endpoints/introspection.py
+++ b/gn_auth/auth/authentication/oauth2/endpoints/introspection.py
@@ -20,6 +20,7 @@ def get_token_user_sub(token: OAuth2Token) -> str:# pylint: disable=[unused-argu
 
 class IntrospectionEndpoint(_IntrospectionEndpoint):
     """Introspect token."""
+    CLIENT_AUTH_METHODS = ['client_secret_post']
     def query_token(self, token_string: str, token_type_hint: str):
         """Query the token."""
         return _query_token(self, token_string, token_type_hint)
diff --git a/gn_auth/auth/authentication/oauth2/endpoints/revocation.py b/gn_auth/auth/authentication/oauth2/endpoints/revocation.py
index 240ca30..80922f1 100644
--- a/gn_auth/auth/authentication/oauth2/endpoints/revocation.py
+++ b/gn_auth/auth/authentication/oauth2/endpoints/revocation.py
@@ -12,6 +12,7 @@ from .utilities import query_token as _query_token
 class RevocationEndpoint(_RevocationEndpoint):
     """Revoke the tokens"""
     ENDPOINT_NAME = "revoke"
+    CLIENT_AUTH_METHODS = ['client_secret_post']
     def query_token(self, token_string: str, token_type_hint: str):
         """Query the token."""
         return _query_token(self, token_string, token_type_hint)
diff --git a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
index b0f2cc7..c802091 100644
--- a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
+++ b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
@@ -1,15 +1,21 @@
 """JWT as Authorisation Grant"""
 import uuid
+import time
 
+from typing import Optional
 from flask import current_app as app
 
+from authlib.jose import jwt
+from authlib.common.encoding import to_native
 from authlib.common.security import generate_token
 from authlib.oauth2.rfc7523.jwt_bearer import JWTBearerGrant as _JWTBearerGrant
 from authlib.oauth2.rfc7523.token import (
     JWTBearerTokenGenerator as _JWTBearerTokenGenerator)
 
+from gn_auth.debug import __pk__
 from gn_auth.auth.db.sqlite3 import with_db_connection
-from gn_auth.auth.authentication.users import user_by_id
+from gn_auth.auth.authentication.users import User, user_by_id
+from gn_auth.auth.authentication.oauth2.models.oauth2client import OAuth2Client
 
 
 class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
@@ -19,23 +25,66 @@ class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
 
     DEFAULT_EXPIRES_IN = 300
 
-    def get_token_data(#pylint: disable=[too-many-arguments]
+    def get_token_data(#pylint: disable=[too-many-arguments, too-many-positional-arguments]
             self, grant_type, client, expires_in=None, user=None, scope=None
     ):
         """Post process data to prevent JSON serialization problems."""
-        tokendata = super().get_token_data(
-            grant_type, client, expires_in, user, scope)
+        issued_at = int(time.time())
+        tokendata = {
+            "scope": self.get_allowed_scope(client, scope),
+            "grant_type": grant_type,
+            "iat": issued_at,
+            "client_id": client.get_client_id()
+        }
+        if isinstance(expires_in, int) and expires_in > 0:
+            tokendata["exp"] = issued_at + expires_in
+        if self.issuer:
+            tokendata["iss"] = self.issuer
+        if user:
+            tokendata["sub"] = self.get_sub_value(user)
+
         return {
             **{
                 key: str(value) if key.endswith("_id") else value
                 for key, value in tokendata.items()
             },
             "sub": str(tokendata["sub"]),
-            "jti": str(uuid.uuid4())
+            "jti": str(uuid.uuid4()),
+            "oauth2_client_id": str(client.client_id)
         }
 
+    def generate(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
+            self,
+            grant_type: str,
+            client: OAuth2Client,
+            user: Optional[User] = None,
+            scope: Optional[str] = None,
+            expires_in: Optional[int] = None
+    ) -> dict:
+        """Generate a bearer token for OAuth 2.0 authorization token endpoint.
+
+        :param client: the client that making the request.
+        :param grant_type: current requested grant_type.
+        :param user: current authorized user.
+        :param expires_in: if provided, use this value as expires_in.
+        :param scope: current requested scope.
+        :return: Token dict
+        """
+
+        token_data = self.get_token_data(grant_type, client, expires_in, user, scope)
+        access_token = jwt.encode({"alg": self.alg}, token_data, key=self.secret_key, check=False)
+        token = {
+            "token_type": "Bearer",
+            "access_token": to_native(access_token)
+        }
+        if expires_in:
+            token["expires_in"] = expires_in
+        if scope:
+            token["scope"] = scope
+        return token
+
 
-    def __call__(# pylint: disable=[too-many-arguments]
+    def __call__(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
             self, grant_type, client, user=None, scope=None, expires_in=None,
             include_refresh_token=True
     ):
@@ -74,7 +123,9 @@ class JWTBearerGrant(_JWTBearerGrant):
 
     def resolve_client_key(self, client, headers, payload):
         """Resolve client key to decode assertion data."""
-        return app.config["SSL_PUBLIC_KEYS"].get(headers["kid"])
+        keyset = client.jwks()
+        __pk__("THE KEYSET =======>", keyset.keys)
+        return keyset.find_by_kid(headers["kid"])
 
 
     def authenticate_user(self, subject):
diff --git a/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py b/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py
index fd6804d..f897d89 100644
--- a/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py
+++ b/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py
@@ -34,18 +34,18 @@ class RefreshTokenGrant(grants.RefreshTokenGrant):
                     else Nothing)
             ).maybe(None, lambda _tok: _tok)
 
-    def authenticate_user(self, credential):
+    def authenticate_user(self, refresh_token):
         """Check that user is valid for given token."""
         with connection(app.config["AUTH_DB"]) as conn:
             try:
-                return user_by_id(conn, credential.user.user_id)
+                return user_by_id(conn, refresh_token.user.user_id)
             except NotFoundError as _nfe:
                 return None
 
         return None
 
-    def revoke_old_credential(self, credential):
+    def revoke_old_credential(self, refresh_token):
         """Revoke any old refresh token after issuing new refresh token."""
         with connection(app.config["AUTH_DB"]) as conn:
-            if credential.parent_of is not None:
-                revoke_refresh_token(conn, credential)
+            if refresh_token.parent_of is not None:
+                revoke_refresh_token(conn, refresh_token)
diff --git a/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py
new file mode 100644
index 0000000..71769e1
--- /dev/null
+++ b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py
@@ -0,0 +1,50 @@
+"""Implement model for JWTBearerToken"""
+import uuid
+import time
+from typing import Optional
+
+from authlib.oauth2.rfc7523 import JWTBearerToken as _JWTBearerToken
+
+from gn_auth.auth.db.sqlite3 import with_db_connection
+from gn_auth.auth.authentication.users import user_by_id
+from gn_auth.auth.authentication.oauth2.models.oauth2client import (
+    client as fetch_client)
+
+class JWTBearerToken(_JWTBearerToken):
+    """Overrides default JWTBearerToken class."""
+
+    def __init__(self, payload, header, options=None, params=None):
+        """Initialise the bearer token."""
+        # TOD0: Maybe remove this init and make this a dataclass like the way
+        #       OAuth2Client is a dataclass
+        super().__init__(payload, header, options, params)
+        self.user = with_db_connection(
+            lambda conn:user_by_id(conn, uuid.UUID(payload["sub"])))
+        self.client = with_db_connection(
+            lambda conn: fetch_client(
+                conn, uuid.UUID(payload["oauth2_client_id"])
+            )
+        ).maybe(None, lambda _client: _client)
+
+
+    def check_client(self, client):
+        """Check that the client is right."""
+        return self.client.get_client_id() == client.get_client_id()
+
+
+    def get_expires_in(self) -> Optional[int]:
+        """Return the number of seconds the token is valid for since issue.
+
+        If `None`, the token never expires."""
+        if "exp" in self:
+            return self['exp'] - self['iat']
+        return None
+
+
+    def is_expired(self):
+        """Check whether the token is expired.
+
+        If there is no 'exp' member, assume this token will never expire."""
+        if "exp" in self:
+            return self["exp"] < time.time()
+        return False
diff --git a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
index 31c9147..46515c8 100644
--- a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
+++ b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
@@ -142,7 +142,7 @@ def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str):
                  "WHERE token=:parenttoken"),
                 {"parenttoken": parent.token, "childtoken": childtoken})
 
-    def __check_child__(parent):
+    def __check_child__(parent):#pylint: disable=[unused-variable]
         with db.cursor(conn) as cursor:
             cursor.execute(
                 ("SELECT * FROM jwt_refresh_tokens WHERE token=:parenttoken"),
@@ -154,15 +154,17 @@ def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str):
                     "activity detected.")
             return Right(parent)
 
-    def __revoke_and_raise_error__(_error_msg_):
+    def __revoke_and_raise_error__(_error_msg_):#pylint: disable=[unused-variable]
         load_refresh_token(conn, parenttoken).then(
             lambda _tok: revoke_refresh_token(conn, _tok))
         raise InvalidGrantError(_error_msg_)
 
+    def __handle_not_found__(_error_msg_):
+        raise InvalidGrantError(_error_msg_)
+
     load_refresh_token(conn, parenttoken).maybe(
-        Left("Token not found"), Right).then(
-            __check_child__).either(__revoke_and_raise_error__,
-                                    __link_to_child__)
+        Left("Token not found"), Right).either(
+            __handle_not_found__, __link_to_child__)
 
 
 def is_refresh_token_valid(token: JWTRefreshToken, client: OAuth2Client) -> bool:
diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
index d31faf6..1639e2e 100644
--- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py
+++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
@@ -1,17 +1,19 @@
 """OAuth2 Client model."""
 import json
 import datetime
-from pathlib import Path
-
 from uuid import UUID
-from dataclasses import dataclass
 from functools import cached_property
-from typing import Sequence, Optional
+from dataclasses import asdict, dataclass
+from typing import Any, Sequence, Optional
 
+import requests
+from flask import current_app as app
+from requests.exceptions import JSONDecodeError
 from authlib.jose import KeySet, JsonWebKey
 from authlib.oauth2.rfc6749 import ClientMixin
 from pymonad.maybe import Just, Maybe, Nothing
 
+from gn_auth.debug import __pk__
 from gn_auth.auth.db import sqlite3 as db
 from gn_auth.auth.errors import NotFoundError
 from gn_auth.auth.authentication.users import (User,
@@ -57,16 +59,34 @@ class OAuth2Client(ClientMixin):
         """
         return self.client_metadata.get("client_type", "public")
 
-    @cached_property
+
     def jwks(self) -> KeySet:
         """Return this client's KeySet."""
-        def __parse_key__(keypath: Path) -> JsonWebKey:
-            with open(keypath) as _key:# pylint: disable=[unspecified-encoding]
-                return JsonWebKey.import_key(_key.read())
+        jwksuri = self.client_metadata.get("public-jwks-uri")
+        __pk__(f"PUBLIC JWKs link for client {self.client_id}", jwksuri)
+        if not bool(jwksuri):
+            app.logger.debug("No Public JWKs URI set for client!")
+            return KeySet([])
+        try:
+            ## IMPORTANT: This can cause a deadlock if the client is working in
+            ##            single-threaded mode, i.e. can only serve one request
+            ##            at a time.
+            return KeySet([JsonWebKey.import_key(key)
+                           for key in requests.get(
+                                   jwksuri,
+                                   timeout=300,
+                                   allow_redirects=True).json()["jwks"]])
+        except requests.ConnectionError as _connerr:
+            app.logger.debug(
+                "Could not connect to provided URI: %s", jwksuri, exc_info=True)
+        except JSONDecodeError as _jsonerr:
+            app.logger.debug(
+                "Could not convert response to JSON", exc_info=True)
+        except Exception as _exc:# pylint: disable=[broad-except]
+            app.logger.debug(
+                "Error retrieving the JWKs for the client.", exc_info=True)
+        return KeySet([])
 
-        return KeySet([
-            __parse_key__(Path(pth))
-            for pth in self.client_metadata.get("public_keys", [])])
 
     def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool:
         """
@@ -77,12 +97,9 @@ class OAuth2Client(ClientMixin):
         * client_secret_post: Client uses the HTTP POST parameters
         * client_secret_basic: Client uses HTTP Basic
         """
-        if endpoint == "token":
+        if endpoint in ("token", "revoke", "introspection"):
             return (method in self.token_endpoint_auth_method
                     and method == "client_secret_post")
-        if endpoint in ("introspection", "revoke"):
-            return (method in self.token_endpoint_auth_method
-                    and method == "client_secret_basic")
         return False
 
     @cached_property
@@ -277,3 +294,25 @@ def delete_client(
         cursor.execute("DELETE FROM oauth2_tokens WHERE client_id=?", params)
         cursor.execute("DELETE FROM oauth2_clients WHERE client_id=?", params)
         return the_client
+
+
+def update_client_attribute(
+        client: OAuth2Client,# pylint: disable=[redefined-outer-name]
+        attribute: str,
+        value: Any
+) -> OAuth2Client:
+    """Return a new OAuth2Client with the given attribute updated/changed."""
+    attrs = {
+        attr: type(value)
+        for attr, value in asdict(client).items()
+        if attr != "client_id"
+    }
+    assert (
+        attribute in attrs.keys() and isinstance(value, attrs[attribute])), (
+            "Invalid attribute/value provided!")
+    return OAuth2Client(
+        client_id=client.client_id,
+        **{
+            attr: (value if attr==attribute else getattr(client, attr))
+            for attr in attrs
+        })
diff --git a/gn_auth/auth/authentication/oauth2/resource_server.py b/gn_auth/auth/authentication/oauth2/resource_server.py
index 2405ee2..8ecf923 100644
--- a/gn_auth/auth/authentication/oauth2/resource_server.py
+++ b/gn_auth/auth/authentication/oauth2/resource_server.py
@@ -1,11 +1,20 @@
 """Protect the resources endpoints"""
+from datetime import datetime, timezone, timedelta
 
 from flask import current_app as app
+
+from authlib.jose import jwt, KeySet, JoseError
 from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator
+from authlib.oauth2.rfc7523 import (
+    JWTBearerTokenValidator as _JWTBearerTokenValidator)
 from authlib.integrations.flask_oauth2 import ResourceProtector
 
 from gn_auth.auth.db import sqlite3 as db
-from gn_auth.auth.authentication.oauth2.models.oauth2token import token_by_access_token
+from gn_auth.auth.jwks import list_jwks, jwks_directory
+from gn_auth.auth.authentication.oauth2.models.jwt_bearer_token import (
+    JWTBearerToken)
+from gn_auth.auth.authentication.oauth2.models.oauth2token import (
+    token_by_access_token)
 
 class BearerTokenValidator(_BearerTokenValidator):
     """Extends `authlib.oauth2.rfc6750.BearerTokenValidator`"""
@@ -14,4 +23,52 @@ class BearerTokenValidator(_BearerTokenValidator):
             return token_by_access_token(conn, token_string).maybe(# type: ignore[misc]
                 None, lambda tok: tok)
 
+class JWTBearerTokenValidator(_JWTBearerTokenValidator):
+    """Validate a token using all the keys"""
+    token_cls = JWTBearerToken
+    _local_attributes = ("jwt_refresh_frequency_hours",)
+
+    def __init__(self, public_key, issuer=None, realm=None, **extra_attributes):
+        """Initialise the validator class."""
+        # https://docs.authlib.org/en/latest/jose/jwt.html#use-dynamic-keys
+        # We can simply use the KeySet rather than a specific key.
+        super().__init__(public_key,
+                         issuer,
+                         realm,
+                         **{
+                             key: value for key,value
+                             in extra_attributes.items()
+                             if key not in self._local_attributes
+                         })
+        self._last_jwks_update = datetime.now(tz=timezone.utc)
+        self._refresh_frequency = timedelta(hours=int(
+            extra_attributes.get("jwt_refresh_frequency_hours", 6)))
+        self.claims_options = {
+            'exp': {'essential': False},
+            'client_id': {'essential': True},
+            'grant_type': {'essential': True},
+        }
+
+    def __refresh_jwks__(self):
+        now = datetime.now(tz=timezone.utc)
+        if (now - self._last_jwks_update) >= self._refresh_frequency:
+            self.public_key = KeySet(list_jwks(jwks_directory(app)))
+
+    def authenticate_token(self, token_string: str):
+        self.__refresh_jwks__()
+        for key in self.public_key.keys:
+            try:
+                claims = jwt.decode(
+                    token_string, key,
+                    claims_options=self.claims_options,
+                    claims_cls=self.token_cls,
+                )
+                claims.validate()
+                return claims
+            except JoseError as error:
+                app.logger.debug('Authenticate token failed. %r', error)
+
+        return None
+
+
 require_oauth = ResourceProtector()
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py
index d845c60..8ac5106 100644
--- a/gn_auth/auth/authentication/oauth2/server.py
+++ b/gn_auth/auth/authentication/oauth2/server.py
@@ -1,23 +1,24 @@
 """Initialise the OAuth2 Server"""
 import uuid
-import datetime
 from typing import Callable
+from datetime import datetime
 
-from flask import Flask, current_app
-from authlib.jose import jwk, jwt
-from authlib.oauth2.rfc7523 import JWTBearerTokenValidator
+from flask import Flask, current_app, request as flask_request
+from authlib.jose import KeySet
+from authlib.oauth2.rfc6749 import OAuth2Request
 from authlib.oauth2.rfc6749.errors import InvalidClientError
 from authlib.integrations.flask_oauth2 import AuthorizationServer
+from authlib.integrations.flask_oauth2.requests import FlaskOAuth2Request
 
 from gn_auth.auth.db import sqlite3 as db
+from gn_auth.auth.jwks import (
+    list_jwks,
+    jwks_directory,
+    newest_jwk_with_rotation)
 
+from .models.jwt_bearer_token import JWTBearerToken
 from .models.oauth2client import client as fetch_client
 from .models.oauth2token import OAuth2Token, save_token
-from .models.jwtrefreshtoken import (
-    JWTRefreshToken,
-    link_child_token,
-    save_refresh_token,
-    load_refresh_token)
 
 from .grants.password_grant import PasswordGrant
 from .grants.refresh_token_grant import RefreshTokenGrant
@@ -27,7 +28,9 @@ from .grants.jwt_bearer_grant import JWTBearerGrant, JWTBearerTokenGenerator
 from .endpoints.revocation import RevocationEndpoint
 from .endpoints.introspection import IntrospectionEndpoint
 
-from .resource_server import require_oauth, BearerTokenValidator
+from .resource_server import require_oauth, JWTBearerTokenValidator
+
+_TWO_HOURS_ = 2 * 60 * 60
 
 
 def create_query_client_func() -> Callable:
@@ -45,52 +48,32 @@ def create_query_client_func() -> Callable:
 
     return __query_client__
 
-def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable:
+def create_save_token_func(token_model: type) -> Callable:
     """Create the function that saves the token."""
+    def __ignore_token__(token, request):# pylint: disable=[unused-argument]
+        """Ignore the token: i.e. Do not save it."""
+
     def __save_token__(token, request):
-        _jwt = jwt.decode(token["access_token"], jwtkey)
-        _token = token_model(
-            token_id=uuid.UUID(_jwt["jti"]),
-            client=request.client,
-            user=request.user,
-            **{
-                "refresh_token": None,
-                "revoked": False,
-                "issued_at": datetime.datetime.now(),
-                **token
-            })
         with db.connection(current_app.config["AUTH_DB"]) as conn:
-            save_token(conn, _token)
-            old_refresh_token = load_refresh_token(
+            save_token(
                 conn,
-                request.form.get("refresh_token", "nosuchtoken")
-            )
-            new_refresh_token = JWTRefreshToken(
-                    token=_token.refresh_token,
+                token_model(
+                    **token,
+                    token_id=uuid.uuid4(),
                     client=request.client,
                     user=request.user,
-                    issued_with=uuid.UUID(_jwt["jti"]),
-                    issued_at=datetime.datetime.fromtimestamp(_jwt["iat"]),
-                    expires=datetime.datetime.fromtimestamp(
-                        old_refresh_token.then(
-                            lambda _tok: _tok.expires.timestamp()
-                        ).maybe((int(_jwt["iat"]) +
-                                 RefreshTokenGrant.DEFAULT_EXPIRES_IN),
-                                lambda _expires: _expires)),
-                    scope=_token.get_scope(),
+                    issued_at=datetime.now(),
                     revoked=False,
-                    parent_of=None)
-            save_refresh_token(conn, new_refresh_token)
-            old_refresh_token.then(lambda _tok: link_child_token(
-                conn, _tok.token, new_refresh_token.token))
-
-    return __save_token__
+                    expires_in=_TWO_HOURS_))
 
+    return {
+        OAuth2Token: __save_token__,
+        JWTBearerToken: __ignore_token__
+    }[token_model]
 
 def make_jwt_token_generator(app):
     """Make token generator function."""
-    _gen = JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"])
-    def __generator__(# pylint: disable=[too-many-arguments]
+    def __generator__(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
             grant_type,
             client,
             user=None,
@@ -98,19 +81,42 @@ def make_jwt_token_generator(app):
             expires_in=None,# pylint: disable=[unused-argument]
             include_refresh_token=True
     ):
-        return _gen.__call__(
-            grant_type,
-            client,
-            user,
-            scope,
-            JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
-            include_refresh_token)
+        return JWTBearerTokenGenerator(
+            secret_key=newest_jwk_with_rotation(
+                jwks_directory(app),
+                int(app.config["JWKS_ROTATION_AGE_DAYS"])),
+            issuer=flask_request.host_url,
+            alg="RS256").__call__(
+                grant_type=grant_type,
+                client=client,
+                user=user,
+                scope=scope,
+                expires_in=expires_in,
+                include_refresh_token=include_refresh_token)
     return __generator__
 
 
+
+class JsonAuthorizationServer(AuthorizationServer):
+    """An authorisation server using JSON rather than FORMDATA."""
+
+    def create_oauth2_request(self, request):
+        """Create an OAuth2 Request from the flask request."""
+        match flask_request.headers.get("Content-Type"):
+            case "application/json":
+                req = OAuth2Request(flask_request.method,
+                                     flask_request.url,
+                                     flask_request.get_json(),
+                                     flask_request.headers)
+            case _:
+                req = FlaskOAuth2Request(flask_request)
+
+        return req
+
+
 def setup_oauth2_server(app: Flask) -> None:
     """Set's up the oauth2 server for the flask application."""
-    server = AuthorizationServer()
+    server = JsonAuthorizationServer()
     server.register_grant(PasswordGrant)
 
     # Figure out a common `code_verifier` for GN2 and GN3 and set
@@ -133,11 +139,9 @@ def setup_oauth2_server(app: Flask) -> None:
     server.init_app(
         app,
         query_client=create_query_client_func(),
-        save_token=create_save_token_func(
-            OAuth2Token, app.config["SSL_PRIVATE_KEY"]))
+        save_token=create_save_token_func(JWTBearerToken))
     app.config["OAUTH2_SERVER"] = server
 
     ## Set up the token validators
-    require_oauth.register_token_validator(BearerTokenValidator())
     require_oauth.register_token_validator(
-        JWTBearerTokenValidator(app.config["SSL_PRIVATE_KEY"].get_public_key()))
+        JWTBearerTokenValidator(KeySet(list_jwks(jwks_directory(app)))))
diff --git a/gn_auth/auth/authentication/oauth2/views.py b/gn_auth/auth/authentication/oauth2/views.py
index 22437a2..0e2c4eb 100644
--- a/gn_auth/auth/authentication/oauth2/views.py
+++ b/gn_auth/auth/authentication/oauth2/views.py
@@ -9,6 +9,7 @@ from flask import (
     flash,
     request,
     url_for,
+    jsonify,
     redirect,
     Response,
     Blueprint,
@@ -17,6 +18,7 @@ from flask import (
 
 from gn_auth.auth.db import sqlite3 as db
 from gn_auth.auth.db.sqlite3 import with_db_connection
+from gn_auth.auth.jwks import jwks_directory, list_jwks
 from gn_auth.auth.errors import NotFoundError, ForbiddenAccess
 from gn_auth.auth.authentication.users import valid_login, user_by_email
 
@@ -45,6 +47,14 @@ def authorise():
             flash("Invalid OAuth2 client.", "alert-danger")
 
         if request.method == "GET":
+            def __forgot_password_table_exists__(conn):
+                with db.cursor(conn) as cursor:
+                    cursor.execute("SELECT name FROM sqlite_master "
+                                   "WHERE type='table' "
+                                   "AND name='forgot_password_tokens'")
+                    return bool(cursor.fetchone())
+                return False
+
             client = server.query_client(request.args.get("client_id"))
             _src = urlparse(request.args["redirect_uri"])
             return render_template(
@@ -53,7 +63,9 @@ def authorise():
                 scope=client.scope,
                 response_type=request.args["response_type"],
                 redirect_uri=request.args["redirect_uri"],
-                source_uri=f"{_src.scheme}://{_src.netloc}/")
+                source_uri=f"{_src.scheme}://{_src.netloc}/",
+                display_forgot_password=with_db_connection(
+                    __forgot_password_table_exists__))
 
         form = request.form
         def __authorise__(conn: db.DbConnection):
@@ -65,14 +77,15 @@ def authorise():
             try:
                 email = validate_email(
                     form.get("user:email"), check_deliverability=False)
-                user = user_by_email(conn, email["email"])
+                user = user_by_email(conn, email["email"])  # type: ignore
                 if valid_login(conn, user, form.get("user:password", "")):
                     if not user.verified:
                         return redirect(
                             url_for("oauth2.users.handle_unverified",
                                     response_type=form["response_type"],
                                     client_id=client_id,
-                                    redirect_uri=form["redirect_uri"]),
+                                    redirect_uri=form["redirect_uri"],
+                                    email=email["email"]),
                             code=307)
                     return server.create_authorization_response(request=request, grant_user=user)
                 flash(email_passwd_msg, "alert-danger")
@@ -116,3 +129,13 @@ def introspect_token() -> Response:
                 IntrospectionEndpoint.ENDPOINT_NAME)
 
     raise ForbiddenAccess("You cannot access this endpoint")
+
+
+@auth.route("/public-jwks", methods=["GET"])
+def public_jwks():
+    """Provide the JWK public keys used by this application."""
+    return jsonify({
+        "documentation": (
+            "The keys are listed in order of creation, from the oldest (first) "
+            "to the newest (last)."),
+        "jwks": tuple(key.as_dict() for key in list_jwks(jwks_directory(app)))})