diff options
Diffstat (limited to 'gn3/auth/authentication/oauth2/grants')
-rw-r--r-- | gn3/auth/authentication/oauth2/grants/authorisation_code_grant.py | 48 |
1 files changed, 44 insertions, 4 deletions
diff --git a/gn3/auth/authentication/oauth2/grants/authorisation_code_grant.py b/gn3/auth/authentication/oauth2/grants/authorisation_code_grant.py index d398192..f80d02e 100644 --- a/gn3/auth/authentication/oauth2/grants/authorisation_code_grant.py +++ b/gn3/auth/authentication/oauth2/grants/authorisation_code_grant.py @@ -1,24 +1,45 @@ """Classes and function for Authorisation Code flow.""" import uuid +import string +import random from typing import Optional +from datetime import datetime from flask import current_app as app from authlib.oauth2.rfc6749 import grants +from authlib.oauth2.rfc7636 import create_s256_code_challenge from gn3.auth import db +from gn3.auth.db_utils import with_db_connection from gn3.auth.authentication.users import User +from ..models.oauth2client import OAuth2Client +from ..models.authorization_code import ( + AuthorisationCode, authorisation_code, save_authorisation_code) + class AuthorisationCodeGrant(grants.AuthorizationCodeGrant): """Implement the 'Authorisation Code' grant.""" - TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"] + TOKEN_ENDPOINT_AUTH_METHODS: list[str] = [ + "client_secret_basic", "client_secret_post"] + AUTHORIZATION_CODE_LENGTH: int = 48 + TOKEN_ENDPOINT_HTTP_METHODS = ['POST'] + GRANT_TYPE = "authorization_code" + RESPONSE_TYPES = {'code'} def save_authorization_code(self, code, request): """Persist the authorisation code to database.""" - raise Exception("NOT IMPLEMENTED!", self, code, request) + client = request.client + nonce = "".join(random.sample(string.ascii_letters + string.digits, + k=self.AUTHORIZATION_CODE_LENGTH)) + return __save_authorization_code__(AuthorisationCode( + uuid.uuid4(), code, client, request.redirect_uri, request.scope, + nonce, int(datetime.now().timestamp()), + create_s256_code_challenge(app.config["SECRET_KEY"]), + "S256", request.user)) def query_authorization_code(self, code, client): """Retrieve the code from the database.""" - raise Exception("NOT IMPLEMENTED!", self, code, client) + return __query_authorization_code__(code, client) def delete_authorization_code(self, authorization_code):# pylint: disable=[no-self-use] """Delete the authorisation code.""" @@ -36,10 +57,29 @@ class AuthorisationCodeGrant(grants.AuthorizationCodeGrant): "WHERE authorisation_code.code=?") with db.connection(app.config["AUTH_DB"]) as conn: with db.cursor(conn) as cursor: - cursor.execute(query, (str(authorization_code.user_id),)) + cursor.execute(query, (str(authorization_code.code),)) res = cursor.fetchone() if res: return User( uuid.UUID(res["user_id"]), res["email"], res["name"]) return None + +def __query_authorization_code__( + code: str, client: OAuth2Client) -> AuthorisationCode: + """A helper function that creates a new database connection. + + This is found to be necessary since the `AuthorizationCodeGrant` class(es) + do not have a way to pass the database connection.""" + def __auth_code__(conn) -> str: + the_code = authorisation_code(conn, code, client) + return the_code.maybe(None, lambda cde: cde) # type: ignore[misc, arg-type, return-value] + + return with_db_connection(__auth_code__) + +def __save_authorization_code__(code: AuthorisationCode) -> AuthorisationCode: + """A helper function that creates a new database connection. + + This is found to be necessary since the `AuthorizationCodeGrant` class(es) + do not have a way to pass the database connection.""" + return with_db_connection(lambda conn: save_authorisation_code(conn, code)) |