aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/auth/authentication/oauth2/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth/auth/authentication/oauth2/server.py')
-rw-r--r--gn_auth/auth/authentication/oauth2/server.py64
1 files changed, 41 insertions, 23 deletions
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py
index d845c60..a8109b7 100644
--- a/gn_auth/auth/authentication/oauth2/server.py
+++ b/gn_auth/auth/authentication/oauth2/server.py
@@ -1,15 +1,20 @@
"""Initialise the OAuth2 Server"""
import uuid
-import datetime
from typing import Callable
+from datetime import datetime
from flask import Flask, current_app
-from authlib.jose import jwk, jwt
-from authlib.oauth2.rfc7523 import JWTBearerTokenValidator
+from authlib.jose import jwt, KeySet
from authlib.oauth2.rfc6749.errors import InvalidClientError
from authlib.integrations.flask_oauth2 import AuthorizationServer
+from authlib.oauth2.rfc6749 import OAuth2Request
+from authlib.integrations.flask_helpers import create_oauth_request
from gn_auth.auth.db import sqlite3 as db
+from gn_auth.auth.jwks import (
+ list_jwks,
+ jwks_directory,
+ newest_jwk_with_rotation)
from .models.oauth2client import client as fetch_client
from .models.oauth2token import OAuth2Token, save_token
@@ -27,7 +32,7 @@ from .grants.jwt_bearer_grant import JWTBearerGrant, JWTBearerTokenGenerator
from .endpoints.revocation import RevocationEndpoint
from .endpoints.introspection import IntrospectionEndpoint
-from .resource_server import require_oauth, BearerTokenValidator
+from .resource_server import require_oauth, JWTBearerTokenValidator
def create_query_client_func() -> Callable:
@@ -45,10 +50,14 @@ def create_query_client_func() -> Callable:
return __query_client__
-def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable:
+def create_save_token_func(token_model: type, app: Flask) -> Callable:
"""Create the function that saves the token."""
def __save_token__(token, request):
- _jwt = jwt.decode(token["access_token"], jwtkey)
+ _jwt = jwt.decode(
+ token["access_token"],
+ newest_jwk_with_rotation(
+ jwks_directory(app),
+ int(app.config["JWKS_ROTATION_AGE_DAYS"])))
_token = token_model(
token_id=uuid.UUID(_jwt["jti"]),
client=request.client,
@@ -56,7 +65,7 @@ def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable:
**{
"refresh_token": None,
"revoked": False,
- "issued_at": datetime.datetime.now(),
+ "issued_at": datetime.now(),
**token
})
with db.connection(current_app.config["AUTH_DB"]) as conn:
@@ -70,8 +79,8 @@ def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable:
client=request.client,
user=request.user,
issued_with=uuid.UUID(_jwt["jti"]),
- issued_at=datetime.datetime.fromtimestamp(_jwt["iat"]),
- expires=datetime.datetime.fromtimestamp(
+ issued_at=datetime.fromtimestamp(_jwt["iat"]),
+ expires=datetime.fromtimestamp(
old_refresh_token.then(
lambda _tok: _tok.expires.timestamp()
).maybe((int(_jwt["iat"]) +
@@ -86,10 +95,8 @@ def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable:
return __save_token__
-
def make_jwt_token_generator(app):
"""Make token generator function."""
- _gen = JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"])
def __generator__(# pylint: disable=[too-many-arguments]
grant_type,
client,
@@ -98,19 +105,32 @@ def make_jwt_token_generator(app):
expires_in=None,# pylint: disable=[unused-argument]
include_refresh_token=True
):
- return _gen.__call__(
- grant_type,
- client,
- user,
- scope,
- JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
- include_refresh_token)
+ return JWTBearerTokenGenerator(
+ newest_jwk_with_rotation(
+ jwks_directory(app),
+ int(app.config["JWKS_ROTATION_AGE_DAYS"]))).__call__(
+ grant_type,
+ client,
+ user,
+ scope,
+ JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
+ include_refresh_token)
return __generator__
+
+class JsonAuthorizationServer(AuthorizationServer):
+ """An authorisation server using JSON rather than FORMDATA."""
+
+ def create_oauth2_request(self, request):
+ """Create an OAuth2 Request from the flask request."""
+ res = create_oauth_request(request, OAuth2Request, True)
+ return res
+
+
def setup_oauth2_server(app: Flask) -> None:
"""Set's up the oauth2 server for the flask application."""
- server = AuthorizationServer()
+ server = JsonAuthorizationServer()
server.register_grant(PasswordGrant)
# Figure out a common `code_verifier` for GN2 and GN3 and set
@@ -133,11 +153,9 @@ def setup_oauth2_server(app: Flask) -> None:
server.init_app(
app,
query_client=create_query_client_func(),
- save_token=create_save_token_func(
- OAuth2Token, app.config["SSL_PRIVATE_KEY"]))
+ save_token=create_save_token_func(OAuth2Token, app))
app.config["OAUTH2_SERVER"] = server
## Set up the token validators
- require_oauth.register_token_validator(BearerTokenValidator())
require_oauth.register_token_validator(
- JWTBearerTokenValidator(app.config["SSL_PRIVATE_KEY"].get_public_key()))
+ JWTBearerTokenValidator(KeySet(list_jwks(jwks_directory(app)))))