about summary refs log tree commit diff
path: root/gn_auth/auth/authentication/oauth2/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth/auth/authentication/oauth2/server.py')
-rw-r--r--gn_auth/auth/authentication/oauth2/server.py118
1 files changed, 61 insertions, 57 deletions
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py
index d845c60..8ac5106 100644
--- a/gn_auth/auth/authentication/oauth2/server.py
+++ b/gn_auth/auth/authentication/oauth2/server.py
@@ -1,23 +1,24 @@
 """Initialise the OAuth2 Server"""
 import uuid
-import datetime
 from typing import Callable
+from datetime import datetime
 
-from flask import Flask, current_app
-from authlib.jose import jwk, jwt
-from authlib.oauth2.rfc7523 import JWTBearerTokenValidator
+from flask import Flask, current_app, request as flask_request
+from authlib.jose import KeySet
+from authlib.oauth2.rfc6749 import OAuth2Request
 from authlib.oauth2.rfc6749.errors import InvalidClientError
 from authlib.integrations.flask_oauth2 import AuthorizationServer
+from authlib.integrations.flask_oauth2.requests import FlaskOAuth2Request
 
 from gn_auth.auth.db import sqlite3 as db
+from gn_auth.auth.jwks import (
+    list_jwks,
+    jwks_directory,
+    newest_jwk_with_rotation)
 
+from .models.jwt_bearer_token import JWTBearerToken
 from .models.oauth2client import client as fetch_client
 from .models.oauth2token import OAuth2Token, save_token
-from .models.jwtrefreshtoken import (
-    JWTRefreshToken,
-    link_child_token,
-    save_refresh_token,
-    load_refresh_token)
 
 from .grants.password_grant import PasswordGrant
 from .grants.refresh_token_grant import RefreshTokenGrant
@@ -27,7 +28,9 @@ from .grants.jwt_bearer_grant import JWTBearerGrant, JWTBearerTokenGenerator
 from .endpoints.revocation import RevocationEndpoint
 from .endpoints.introspection import IntrospectionEndpoint
 
-from .resource_server import require_oauth, BearerTokenValidator
+from .resource_server import require_oauth, JWTBearerTokenValidator
+
+_TWO_HOURS_ = 2 * 60 * 60
 
 
 def create_query_client_func() -> Callable:
@@ -45,52 +48,32 @@ def create_query_client_func() -> Callable:
 
     return __query_client__
 
-def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable:
+def create_save_token_func(token_model: type) -> Callable:
     """Create the function that saves the token."""
+    def __ignore_token__(token, request):# pylint: disable=[unused-argument]
+        """Ignore the token: i.e. Do not save it."""
+
     def __save_token__(token, request):
-        _jwt = jwt.decode(token["access_token"], jwtkey)
-        _token = token_model(
-            token_id=uuid.UUID(_jwt["jti"]),
-            client=request.client,
-            user=request.user,
-            **{
-                "refresh_token": None,
-                "revoked": False,
-                "issued_at": datetime.datetime.now(),
-                **token
-            })
         with db.connection(current_app.config["AUTH_DB"]) as conn:
-            save_token(conn, _token)
-            old_refresh_token = load_refresh_token(
+            save_token(
                 conn,
-                request.form.get("refresh_token", "nosuchtoken")
-            )
-            new_refresh_token = JWTRefreshToken(
-                    token=_token.refresh_token,
+                token_model(
+                    **token,
+                    token_id=uuid.uuid4(),
                     client=request.client,
                     user=request.user,
-                    issued_with=uuid.UUID(_jwt["jti"]),
-                    issued_at=datetime.datetime.fromtimestamp(_jwt["iat"]),
-                    expires=datetime.datetime.fromtimestamp(
-                        old_refresh_token.then(
-                            lambda _tok: _tok.expires.timestamp()
-                        ).maybe((int(_jwt["iat"]) +
-                                 RefreshTokenGrant.DEFAULT_EXPIRES_IN),
-                                lambda _expires: _expires)),
-                    scope=_token.get_scope(),
+                    issued_at=datetime.now(),
                     revoked=False,
-                    parent_of=None)
-            save_refresh_token(conn, new_refresh_token)
-            old_refresh_token.then(lambda _tok: link_child_token(
-                conn, _tok.token, new_refresh_token.token))
-
-    return __save_token__
+                    expires_in=_TWO_HOURS_))
 
+    return {
+        OAuth2Token: __save_token__,
+        JWTBearerToken: __ignore_token__
+    }[token_model]
 
 def make_jwt_token_generator(app):
     """Make token generator function."""
-    _gen = JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"])
-    def __generator__(# pylint: disable=[too-many-arguments]
+    def __generator__(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
             grant_type,
             client,
             user=None,
@@ -98,19 +81,42 @@ def make_jwt_token_generator(app):
             expires_in=None,# pylint: disable=[unused-argument]
             include_refresh_token=True
     ):
-        return _gen.__call__(
-            grant_type,
-            client,
-            user,
-            scope,
-            JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
-            include_refresh_token)
+        return JWTBearerTokenGenerator(
+            secret_key=newest_jwk_with_rotation(
+                jwks_directory(app),
+                int(app.config["JWKS_ROTATION_AGE_DAYS"])),
+            issuer=flask_request.host_url,
+            alg="RS256").__call__(
+                grant_type=grant_type,
+                client=client,
+                user=user,
+                scope=scope,
+                expires_in=expires_in,
+                include_refresh_token=include_refresh_token)
     return __generator__
 
 
+
+class JsonAuthorizationServer(AuthorizationServer):
+    """An authorisation server using JSON rather than FORMDATA."""
+
+    def create_oauth2_request(self, request):
+        """Create an OAuth2 Request from the flask request."""
+        match flask_request.headers.get("Content-Type"):
+            case "application/json":
+                req = OAuth2Request(flask_request.method,
+                                     flask_request.url,
+                                     flask_request.get_json(),
+                                     flask_request.headers)
+            case _:
+                req = FlaskOAuth2Request(flask_request)
+
+        return req
+
+
 def setup_oauth2_server(app: Flask) -> None:
     """Set's up the oauth2 server for the flask application."""
-    server = AuthorizationServer()
+    server = JsonAuthorizationServer()
     server.register_grant(PasswordGrant)
 
     # Figure out a common `code_verifier` for GN2 and GN3 and set
@@ -133,11 +139,9 @@ def setup_oauth2_server(app: Flask) -> None:
     server.init_app(
         app,
         query_client=create_query_client_func(),
-        save_token=create_save_token_func(
-            OAuth2Token, app.config["SSL_PRIVATE_KEY"]))
+        save_token=create_save_token_func(JWTBearerToken))
     app.config["OAUTH2_SERVER"] = server
 
     ## Set up the token validators
-    require_oauth.register_token_validator(BearerTokenValidator())
     require_oauth.register_token_validator(
-        JWTBearerTokenValidator(app.config["SSL_PRIVATE_KEY"].get_public_key()))
+        JWTBearerTokenValidator(KeySet(list_jwks(jwks_directory(app)))))