1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
|
"""Initialise the OAuth2 Server"""
import uuid
import datetime
from typing import Callable
from flask import Flask, current_app
from authlib.oauth2.rfc6749.errors import InvalidClientError
from authlib.integrations.flask_oauth2 import AuthorizationServer
# from authlib.oauth2.rfc7636 import CodeChallenge
from gn_auth.auth.db import sqlite3 as db
from .models.oauth2client import client
from .models.oauth2token import OAuth2Token, save_token
from .grants.password_grant import PasswordGrant
from .grants.authorisation_code_grant import AuthorisationCodeGrant
from .endpoints.revocation import RevocationEndpoint
from .endpoints.introspection import IntrospectionEndpoint
def create_query_client_func() -> Callable:
"""Create the function that loads the client."""
def __query_client__(client_id: uuid.UUID):
# use current_app rather than passing the db_uri to avoid issues
# when config changes, e.g. while testing.
with db.connection(current_app.config["AUTH_DB"]) as conn:
the_client = client(conn, client_id).maybe(
None, lambda clt: clt) # type: ignore[misc]
if bool(the_client):
return the_client
raise InvalidClientError(
"No client found for the given CLIENT_ID and CLIENT_SECRET.")
return __query_client__
def create_save_token_func(token_model: type) -> Callable:
"""Create the function that saves the token."""
def __save_token__(token, request):
with db.connection(current_app.config["AUTH_DB"]) as conn:
save_token(
conn, token_model(
token_id=uuid.uuid4(), client=request.client,
user=request.user,
**{
"refresh_token": None, "revoked": False,
"issued_at": datetime.datetime.now(),
**token
}))
return __save_token__
def setup_oauth2_server(app: Flask) -> None:
"""Set's up the oauth2 server for the flask application."""
server = AuthorizationServer()
server.register_grant(PasswordGrant)
# Figure out a common `code_verifier` for GN2 and GN3 and set
# server.register_grant(AuthorisationCodeGrant, [CodeChallenge(required=False)])
# below
server.register_grant(AuthorisationCodeGrant)
# register endpoints
server.register_endpoint(RevocationEndpoint)
server.register_endpoint(IntrospectionEndpoint)
# init server
server.init_app(
app,
query_client=create_query_client_func(),
save_token=create_save_token_func(OAuth2Token))
app.config["OAUTH2_SERVER"] = server
|