about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-05-24 14:23:16 -0500
committerFrederick Muriuki Muriithi2024-05-24 14:23:16 -0500
commitb21357e122280ef10bcbe464b27b652c802f4383 (patch)
tree173a3c64bb001a91d34548037c2e17a493bb1d6c
parent75ea3002799a6323c29da1ce36aa119b12469b61 (diff)
downloadgn-auth-b21357e122280ef10bcbe464b27b652c802f4383.tar.gz
Revoke refresh token, and all its children.
-rw-r--r--gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py35
1 files changed, 26 insertions, 9 deletions
diff --git a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
index e178c27..dba1563 100644
--- a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
+++ b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
@@ -55,14 +55,30 @@ class JWTRefreshToken(TokenMixin):# pylint: disable=[too-many-instance-attribute
 
 
 def revoke_refresh_token(conn: db.DbConnection, token: JWTRefreshToken) -> None:
-    """Revoke a refresh token."""
-    # TODO: this token has been used before - revoke tree.
-    # TODO: Fetch all the children tokens
-    #   HINT:
-    #     SELECT t1.token, t1.parent_of FROM jwt_refresh_tokens AS t1
-    #     LEFT JOIN jwt_refresh_tokens AS t2 ON t1.parent_of=t2.token
-    # TODO: Revoke all children tokens including the treeroot token
-    raise NotImplementedError()
+    """Revoke a refresh token and all its children."""
+    tree_query = """
+    -- CTE: See https://codedamn.com/news/sql/recursive-sql-queries-hierarchical-data-management
+    WITH RECURSIVE token_tree (token, parent_of, revoked, level) AS (
+        -- anchor member
+        SELECT token, parent_of, revoked, 1 AS level
+        FROM jwt_refresh_tokens
+        WHERE token=:root
+        -- merge the anchor above to the recursive member below!
+        UNION ALL
+        -- recursive member
+        SELECT jrt.token, jrt.parent_of, jrt.revoked, tt.level + 1
+        FROM jwt_refresh_tokens AS jrt
+        INNER JOIN token_tree AS tt
+        ON tt.parent_of=jrt.token
+    ) SELECT * FROM token_tree;
+    """
+    with db.cursor(conn) as cursor:
+        cursor.execute(tree_query, {"root": token.token})
+        rows = cursor.fetchall()
+        if rows:
+            cursor.executemany(
+                "UPDATE jwt_refresh_tokens SET revoked=1 WHERE token=?",
+                tuple((row["token"],) for row in rows))
 
 
 def save_refresh_token(conn: db.DbConnection, token: JWTRefreshToken) -> None:
@@ -138,7 +154,8 @@ def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str):
             return Right(parent)
 
     def __revoke_and_raise_error__(_error_msg_):
-        revoke_refresh_token(conn, parenttoken)
+        load_refresh_token(conn, parenttoken).then(
+            lambda _tok: revoke_refresh_token(conn, _tok))
         raise InvalidGrantError(_error_msg_)
 
     load_refresh_token(conn, parenttoken).maybe(