From a22dbcba7b28b75c13aa25bdd36583ade5fe3747 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Thu, 9 May 2024 08:12:16 +0300 Subject: Link old refresh token to newly issued refresh token We need to track the "lineage" of refresh tokens in order to detect possible stolen tokens and mitigate damage. --- gn_auth/auth/authentication/oauth2/server.py | 53 ++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py index 8b65aa9..75d6e1b 100644 --- a/gn_auth/auth/authentication/oauth2/server.py +++ b/gn_auth/auth/authentication/oauth2/server.py @@ -13,7 +13,11 @@ from gn_auth.auth.db import sqlite3 as db from .models.oauth2client import client from .models.oauth2token import OAuth2Token, save_token -from .models.jwtrefreshtoken import JWTRefreshToken, save_refresh_token +from .models.jwtrefreshtoken import ( + JWTRefreshToken, + link_child_token, + save_refresh_token, + load_refresh_token) from .grants.password_grant import PasswordGrant from .grants.refresh_token_grant import RefreshTokenGrant @@ -25,6 +29,7 @@ from .endpoints.introspection import IntrospectionEndpoint from .resource_server import require_oauth, BearerTokenValidator + def create_query_client_func() -> Callable: """Create the function that loads the client.""" def __query_client__(client_id: uuid.UUID): @@ -56,20 +61,53 @@ def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable: }) with db.connection(current_app.config["AUTH_DB"]) as conn: save_token(conn, _token) - save_refresh_token( + old_refresh_token = load_refresh_token( conn, - JWTRefreshToken( + request.form.get("refresh_token", "nosuchtoken") + ) + new_refresh_token = JWTRefreshToken( token=_token.refresh_token, client=request.client, user=request.user, issued_with=uuid.UUID(_jwt["jti"]), issued_at=datetime.datetime.fromtimestamp(_jwt["iat"]), - expires=datetime.datetime.fromtimestamp(_jwt["iat"]), + expires=datetime.datetime.fromtimestamp( + old_refresh_token.then( + lambda _tok: _tok.expires.timestamp() + ).maybe((int(_jwt["iat"]) + + RefreshTokenGrant.DEFAULT_EXPIRES_IN), + lambda _expires: _expires)), + scope=_token.get_scope(), revoked=False, - parent_of=None)) + parent_of=None) + save_refresh_token(conn, new_refresh_token) + old_refresh_token.then(lambda _tok: link_child_token( + conn, _tok.token, new_refresh_token.token)) return __save_token__ + +def make_jwt_token_generator(app): + """Make token generator function.""" + _gen = JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"]) + def __generator__(# pylint: disable=[too-many-arguments] + grant_type, + client, + user=None, + scope=None, + expires_in=None,# pylint: disable=[unused-argument] + include_refresh_token=True + ): + return _gen.__call__( + grant_type, + client, + user, + scope, + JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN, + include_refresh_token) + return __generator__ + + def setup_oauth2_server(app: Flask) -> None: """Set's up the oauth2 server for the flask application.""" server = AuthorizationServer() @@ -81,9 +119,10 @@ def setup_oauth2_server(app: Flask) -> None: server.register_grant(AuthorisationCodeGrant) server.register_grant(JWTBearerGrant) + jwttokengenerator = make_jwt_token_generator(app) server.register_token_generator( - "urn:ietf:params:oauth:grant-type:jwt-bearer", - JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"])) + "urn:ietf:params:oauth:grant-type:jwt-bearer", jwttokengenerator) + server.register_token_generator("refresh_token", jwttokengenerator) server.register_grant(RefreshTokenGrant) # register endpoints -- cgit v1.2.3