diff options
author | Frederick Muriuki Muriithi | 2024-05-09 08:12:16 +0300 |
---|---|---|
committer | Frederick Muriuki Muriithi | 2024-05-13 06:27:46 +0300 |
commit | a22dbcba7b28b75c13aa25bdd36583ade5fe3747 (patch) | |
tree | 84ba00ba66bceda8136116d54988f46344ea1c8d /gn_auth/auth/authentication | |
parent | 0aa19052d00b2e7d0d6edd5905314484480e7ea2 (diff) | |
download | gn-auth-a22dbcba7b28b75c13aa25bdd36583ade5fe3747.tar.gz |
Link old refresh token to newly issued refresh token
We need to track the "lineage" of refresh tokens in order to detect
possible stolen tokens and mitigate damage.
Diffstat (limited to 'gn_auth/auth/authentication')
-rw-r--r-- | gn_auth/auth/authentication/oauth2/server.py | 53 |
1 files changed, 46 insertions, 7 deletions
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py index 8b65aa9..75d6e1b 100644 --- a/gn_auth/auth/authentication/oauth2/server.py +++ b/gn_auth/auth/authentication/oauth2/server.py @@ -13,7 +13,11 @@ from gn_auth.auth.db import sqlite3 as db from .models.oauth2client import client from .models.oauth2token import OAuth2Token, save_token -from .models.jwtrefreshtoken import JWTRefreshToken, save_refresh_token +from .models.jwtrefreshtoken import ( + JWTRefreshToken, + link_child_token, + save_refresh_token, + load_refresh_token) from .grants.password_grant import PasswordGrant from .grants.refresh_token_grant import RefreshTokenGrant @@ -25,6 +29,7 @@ from .endpoints.introspection import IntrospectionEndpoint from .resource_server import require_oauth, BearerTokenValidator + def create_query_client_func() -> Callable: """Create the function that loads the client.""" def __query_client__(client_id: uuid.UUID): @@ -56,20 +61,53 @@ def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable: }) with db.connection(current_app.config["AUTH_DB"]) as conn: save_token(conn, _token) - save_refresh_token( + old_refresh_token = load_refresh_token( conn, - JWTRefreshToken( + request.form.get("refresh_token", "nosuchtoken") + ) + new_refresh_token = JWTRefreshToken( token=_token.refresh_token, client=request.client, user=request.user, issued_with=uuid.UUID(_jwt["jti"]), issued_at=datetime.datetime.fromtimestamp(_jwt["iat"]), - expires=datetime.datetime.fromtimestamp(_jwt["iat"]), + expires=datetime.datetime.fromtimestamp( + old_refresh_token.then( + lambda _tok: _tok.expires.timestamp() + ).maybe((int(_jwt["iat"]) + + RefreshTokenGrant.DEFAULT_EXPIRES_IN), + lambda _expires: _expires)), + scope=_token.get_scope(), revoked=False, - parent_of=None)) + parent_of=None) + save_refresh_token(conn, new_refresh_token) + old_refresh_token.then(lambda _tok: link_child_token( + conn, _tok.token, new_refresh_token.token)) return __save_token__ + +def make_jwt_token_generator(app): + """Make token generator function.""" + _gen = JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"]) + def __generator__(# pylint: disable=[too-many-arguments] + grant_type, + client, + user=None, + scope=None, + expires_in=None,# pylint: disable=[unused-argument] + include_refresh_token=True + ): + return _gen.__call__( + grant_type, + client, + user, + scope, + JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN, + include_refresh_token) + return __generator__ + + def setup_oauth2_server(app: Flask) -> None: """Set's up the oauth2 server for the flask application.""" server = AuthorizationServer() @@ -81,9 +119,10 @@ def setup_oauth2_server(app: Flask) -> None: server.register_grant(AuthorisationCodeGrant) server.register_grant(JWTBearerGrant) + jwttokengenerator = make_jwt_token_generator(app) server.register_token_generator( - "urn:ietf:params:oauth:grant-type:jwt-bearer", - JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"])) + "urn:ietf:params:oauth:grant-type:jwt-bearer", jwttokengenerator) + server.register_token_generator("refresh_token", jwttokengenerator) server.register_grant(RefreshTokenGrant) # register endpoints |