aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-05-24 09:58:07 -0500
committerFrederick Muriuki Muriithi2024-05-24 09:58:07 -0500
commit707aefbf5e2c82f9d2504f0eed8548d0e177ee96 (patch)
treed9b30d6d136d04cce43c305ef8349cada8e828dc
parent0582565fa7db4b95e86fb0dde8d83e3170e566a7 (diff)
downloadgn-auth-707aefbf5e2c82f9d2504f0eed8548d0e177ee96.tar.gz
Use monads consistently to reduce chances of errors.
-rw-r--r--gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py23
1 files changed, 14 insertions, 9 deletions
diff --git a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
index 40b1554..04908bc 100644
--- a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
+++ b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
@@ -11,6 +11,7 @@ from dataclasses import dataclass
from authlib.oauth2.rfc6749 import TokenMixin, InvalidGrantError
+from pymonad.either import Left, Right
from pymonad.maybe import Just, Maybe, Nothing
from pymonad.tools import monad_from_none_or_value
@@ -117,15 +118,19 @@ def load_refresh_token(conn: db.DbConnection, token: str) -> Maybe:
def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str):
"""Link child token."""
- _parent = load_refresh_token(conn, parenttoken).maybe(
- None, lambda _tok: _tok)
- if _parent is None:
- raise InvalidGrantError("Token not found.")
-
- with db.cursor(conn) as cursor:
- cursor.execute(("UPDATE jwt_refresh_tokens SET parent_of=:childtoken "
- "WHERE token=:parenttoken"),
- {"parenttoken": parenttoken, "childtoken": childtoken})
+ def __link_to_child__(parent):
+ with db.cursor(conn) as cursor:
+ cursor.execute(
+ ("UPDATE jwt_refresh_tokens SET parent_of=:childtoken "
+ "WHERE token=:parenttoken"),
+ {"parenttoken": parent.token, "childtoken": childtoken})
+
+ def __raise_error__(_error_msg_):
+ raise InvalidGrantError(_error_msg_)
+
+ load_refresh_token(conn, parenttoken).maybe(
+ Left("Token not found"), Right).either(
+ __raise_error__, __link_to_child__)
def is_refresh_token_valid(token: JWTRefreshToken, client: OAuth2Client) -> bool: