about summary refs log tree commit diff
path: root/gn_auth/auth/authentication
diff options
context:
space:
mode:
authorMunyoki Kilyungi2024-03-06 15:37:57 +0300
committerMunyoki Kilyungi2024-03-08 15:03:52 +0300
commit82f76b879be1baafeddc80bc361744e59e47e42b (patch)
tree622ce1385f840c3a9a10062a091d3ed08af15917 /gn_auth/auth/authentication
parentcce9a772d6fe55c62368f6895492a8683e7b82ca (diff)
downloadgn-auth-82f76b879be1baafeddc80bc361744e59e47e42b.tar.gz
Replace "if" branching with "monad_from_none_or_value".
* gn_auth/auth/authentication/oauth2/models/authorization_code.py:
Import "monad_from_none_or_value".
(authorisation_code): Replace if branching for Nothing/Just check with "monad_from_none_or_value".
* gn_auth/auth/authentication/oauth2/models/oauth2token.py: Import
"monad_from_none_or_value".
(__token_from_resultset__): Replace if branching for Nothing/Just
check with "monad_from_none_or_value".
(token_by_access_token): Ditto.
(token_by_refresh_token): Ditto.

Signed-off-by: Munyoki Kilyungi <me@bonfacemunyoki.com>
Diffstat (limited to 'gn_auth/auth/authentication')
-rw-r--r--gn_auth/auth/authentication/oauth2/models/authorization_code.py10
-rw-r--r--gn_auth/auth/authentication/oauth2/models/oauth2token.py60
2 files changed, 40 insertions, 30 deletions
diff --git a/gn_auth/auth/authentication/oauth2/models/authorization_code.py b/gn_auth/auth/authentication/oauth2/models/authorization_code.py
index faffca1..7bce0ca 100644
--- a/gn_auth/auth/authentication/oauth2/models/authorization_code.py
+++ b/gn_auth/auth/authentication/oauth2/models/authorization_code.py
@@ -3,6 +3,7 @@ from uuid import UUID
 from datetime import datetime
 from typing import NamedTuple
 
+from pymonad.tools import monad_from_none_or_value
 from pymonad.maybe import Just, Maybe, Nothing
 
 from gn_auth.auth.db import sqlite3 as db
@@ -67,9 +68,11 @@ def authorisation_code(conn: db.DbConnection ,
                  "WHERE code=:code AND client_id=:client_id")
         cursor.execute(
             query, {"code": code, "client_id": str(client.client_id)})
-        result = cursor.fetchone()
-        if result:
-            return Just(AuthorisationCode(
+
+        return monad_from_none_or_value(
+            Nothing, Just, cursor.fetchone()
+        ).then(
+            lambda result: AuthorisationCode(
                 code_id=UUID(result["code_id"]),
                 code=result["code"],
                 client=client,
@@ -80,7 +83,6 @@ def authorisation_code(conn: db.DbConnection ,
                 code_challenge=result["code_challenge"],
                 code_challenge_method=result["code_challenge_method"],
                 user=user_by_id(conn, UUID(result["user_id"]))))
-        return Nothing
 
 def save_authorisation_code(conn: db.DbConnection,
                             auth_code: AuthorisationCode) -> AuthorisationCode:
diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2token.py b/gn_auth/auth/authentication/oauth2/models/oauth2token.py
index bbcdc4c..f539a07 100644
--- a/gn_auth/auth/authentication/oauth2/models/oauth2token.py
+++ b/gn_auth/auth/authentication/oauth2/models/oauth2token.py
@@ -3,6 +3,7 @@ import uuid
 import datetime
 from typing import NamedTuple, Optional
 
+from pymonad.tools import monad_from_none_or_value
 from pymonad.maybe import Just, Maybe, Nothing
 
 from gn_auth.auth.db import sqlite3 as db
@@ -50,40 +51,45 @@ class OAuth2Token(NamedTuple):
         """Check whether the token has been revoked."""
         return self.revoked
 
+
 def __token_from_resultset__(conn: db.DbConnection, rset) -> Maybe:
-    def __identity__(value):
-        return value
     try:
         the_user = user_by_id(conn, uuid.UUID(rset["user_id"]))
     except NotFoundError as _nfe:
         the_user = None
-    the_client = client(conn, uuid.UUID(rset["client_id"]), the_user)
-
-    if the_client.is_just() and bool(the_user):
-        return Just(OAuth2Token(token_id=uuid.UUID(rset["token_id"]),
-                                client=the_client.maybe(None, __identity__),
-                                token_type=rset["token_type"],
-                                access_token=rset["access_token"],
-                                refresh_token=rset["refresh_token"],
-                                scope=rset["scope"],
-                                revoked=(rset["revoked"] == 1),
-                                issued_at=datetime.datetime.fromtimestamp(
-                                    rset["issued_at"]),
-                                expires_in=rset["expires_in"],
-                                user=the_user))# type: ignore[arg-type]
-
-    return Nothing
+    return client(
+        conn, uuid.UUID(rset["client_id"]), the_user
+    ).then(
+        lambda client: OAuth2Token(
+            token_id=uuid.UUID(rset["token_id"]),
+            client=client,
+            token_type=rset["token_type"],
+            access_token=rset["access_token"],
+            refresh_token=rset["refresh_token"],
+            scope=rset["scope"],
+            revoked=(rset["revoked"] == 1),
+            issued_at=datetime.datetime.fromtimestamp(
+                rset["issued_at"]),
+            expires_in=rset["expires_in"],
+            user=the_user  # type: ignore
+        ) if bool(the_user) else
+        Nothing
+    )
+
 
 def token_by_access_token(conn: db.DbConnection, token_str: str) -> Maybe:
     """Retrieve token by its token string"""
     with db.cursor(conn) as cursor:
         cursor.execute("SELECT * FROM oauth2_tokens WHERE access_token=?",
                        (token_str,))
-        res = cursor.fetchone()
-        if res:
-            return __token_from_resultset__(conn, res)
+        return monad_from_none_or_value(
+            Nothing, Just, cursor.fetchone()
+        ).then(
+            lambda res: __token_from_resultset__(
+                conn, res
+            )
+        )
 
-    return Nothing
 
 def token_by_refresh_token(conn: db.DbConnection, token_str: str) -> Maybe:
     """Retrieve token by its token string"""
@@ -91,11 +97,12 @@ def token_by_refresh_token(conn: db.DbConnection, token_str: str) -> Maybe:
         cursor.execute(
             "SELECT * FROM oauth2_tokens WHERE refresh_token=?",
             (token_str,))
-        res = cursor.fetchone()
-        if res:
-            return __token_from_resultset__(conn, res)
+        return monad_from_none_or_value(
+            Nothing, Just, cursor.fetchone()
+        ).then(
+            lambda res: __token_from_resultset__(conn, res)
+        )
 
-    return Nothing
 
 def revoke_token(token: OAuth2Token) -> OAuth2Token:
     """
@@ -108,6 +115,7 @@ def revoke_token(token: OAuth2Token) -> OAuth2Token:
         refresh_token=token.refresh_token, scope=token.scope, revoked=True,
         issued_at=token.issued_at, expires_in=token.expires_in, user=token.user)
 
+
 def save_token(conn: db.DbConnection, token: OAuth2Token) -> None:
     """Save/Update the token."""
     with db.cursor(conn) as cursor: