1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
|
"""Initialise the OAuth2 Server"""
import os
import uuid
from pathlib import Path
from typing import Callable
from datetime import datetime, timedelta
from pymonad.either import Left
from flask import Flask, current_app
from authlib.jose import jwt, KeySet, JsonWebKey
from authlib.oauth2.rfc6749.errors import InvalidClientError
from authlib.integrations.flask_oauth2 import AuthorizationServer
from authlib.oauth2.rfc6749 import OAuth2Request
from authlib.integrations.flask_helpers import create_oauth_request
from gn_auth.auth.db import sqlite3 as db
from gn_auth.auth.jwks import (
list_jwks, newest_jwk_with_rotation, jwks_directory, generate_and_save_private_key)
from .models.oauth2client import client as fetch_client
from .models.oauth2token import OAuth2Token, save_token
from .models.jwtrefreshtoken import (
JWTRefreshToken,
link_child_token,
save_refresh_token,
load_refresh_token)
from .grants.password_grant import PasswordGrant
from .grants.refresh_token_grant import RefreshTokenGrant
from .grants.authorisation_code_grant import AuthorisationCodeGrant
from .grants.jwt_bearer_grant import JWTBearerGrant, JWTBearerTokenGenerator
from .endpoints.revocation import RevocationEndpoint
from .endpoints.introspection import IntrospectionEndpoint
from .resource_server import require_oauth, JWTBearerTokenValidator
def create_query_client_func() -> Callable:
"""Create the function that loads the client."""
def __query_client__(client_id: uuid.UUID):
# use current_app rather than passing the db_uri to avoid issues
# when config changes, e.g. while testing.
with db.connection(current_app.config["AUTH_DB"]) as conn:
_client = fetch_client(conn, client_id).maybe(
None, lambda clt: clt) # type: ignore[misc]
if bool(_client):
return _client
raise InvalidClientError(
"No client found for the given CLIENT_ID and CLIENT_SECRET.")
return __query_client__
def create_save_token_func(token_model: type, app: Flask) -> Callable:
"""Create the function that saves the token."""
def __save_token__(token, request):
_jwt = jwt.decode(
token["access_token"],
newest_jwk_with_rotation(
jwks_directory(app),
int(app.config["JWKS_ROTATION_AGE_DAYS"])))
_token = token_model(
token_id=uuid.UUID(_jwt["jti"]),
client=request.client,
user=request.user,
**{
"refresh_token": None,
"revoked": False,
"issued_at": datetime.now(),
**token
})
with db.connection(current_app.config["AUTH_DB"]) as conn:
save_token(conn, _token)
old_refresh_token = load_refresh_token(
conn,
request.form.get("refresh_token", "nosuchtoken")
)
new_refresh_token = JWTRefreshToken(
token=_token.refresh_token,
client=request.client,
user=request.user,
issued_with=uuid.UUID(_jwt["jti"]),
issued_at=datetime.fromtimestamp(_jwt["iat"]),
expires=datetime.fromtimestamp(
old_refresh_token.then(
lambda _tok: _tok.expires.timestamp()
).maybe((int(_jwt["iat"]) +
RefreshTokenGrant.DEFAULT_EXPIRES_IN),
lambda _expires: _expires)),
scope=_token.get_scope(),
revoked=False,
parent_of=None)
save_refresh_token(conn, new_refresh_token)
old_refresh_token.then(lambda _tok: link_child_token(
conn, _tok.token, new_refresh_token.token))
return __save_token__
def make_jwt_token_generator(app):
"""Make token generator function."""
def __generator__(# pylint: disable=[too-many-arguments]
grant_type,
client,
user=None,
scope=None,
expires_in=None,# pylint: disable=[unused-argument]
include_refresh_token=True
):
return JWTBearerTokenGenerator(
newest_jwk_with_rotation(
jwks_directory(app),
int(app.config["JWKS_ROTATION_AGE_DAYS"]))).__call__(
grant_type,
client,
user,
scope,
JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
include_refresh_token)
return __generator__
class JsonAuthorizationServer(AuthorizationServer):
"""An authorisation server using JSON rather than FORMDATA."""
def create_oauth2_request(self, request):
"""Create an OAuth2 Request from the flask request."""
res = create_oauth_request(request, OAuth2Request, True)
return res
def setup_oauth2_server(app: Flask) -> None:
"""Set's up the oauth2 server for the flask application."""
server = JsonAuthorizationServer()
server.register_grant(PasswordGrant)
# Figure out a common `code_verifier` for GN2 and GN3 and set
# server.register_grant(AuthorisationCodeGrant, [CodeChallenge(required=False)])
# below
server.register_grant(AuthorisationCodeGrant)
server.register_grant(JWTBearerGrant)
jwttokengenerator = make_jwt_token_generator(app)
server.register_token_generator(
"urn:ietf:params:oauth:grant-type:jwt-bearer", jwttokengenerator)
server.register_token_generator("refresh_token", jwttokengenerator)
server.register_grant(RefreshTokenGrant)
# register endpoints
server.register_endpoint(RevocationEndpoint)
server.register_endpoint(IntrospectionEndpoint)
# init server
server.init_app(
app,
query_client=create_query_client_func(),
save_token=create_save_token_func(OAuth2Token, app))
app.config["OAUTH2_SERVER"] = server
## Set up the token validators
require_oauth.register_token_validator(
JWTBearerTokenValidator(KeySet(list_jwks(jwks_directory(app)))))
|