aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/auth/authentication/oauth2/server.py
blob: a8109b790279f6b78bc4b127a85344a60b680396 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Initialise the OAuth2 Server"""
import uuid
from typing import Callable
from datetime import datetime

from flask import Flask, current_app
from authlib.jose import jwt, KeySet
from authlib.oauth2.rfc6749.errors import InvalidClientError
from authlib.integrations.flask_oauth2 import AuthorizationServer
from authlib.oauth2.rfc6749 import OAuth2Request
from authlib.integrations.flask_helpers import create_oauth_request

from gn_auth.auth.db import sqlite3 as db
from gn_auth.auth.jwks import (
    list_jwks,
    jwks_directory,
    newest_jwk_with_rotation)

from .models.oauth2client import client as fetch_client
from .models.oauth2token import OAuth2Token, save_token
from .models.jwtrefreshtoken import (
    JWTRefreshToken,
    link_child_token,
    save_refresh_token,
    load_refresh_token)

from .grants.password_grant import PasswordGrant
from .grants.refresh_token_grant import RefreshTokenGrant
from .grants.authorisation_code_grant import AuthorisationCodeGrant
from .grants.jwt_bearer_grant import JWTBearerGrant, JWTBearerTokenGenerator

from .endpoints.revocation import RevocationEndpoint
from .endpoints.introspection import IntrospectionEndpoint

from .resource_server import require_oauth, JWTBearerTokenValidator


def create_query_client_func() -> Callable:
    """Create the function that loads the client."""
    def __query_client__(client_id: uuid.UUID):
        # use current_app rather than passing the db_uri to avoid issues
        # when config changes, e.g. while testing.
        with db.connection(current_app.config["AUTH_DB"]) as conn:
            _client = fetch_client(conn, client_id).maybe(
                None, lambda clt: clt) # type: ignore[misc]
            if bool(_client):
                return _client
            raise InvalidClientError(
                "No client found for the given CLIENT_ID and CLIENT_SECRET.")

    return __query_client__

def create_save_token_func(token_model: type, app: Flask) -> Callable:
    """Create the function that saves the token."""
    def __save_token__(token, request):
        _jwt = jwt.decode(
            token["access_token"],
            newest_jwk_with_rotation(
                jwks_directory(app),
                int(app.config["JWKS_ROTATION_AGE_DAYS"])))
        _token = token_model(
            token_id=uuid.UUID(_jwt["jti"]),
            client=request.client,
            user=request.user,
            **{
                "refresh_token": None,
                "revoked": False,
                "issued_at": datetime.now(),
                **token
            })
        with db.connection(current_app.config["AUTH_DB"]) as conn:
            save_token(conn, _token)
            old_refresh_token = load_refresh_token(
                conn,
                request.form.get("refresh_token", "nosuchtoken")
            )
            new_refresh_token = JWTRefreshToken(
                    token=_token.refresh_token,
                    client=request.client,
                    user=request.user,
                    issued_with=uuid.UUID(_jwt["jti"]),
                    issued_at=datetime.fromtimestamp(_jwt["iat"]),
                    expires=datetime.fromtimestamp(
                        old_refresh_token.then(
                            lambda _tok: _tok.expires.timestamp()
                        ).maybe((int(_jwt["iat"]) +
                                 RefreshTokenGrant.DEFAULT_EXPIRES_IN),
                                lambda _expires: _expires)),
                    scope=_token.get_scope(),
                    revoked=False,
                    parent_of=None)
            save_refresh_token(conn, new_refresh_token)
            old_refresh_token.then(lambda _tok: link_child_token(
                conn, _tok.token, new_refresh_token.token))

    return __save_token__

def make_jwt_token_generator(app):
    """Make token generator function."""
    def __generator__(# pylint: disable=[too-many-arguments]
            grant_type,
            client,
            user=None,
            scope=None,
            expires_in=None,# pylint: disable=[unused-argument]
            include_refresh_token=True
    ):
        return JWTBearerTokenGenerator(
            newest_jwk_with_rotation(
                jwks_directory(app),
                int(app.config["JWKS_ROTATION_AGE_DAYS"]))).__call__(
                        grant_type,
                        client,
                        user,
                        scope,
                        JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
                        include_refresh_token)
    return __generator__



class JsonAuthorizationServer(AuthorizationServer):
    """An authorisation server using JSON rather than FORMDATA."""

    def create_oauth2_request(self, request):
        """Create an OAuth2 Request from the flask request."""
        res = create_oauth_request(request, OAuth2Request, True)
        return res


def setup_oauth2_server(app: Flask) -> None:
    """Set's up the oauth2 server for the flask application."""
    server = JsonAuthorizationServer()
    server.register_grant(PasswordGrant)

    # Figure out a common `code_verifier` for GN2 and GN3 and set
    # server.register_grant(AuthorisationCodeGrant, [CodeChallenge(required=False)])
    # below
    server.register_grant(AuthorisationCodeGrant)

    server.register_grant(JWTBearerGrant)
    jwttokengenerator = make_jwt_token_generator(app)
    server.register_token_generator(
        "urn:ietf:params:oauth:grant-type:jwt-bearer", jwttokengenerator)
    server.register_token_generator("refresh_token", jwttokengenerator)
    server.register_grant(RefreshTokenGrant)

    # register endpoints
    server.register_endpoint(RevocationEndpoint)
    server.register_endpoint(IntrospectionEndpoint)

    # init server
    server.init_app(
        app,
        query_client=create_query_client_func(),
        save_token=create_save_token_func(OAuth2Token, app))
    app.config["OAUTH2_SERVER"] = server

    ## Set up the token validators
    require_oauth.register_token_validator(
        JWTBearerTokenValidator(KeySet(list_jwks(jwks_directory(app)))))