about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-05-09 08:12:16 +0300
committerFrederick Muriuki Muriithi2024-05-13 06:27:46 +0300
commita22dbcba7b28b75c13aa25bdd36583ade5fe3747 (patch)
tree84ba00ba66bceda8136116d54988f46344ea1c8d
parent0aa19052d00b2e7d0d6edd5905314484480e7ea2 (diff)
downloadgn-auth-a22dbcba7b28b75c13aa25bdd36583ade5fe3747.tar.gz
Link old refresh token to newly issued refresh token
We need to track the "lineage" of refresh tokens in order to detect
possible stolen tokens and mitigate damage.
-rw-r--r--gn_auth/auth/authentication/oauth2/server.py53
1 files changed, 46 insertions, 7 deletions
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py
index 8b65aa9..75d6e1b 100644
--- a/gn_auth/auth/authentication/oauth2/server.py
+++ b/gn_auth/auth/authentication/oauth2/server.py
@@ -13,7 +13,11 @@ from gn_auth.auth.db import sqlite3 as db
 
 from .models.oauth2client import client
 from .models.oauth2token import OAuth2Token, save_token
-from .models.jwtrefreshtoken import JWTRefreshToken, save_refresh_token
+from .models.jwtrefreshtoken import (
+    JWTRefreshToken,
+    link_child_token,
+    save_refresh_token,
+    load_refresh_token)
 
 from .grants.password_grant import PasswordGrant
 from .grants.refresh_token_grant import RefreshTokenGrant
@@ -25,6 +29,7 @@ from .endpoints.introspection import IntrospectionEndpoint
 
 from .resource_server import require_oauth, BearerTokenValidator
 
+
 def create_query_client_func() -> Callable:
     """Create the function that loads the client."""
     def __query_client__(client_id: uuid.UUID):
@@ -56,20 +61,53 @@ def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable:
             })
         with db.connection(current_app.config["AUTH_DB"]) as conn:
             save_token(conn, _token)
-            save_refresh_token(
+            old_refresh_token = load_refresh_token(
                 conn,
-                JWTRefreshToken(
+                request.form.get("refresh_token", "nosuchtoken")
+            )
+            new_refresh_token = JWTRefreshToken(
                     token=_token.refresh_token,
                     client=request.client,
                     user=request.user,
                     issued_with=uuid.UUID(_jwt["jti"]),
                     issued_at=datetime.datetime.fromtimestamp(_jwt["iat"]),
-                    expires=datetime.datetime.fromtimestamp(_jwt["iat"]),
+                    expires=datetime.datetime.fromtimestamp(
+                        old_refresh_token.then(
+                            lambda _tok: _tok.expires.timestamp()
+                        ).maybe((int(_jwt["iat"]) +
+                                 RefreshTokenGrant.DEFAULT_EXPIRES_IN),
+                                lambda _expires: _expires)),
+                    scope=_token.get_scope(),
                     revoked=False,
-                    parent_of=None))
+                    parent_of=None)
+            save_refresh_token(conn, new_refresh_token)
+            old_refresh_token.then(lambda _tok: link_child_token(
+                conn, _tok.token, new_refresh_token.token))
 
     return __save_token__
 
+
+def make_jwt_token_generator(app):
+    """Make token generator function."""
+    _gen = JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"])
+    def __generator__(# pylint: disable=[too-many-arguments]
+            grant_type,
+            client,
+            user=None,
+            scope=None,
+            expires_in=None,# pylint: disable=[unused-argument]
+            include_refresh_token=True
+    ):
+        return _gen.__call__(
+            grant_type,
+            client,
+            user,
+            scope,
+            JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
+            include_refresh_token)
+    return __generator__
+
+
 def setup_oauth2_server(app: Flask) -> None:
     """Set's up the oauth2 server for the flask application."""
     server = AuthorizationServer()
@@ -81,9 +119,10 @@ def setup_oauth2_server(app: Flask) -> None:
     server.register_grant(AuthorisationCodeGrant)
 
     server.register_grant(JWTBearerGrant)
+    jwttokengenerator = make_jwt_token_generator(app)
     server.register_token_generator(
-        "urn:ietf:params:oauth:grant-type:jwt-bearer",
-        JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"]))
+        "urn:ietf:params:oauth:grant-type:jwt-bearer", jwttokengenerator)
+    server.register_token_generator("refresh_token", jwttokengenerator)
     server.register_grant(RefreshTokenGrant)
 
     # register endpoints