aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/auth/authentication
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth/auth/authentication')
-rw-r--r--gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py35
1 files changed, 26 insertions, 9 deletions
diff --git a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
index e178c27..dba1563 100644
--- a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
+++ b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
@@ -55,14 +55,30 @@ class JWTRefreshToken(TokenMixin):# pylint: disable=[too-many-instance-attribute
def revoke_refresh_token(conn: db.DbConnection, token: JWTRefreshToken) -> None:
- """Revoke a refresh token."""
- # TODO: this token has been used before - revoke tree.
- # TODO: Fetch all the children tokens
- # HINT:
- # SELECT t1.token, t1.parent_of FROM jwt_refresh_tokens AS t1
- # LEFT JOIN jwt_refresh_tokens AS t2 ON t1.parent_of=t2.token
- # TODO: Revoke all children tokens including the treeroot token
- raise NotImplementedError()
+ """Revoke a refresh token and all its children."""
+ tree_query = """
+ -- CTE: See https://codedamn.com/news/sql/recursive-sql-queries-hierarchical-data-management
+ WITH RECURSIVE token_tree (token, parent_of, revoked, level) AS (
+ -- anchor member
+ SELECT token, parent_of, revoked, 1 AS level
+ FROM jwt_refresh_tokens
+ WHERE token=:root
+ -- merge the anchor above to the recursive member below!
+ UNION ALL
+ -- recursive member
+ SELECT jrt.token, jrt.parent_of, jrt.revoked, tt.level + 1
+ FROM jwt_refresh_tokens AS jrt
+ INNER JOIN token_tree AS tt
+ ON tt.parent_of=jrt.token
+ ) SELECT * FROM token_tree;
+ """
+ with db.cursor(conn) as cursor:
+ cursor.execute(tree_query, {"root": token.token})
+ rows = cursor.fetchall()
+ if rows:
+ cursor.executemany(
+ "UPDATE jwt_refresh_tokens SET revoked=1 WHERE token=?",
+ tuple((row["token"],) for row in rows))
def save_refresh_token(conn: db.DbConnection, token: JWTRefreshToken) -> None:
@@ -138,7 +154,8 @@ def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str):
return Right(parent)
def __revoke_and_raise_error__(_error_msg_):
- revoke_refresh_token(conn, parenttoken)
+ load_refresh_token(conn, parenttoken).then(
+ lambda _tok: revoke_refresh_token(conn, _tok))
raise InvalidGrantError(_error_msg_)
load_refresh_token(conn, parenttoken).maybe(