diff options
Diffstat (limited to 'gn_auth/auth/authentication/oauth2/models')
-rw-r--r-- | gn_auth/auth/authentication/oauth2/models/authorization_code.py | 10 | ||||
-rw-r--r-- | gn_auth/auth/authentication/oauth2/models/oauth2token.py | 60 |
2 files changed, 40 insertions, 30 deletions
diff --git a/gn_auth/auth/authentication/oauth2/models/authorization_code.py b/gn_auth/auth/authentication/oauth2/models/authorization_code.py index faffca1..7bce0ca 100644 --- a/gn_auth/auth/authentication/oauth2/models/authorization_code.py +++ b/gn_auth/auth/authentication/oauth2/models/authorization_code.py @@ -3,6 +3,7 @@ from uuid import UUID from datetime import datetime from typing import NamedTuple +from pymonad.tools import monad_from_none_or_value from pymonad.maybe import Just, Maybe, Nothing from gn_auth.auth.db import sqlite3 as db @@ -67,9 +68,11 @@ def authorisation_code(conn: db.DbConnection , "WHERE code=:code AND client_id=:client_id") cursor.execute( query, {"code": code, "client_id": str(client.client_id)}) - result = cursor.fetchone() - if result: - return Just(AuthorisationCode( + + return monad_from_none_or_value( + Nothing, Just, cursor.fetchone() + ).then( + lambda result: AuthorisationCode( code_id=UUID(result["code_id"]), code=result["code"], client=client, @@ -80,7 +83,6 @@ def authorisation_code(conn: db.DbConnection , code_challenge=result["code_challenge"], code_challenge_method=result["code_challenge_method"], user=user_by_id(conn, UUID(result["user_id"])))) - return Nothing def save_authorisation_code(conn: db.DbConnection, auth_code: AuthorisationCode) -> AuthorisationCode: diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2token.py b/gn_auth/auth/authentication/oauth2/models/oauth2token.py index bbcdc4c..f539a07 100644 --- a/gn_auth/auth/authentication/oauth2/models/oauth2token.py +++ b/gn_auth/auth/authentication/oauth2/models/oauth2token.py @@ -3,6 +3,7 @@ import uuid import datetime from typing import NamedTuple, Optional +from pymonad.tools import monad_from_none_or_value from pymonad.maybe import Just, Maybe, Nothing from gn_auth.auth.db import sqlite3 as db @@ -50,40 +51,45 @@ class OAuth2Token(NamedTuple): """Check whether the token has been revoked.""" return self.revoked + def __token_from_resultset__(conn: db.DbConnection, rset) -> Maybe: - def __identity__(value): - return value try: the_user = user_by_id(conn, uuid.UUID(rset["user_id"])) except NotFoundError as _nfe: the_user = None - the_client = client(conn, uuid.UUID(rset["client_id"]), the_user) - - if the_client.is_just() and bool(the_user): - return Just(OAuth2Token(token_id=uuid.UUID(rset["token_id"]), - client=the_client.maybe(None, __identity__), - token_type=rset["token_type"], - access_token=rset["access_token"], - refresh_token=rset["refresh_token"], - scope=rset["scope"], - revoked=(rset["revoked"] == 1), - issued_at=datetime.datetime.fromtimestamp( - rset["issued_at"]), - expires_in=rset["expires_in"], - user=the_user))# type: ignore[arg-type] - - return Nothing + return client( + conn, uuid.UUID(rset["client_id"]), the_user + ).then( + lambda client: OAuth2Token( + token_id=uuid.UUID(rset["token_id"]), + client=client, + token_type=rset["token_type"], + access_token=rset["access_token"], + refresh_token=rset["refresh_token"], + scope=rset["scope"], + revoked=(rset["revoked"] == 1), + issued_at=datetime.datetime.fromtimestamp( + rset["issued_at"]), + expires_in=rset["expires_in"], + user=the_user # type: ignore + ) if bool(the_user) else + Nothing + ) + def token_by_access_token(conn: db.DbConnection, token_str: str) -> Maybe: """Retrieve token by its token string""" with db.cursor(conn) as cursor: cursor.execute("SELECT * FROM oauth2_tokens WHERE access_token=?", (token_str,)) - res = cursor.fetchone() - if res: - return __token_from_resultset__(conn, res) + return monad_from_none_or_value( + Nothing, Just, cursor.fetchone() + ).then( + lambda res: __token_from_resultset__( + conn, res + ) + ) - return Nothing def token_by_refresh_token(conn: db.DbConnection, token_str: str) -> Maybe: """Retrieve token by its token string""" @@ -91,11 +97,12 @@ def token_by_refresh_token(conn: db.DbConnection, token_str: str) -> Maybe: cursor.execute( "SELECT * FROM oauth2_tokens WHERE refresh_token=?", (token_str,)) - res = cursor.fetchone() - if res: - return __token_from_resultset__(conn, res) + return monad_from_none_or_value( + Nothing, Just, cursor.fetchone() + ).then( + lambda res: __token_from_resultset__(conn, res) + ) - return Nothing def revoke_token(token: OAuth2Token) -> OAuth2Token: """ @@ -108,6 +115,7 @@ def revoke_token(token: OAuth2Token) -> OAuth2Token: refresh_token=token.refresh_token, scope=token.scope, revoked=True, issued_at=token.issued_at, expires_in=token.expires_in, user=token.user) + def save_token(conn: db.DbConnection, token: OAuth2Token) -> None: """Save/Update the token.""" with db.cursor(conn) as cursor: |