diff options
Diffstat (limited to 'gn_auth/auth/authentication/oauth2/models')
-rw-r--r-- | gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py | 35 |
1 files changed, 26 insertions, 9 deletions
diff --git a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py index e178c27..dba1563 100644 --- a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py +++ b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py @@ -55,14 +55,30 @@ class JWTRefreshToken(TokenMixin):# pylint: disable=[too-many-instance-attribute def revoke_refresh_token(conn: db.DbConnection, token: JWTRefreshToken) -> None: - """Revoke a refresh token.""" - # TODO: this token has been used before - revoke tree. - # TODO: Fetch all the children tokens - # HINT: - # SELECT t1.token, t1.parent_of FROM jwt_refresh_tokens AS t1 - # LEFT JOIN jwt_refresh_tokens AS t2 ON t1.parent_of=t2.token - # TODO: Revoke all children tokens including the treeroot token - raise NotImplementedError() + """Revoke a refresh token and all its children.""" + tree_query = """ + -- CTE: See https://codedamn.com/news/sql/recursive-sql-queries-hierarchical-data-management + WITH RECURSIVE token_tree (token, parent_of, revoked, level) AS ( + -- anchor member + SELECT token, parent_of, revoked, 1 AS level + FROM jwt_refresh_tokens + WHERE token=:root + -- merge the anchor above to the recursive member below! + UNION ALL + -- recursive member + SELECT jrt.token, jrt.parent_of, jrt.revoked, tt.level + 1 + FROM jwt_refresh_tokens AS jrt + INNER JOIN token_tree AS tt + ON tt.parent_of=jrt.token + ) SELECT * FROM token_tree; + """ + with db.cursor(conn) as cursor: + cursor.execute(tree_query, {"root": token.token}) + rows = cursor.fetchall() + if rows: + cursor.executemany( + "UPDATE jwt_refresh_tokens SET revoked=1 WHERE token=?", + tuple((row["token"],) for row in rows)) def save_refresh_token(conn: db.DbConnection, token: JWTRefreshToken) -> None: @@ -138,7 +154,8 @@ def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str): return Right(parent) def __revoke_and_raise_error__(_error_msg_): - revoke_refresh_token(conn, parenttoken) + load_refresh_token(conn, parenttoken).then( + lambda _tok: revoke_refresh_token(conn, _tok)) raise InvalidGrantError(_error_msg_) load_refresh_token(conn, parenttoken).maybe( |