about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2025-01-10 13:12:15 -0600
committerFrederick Muriuki Muriithi2025-01-10 13:12:15 -0600
commit877a7a3d7862c3b7e3a4aeae5f4ca69b6c880807 (patch)
treedf2b4cfeb89cb959e29a1c601d159ad15b131ca1
parent989711fe1843cb8085883ef7389af1cbe32bb661 (diff)
downloadgn-auth-877a7a3d7862c3b7e3a4aeae5f4ca69b6c880807.tar.gz
Update server to support non-expiring JWTs.
-rw-r--r--gn_auth/auth/authentication/oauth2/server.py82
1 files changed, 30 insertions, 52 deletions
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py
index a8109b7..e6d1e01 100644
--- a/gn_auth/auth/authentication/oauth2/server.py
+++ b/gn_auth/auth/authentication/oauth2/server.py
@@ -3,8 +3,8 @@ import uuid
 from typing import Callable
 from datetime import datetime
 
-from flask import Flask, current_app
-from authlib.jose import jwt, KeySet
+from flask import Flask, current_app, request as flask_request
+from authlib.jose import KeySet
 from authlib.oauth2.rfc6749.errors import InvalidClientError
 from authlib.integrations.flask_oauth2 import AuthorizationServer
 from authlib.oauth2.rfc6749 import OAuth2Request
@@ -16,13 +16,9 @@ from gn_auth.auth.jwks import (
     jwks_directory,
     newest_jwk_with_rotation)
 
+from .models.jwt_bearer_token import JWTBearerToken
 from .models.oauth2client import client as fetch_client
 from .models.oauth2token import OAuth2Token, save_token
-from .models.jwtrefreshtoken import (
-    JWTRefreshToken,
-    link_child_token,
-    save_refresh_token,
-    load_refresh_token)
 
 from .grants.password_grant import PasswordGrant
 from .grants.refresh_token_grant import RefreshTokenGrant
@@ -34,6 +30,8 @@ from .endpoints.introspection import IntrospectionEndpoint
 
 from .resource_server import require_oauth, JWTBearerTokenValidator
 
+_TWO_HOURS_ = 2 * 60 * 60
+
 
 def create_query_client_func() -> Callable:
     """Create the function that loads the client."""
@@ -50,50 +48,28 @@ def create_query_client_func() -> Callable:
 
     return __query_client__
 
-def create_save_token_func(token_model: type, app: Flask) -> Callable:
+def create_save_token_func(token_model: type) -> Callable:
     """Create the function that saves the token."""
+    def __ignore_token__(token, request):# pylint: disable=[unused-argument]
+        """Ignore the token: i.e. Do not save it."""
+
     def __save_token__(token, request):
-        _jwt = jwt.decode(
-            token["access_token"],
-            newest_jwk_with_rotation(
-                jwks_directory(app),
-                int(app.config["JWKS_ROTATION_AGE_DAYS"])))
-        _token = token_model(
-            token_id=uuid.UUID(_jwt["jti"]),
-            client=request.client,
-            user=request.user,
-            **{
-                "refresh_token": None,
-                "revoked": False,
-                "issued_at": datetime.now(),
-                **token
-            })
         with db.connection(current_app.config["AUTH_DB"]) as conn:
-            save_token(conn, _token)
-            old_refresh_token = load_refresh_token(
+            save_token(
                 conn,
-                request.form.get("refresh_token", "nosuchtoken")
-            )
-            new_refresh_token = JWTRefreshToken(
-                    token=_token.refresh_token,
+                token_model(
+                    **token,
+                    token_id=uuid.uuid4(),
                     client=request.client,
                     user=request.user,
-                    issued_with=uuid.UUID(_jwt["jti"]),
-                    issued_at=datetime.fromtimestamp(_jwt["iat"]),
-                    expires=datetime.fromtimestamp(
-                        old_refresh_token.then(
-                            lambda _tok: _tok.expires.timestamp()
-                        ).maybe((int(_jwt["iat"]) +
-                                 RefreshTokenGrant.DEFAULT_EXPIRES_IN),
-                                lambda _expires: _expires)),
-                    scope=_token.get_scope(),
+                    issued_at=datetime.now(),
                     revoked=False,
-                    parent_of=None)
-            save_refresh_token(conn, new_refresh_token)
-            old_refresh_token.then(lambda _tok: link_child_token(
-                conn, _tok.token, new_refresh_token.token))
+                    expires_in=_TWO_HOURS_))
 
-    return __save_token__
+    return {
+        OAuth2Token: __save_token__,
+        JWTBearerToken: __ignore_token__
+    }[token_model]
 
 def make_jwt_token_generator(app):
     """Make token generator function."""
@@ -106,15 +82,17 @@ def make_jwt_token_generator(app):
             include_refresh_token=True
     ):
         return JWTBearerTokenGenerator(
-            newest_jwk_with_rotation(
+            secret_key=newest_jwk_with_rotation(
                 jwks_directory(app),
-                int(app.config["JWKS_ROTATION_AGE_DAYS"]))).__call__(
-                        grant_type,
-                        client,
-                        user,
-                        scope,
-                        JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
-                        include_refresh_token)
+                int(app.config["JWKS_ROTATION_AGE_DAYS"])),
+            issuer=flask_request.host_url,
+            alg="RS256").__call__(
+                grant_type=grant_type,
+                client=client,
+                user=user,
+                scope=scope,
+                expires_in=expires_in,
+                include_refresh_token=include_refresh_token)
     return __generator__
 
 
@@ -153,7 +131,7 @@ def setup_oauth2_server(app: Flask) -> None:
     server.init_app(
         app,
         query_client=create_query_client_func(),
-        save_token=create_save_token_func(OAuth2Token, app))
+        save_token=create_save_token_func(JWTBearerToken))
     app.config["OAUTH2_SERVER"] = server
 
     ## Set up the token validators