aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/auth/authentication/oauth2/server.py
blob: 8b65aa923fc87f9c7cf5a66d377c3de21a9021c0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Initialise the OAuth2 Server"""
import uuid
import datetime
from typing import Callable

from flask import Flask, current_app
from authlib.jose import jwk, jwt
from authlib.oauth2.rfc7523 import JWTBearerTokenValidator
from authlib.oauth2.rfc6749.errors import InvalidClientError
from authlib.integrations.flask_oauth2 import AuthorizationServer

from gn_auth.auth.db import sqlite3 as db

from .models.oauth2client import client
from .models.oauth2token import OAuth2Token, save_token
from .models.jwtrefreshtoken import JWTRefreshToken, save_refresh_token

from .grants.password_grant import PasswordGrant
from .grants.refresh_token_grant import RefreshTokenGrant
from .grants.authorisation_code_grant import AuthorisationCodeGrant
from .grants.jwt_bearer_grant import JWTBearerGrant, JWTBearerTokenGenerator

from .endpoints.revocation import RevocationEndpoint
from .endpoints.introspection import IntrospectionEndpoint

from .resource_server import require_oauth, BearerTokenValidator

def create_query_client_func() -> Callable:
    """Create the function that loads the client."""
    def __query_client__(client_id: uuid.UUID):
        # use current_app rather than passing the db_uri to avoid issues
        # when config changes, e.g. while testing.
        with db.connection(current_app.config["AUTH_DB"]) as conn:
            _client = client(conn, client_id).maybe(
                None, lambda clt: clt) # type: ignore[misc]
            if bool(_client):
                return _client
            raise InvalidClientError(
                "No client found for the given CLIENT_ID and CLIENT_SECRET.")

    return __query_client__

def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable:
    """Create the function that saves the token."""
    def __save_token__(token, request):
        _jwt = jwt.decode(token["access_token"], jwtkey)
        _token = token_model(
            token_id=uuid.UUID(_jwt["jti"]),
            client=request.client,
            user=request.user,
            **{
                "refresh_token": None,
                "revoked": False,
                "issued_at": datetime.datetime.now(),
                **token
            })
        with db.connection(current_app.config["AUTH_DB"]) as conn:
            save_token(conn, _token)
            save_refresh_token(
                conn,
                JWTRefreshToken(
                    token=_token.refresh_token,
                    client=request.client,
                    user=request.user,
                    issued_with=uuid.UUID(_jwt["jti"]),
                    issued_at=datetime.datetime.fromtimestamp(_jwt["iat"]),
                    expires=datetime.datetime.fromtimestamp(_jwt["iat"]),
                    revoked=False,
                    parent_of=None))

    return __save_token__

def setup_oauth2_server(app: Flask) -> None:
    """Set's up the oauth2 server for the flask application."""
    server = AuthorizationServer()
    server.register_grant(PasswordGrant)

    # Figure out a common `code_verifier` for GN2 and GN3 and set
    # server.register_grant(AuthorisationCodeGrant, [CodeChallenge(required=False)])
    # below
    server.register_grant(AuthorisationCodeGrant)

    server.register_grant(JWTBearerGrant)
    server.register_token_generator(
        "urn:ietf:params:oauth:grant-type:jwt-bearer",
        JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"]))
    server.register_grant(RefreshTokenGrant)

    # register endpoints
    server.register_endpoint(RevocationEndpoint)
    server.register_endpoint(IntrospectionEndpoint)

    # init server
    server.init_app(
        app,
        query_client=create_query_client_func(),
        save_token=create_save_token_func(
            OAuth2Token, app.config["SSL_PRIVATE_KEY"]))
    app.config["OAUTH2_SERVER"] = server

    ## Set up the token validators
    require_oauth.register_token_validator(BearerTokenValidator())
    require_oauth.register_token_validator(
        JWTBearerTokenValidator(app.config["SSL_PRIVATE_KEY"].get_public_key()))