aboutsummaryrefslogtreecommitdiff
"""Initialise the OAuth2 Server"""
import os
import uuid
from pathlib import Path
from typing import Callable
from datetime import datetime, timedelta

from pymonad.either import Left
from flask import Flask, current_app
from authlib.jose import jwt, KeySet, JsonWebKey
from authlib.oauth2.rfc6749.errors import InvalidClientError
from authlib.integrations.flask_oauth2 import AuthorizationServer

from gn_auth.auth.db import sqlite3 as db
from gn_auth.auth.jwks import (
    list_jwks, newest_jwk, jwks_directory, generate_and_save_private_key)

from .models.oauth2client import client as fetch_client
from .models.oauth2token import OAuth2Token, save_token
from .models.jwtrefreshtoken import (
    JWTRefreshToken,
    link_child_token,
    save_refresh_token,
    load_refresh_token)

from .grants.password_grant import PasswordGrant
from .grants.refresh_token_grant import RefreshTokenGrant
from .grants.authorisation_code_grant import AuthorisationCodeGrant
from .grants.jwt_bearer_grant import JWTBearerGrant, JWTBearerTokenGenerator

from .endpoints.revocation import RevocationEndpoint
from .endpoints.introspection import IntrospectionEndpoint

from .resource_server import require_oauth, JWTBearerTokenValidator


def create_query_client_func() -> Callable:
    """Create the function that loads the client."""
    def __query_client__(client_id: uuid.UUID):
        # use current_app rather than passing the db_uri to avoid issues
        # when config changes, e.g. while testing.
        with db.connection(current_app.config["AUTH_DB"]) as conn:
            _client = fetch_client(conn, client_id).maybe(
                None, lambda clt: clt) # type: ignore[misc]
            if bool(_client):
                return _client
            raise InvalidClientError(
                "No client found for the given CLIENT_ID and CLIENT_SECRET.")

    return __query_client__

def create_save_token_func(token_model: type, app: Flask) -> Callable:
    """Create the function that saves the token."""
    def __save_token__(token, request):
        _jwt = jwt.decode(
            token["access_token"],
            newest_jwk_with_rotation(
                jwks_directory(app),
                int(app.config["JWKS_ROTATION_AGE_DAYS"])))
        _token = token_model(
            token_id=uuid.UUID(_jwt["jti"]),
            client=request.client,
            user=request.user,
            **{
                "refresh_token": None,
                "revoked": False,
                "issued_at": datetime.now(),
                **token
            })
        with db.connection(current_app.config["AUTH_DB"]) as conn:
            save_token(conn, _token)
            old_refresh_token = load_refresh_token(
                conn,
                request.form.get("refresh_token", "nosuchtoken")
            )
            new_refresh_token = JWTRefreshToken(
                    token=_token.refresh_token,
                    client=request.client,
                    user=request.user,
                    issued_with=uuid.UUID(_jwt["jti"]),
                    issued_at=datetime.fromtimestamp(_jwt["iat"]),
                    expires=datetime.fromtimestamp(
                        old_refresh_token.then(
                            lambda _tok: _tok.expires.timestamp()
                        ).maybe((int(_jwt["iat"]) +
                                 RefreshTokenGrant.DEFAULT_EXPIRES_IN),
                                lambda _expires: _expires)),
                    scope=_token.get_scope(),
                    revoked=False,
                    parent_of=None)
            save_refresh_token(conn, new_refresh_token)
            old_refresh_token.then(lambda _tok: link_child_token(
                conn, _tok.token, new_refresh_token.token))

    return __save_token__

def newest_jwk_with_rotation(jwksdir: Path, keyage: int) -> JsonWebKey:
    """
    Retrieve the latests JWK, creating a new one if older than `keyage` days.
    """
    def newer_than_days(jwkey):
        filestat = os.stat(Path(
            jwksdir, f"{jwkey.as_dict()['kid']}.private.pem"))
        oldesttimeallowed = (datetime.now() - timedelta(days=keyage))
        if filestat.st_ctime < (oldesttimeallowed.timestamp()):
            return Left("JWK is too old!")
        return jwkey

    return newest_jwk(jwksdir).then(newer_than_days).either(
        lambda _errmsg: generate_and_save_private_key(jwksdir),
        lambda key: key)


def make_jwt_token_generator(app):
    """Make token generator function."""
    def __generator__(# pylint: disable=[too-many-arguments]
            grant_type,
            client,
            user=None,
            scope=None,
            expires_in=None,# pylint: disable=[unused-argument]
            include_refresh_token=True
    ):
        return JWTBearerTokenGenerator(
            newest_jwk_with_rotation(
                jwks_directory(app),
                int(app.config["JWKS_ROTATION_AGE_DAYS"]))).__call__(
                        grant_type,
                        client,
                        user,
                        scope,
                        JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
                        include_refresh_token)
    return __generator__


def setup_oauth2_server(app: Flask) -> None:
    """Set's up the oauth2 server for the flask application."""
    server = AuthorizationServer()
    server.register_grant(PasswordGrant)

    # Figure out a common `code_verifier` for GN2 and GN3 and set
    # server.register_grant(AuthorisationCodeGrant, [CodeChallenge(required=False)])
    # below
    server.register_grant(AuthorisationCodeGrant)

    server.register_grant(JWTBearerGrant)
    jwttokengenerator = make_jwt_token_generator(app)
    server.register_token_generator(
        "urn:ietf:params:oauth:grant-type:jwt-bearer", jwttokengenerator)
    server.register_token_generator("refresh_token", jwttokengenerator)
    server.register_grant(RefreshTokenGrant)

    # register endpoints
    server.register_endpoint(RevocationEndpoint)
    server.register_endpoint(IntrospectionEndpoint)

    # init server
    server.init_app(
        app,
        query_client=create_query_client_func(),
        save_token=create_save_token_func(OAuth2Token, app))
    app.config["OAUTH2_SERVER"] = server

    ## Set up the token validators
    require_oauth.register_token_validator(
        JWTBearerTokenValidator(KeySet(list_jwks(jwks_directory(app)))))