aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/auth/authentication/oauth2/models
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth/auth/authentication/oauth2/models')
-rw-r--r--gn_auth/auth/authentication/oauth2/models/authorization_code.py10
-rw-r--r--gn_auth/auth/authentication/oauth2/models/oauth2token.py60
2 files changed, 40 insertions, 30 deletions
diff --git a/gn_auth/auth/authentication/oauth2/models/authorization_code.py b/gn_auth/auth/authentication/oauth2/models/authorization_code.py
index faffca1..7bce0ca 100644
--- a/gn_auth/auth/authentication/oauth2/models/authorization_code.py
+++ b/gn_auth/auth/authentication/oauth2/models/authorization_code.py
@@ -3,6 +3,7 @@ from uuid import UUID
from datetime import datetime
from typing import NamedTuple
+from pymonad.tools import monad_from_none_or_value
from pymonad.maybe import Just, Maybe, Nothing
from gn_auth.auth.db import sqlite3 as db
@@ -67,9 +68,11 @@ def authorisation_code(conn: db.DbConnection ,
"WHERE code=:code AND client_id=:client_id")
cursor.execute(
query, {"code": code, "client_id": str(client.client_id)})
- result = cursor.fetchone()
- if result:
- return Just(AuthorisationCode(
+
+ return monad_from_none_or_value(
+ Nothing, Just, cursor.fetchone()
+ ).then(
+ lambda result: AuthorisationCode(
code_id=UUID(result["code_id"]),
code=result["code"],
client=client,
@@ -80,7 +83,6 @@ def authorisation_code(conn: db.DbConnection ,
code_challenge=result["code_challenge"],
code_challenge_method=result["code_challenge_method"],
user=user_by_id(conn, UUID(result["user_id"]))))
- return Nothing
def save_authorisation_code(conn: db.DbConnection,
auth_code: AuthorisationCode) -> AuthorisationCode:
diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2token.py b/gn_auth/auth/authentication/oauth2/models/oauth2token.py
index bbcdc4c..f539a07 100644
--- a/gn_auth/auth/authentication/oauth2/models/oauth2token.py
+++ b/gn_auth/auth/authentication/oauth2/models/oauth2token.py
@@ -3,6 +3,7 @@ import uuid
import datetime
from typing import NamedTuple, Optional
+from pymonad.tools import monad_from_none_or_value
from pymonad.maybe import Just, Maybe, Nothing
from gn_auth.auth.db import sqlite3 as db
@@ -50,40 +51,45 @@ class OAuth2Token(NamedTuple):
"""Check whether the token has been revoked."""
return self.revoked
+
def __token_from_resultset__(conn: db.DbConnection, rset) -> Maybe:
- def __identity__(value):
- return value
try:
the_user = user_by_id(conn, uuid.UUID(rset["user_id"]))
except NotFoundError as _nfe:
the_user = None
- the_client = client(conn, uuid.UUID(rset["client_id"]), the_user)
-
- if the_client.is_just() and bool(the_user):
- return Just(OAuth2Token(token_id=uuid.UUID(rset["token_id"]),
- client=the_client.maybe(None, __identity__),
- token_type=rset["token_type"],
- access_token=rset["access_token"],
- refresh_token=rset["refresh_token"],
- scope=rset["scope"],
- revoked=(rset["revoked"] == 1),
- issued_at=datetime.datetime.fromtimestamp(
- rset["issued_at"]),
- expires_in=rset["expires_in"],
- user=the_user))# type: ignore[arg-type]
-
- return Nothing
+ return client(
+ conn, uuid.UUID(rset["client_id"]), the_user
+ ).then(
+ lambda client: OAuth2Token(
+ token_id=uuid.UUID(rset["token_id"]),
+ client=client,
+ token_type=rset["token_type"],
+ access_token=rset["access_token"],
+ refresh_token=rset["refresh_token"],
+ scope=rset["scope"],
+ revoked=(rset["revoked"] == 1),
+ issued_at=datetime.datetime.fromtimestamp(
+ rset["issued_at"]),
+ expires_in=rset["expires_in"],
+ user=the_user # type: ignore
+ ) if bool(the_user) else
+ Nothing
+ )
+
def token_by_access_token(conn: db.DbConnection, token_str: str) -> Maybe:
"""Retrieve token by its token string"""
with db.cursor(conn) as cursor:
cursor.execute("SELECT * FROM oauth2_tokens WHERE access_token=?",
(token_str,))
- res = cursor.fetchone()
- if res:
- return __token_from_resultset__(conn, res)
+ return monad_from_none_or_value(
+ Nothing, Just, cursor.fetchone()
+ ).then(
+ lambda res: __token_from_resultset__(
+ conn, res
+ )
+ )
- return Nothing
def token_by_refresh_token(conn: db.DbConnection, token_str: str) -> Maybe:
"""Retrieve token by its token string"""
@@ -91,11 +97,12 @@ def token_by_refresh_token(conn: db.DbConnection, token_str: str) -> Maybe:
cursor.execute(
"SELECT * FROM oauth2_tokens WHERE refresh_token=?",
(token_str,))
- res = cursor.fetchone()
- if res:
- return __token_from_resultset__(conn, res)
+ return monad_from_none_or_value(
+ Nothing, Just, cursor.fetchone()
+ ).then(
+ lambda res: __token_from_resultset__(conn, res)
+ )
- return Nothing
def revoke_token(token: OAuth2Token) -> OAuth2Token:
"""
@@ -108,6 +115,7 @@ def revoke_token(token: OAuth2Token) -> OAuth2Token:
refresh_token=token.refresh_token, scope=token.scope, revoked=True,
issued_at=token.issued_at, expires_in=token.expires_in, user=token.user)
+
def save_token(conn: db.DbConnection, token: OAuth2Token) -> None:
"""Save/Update the token."""
with db.cursor(conn) as cursor: