diff options
Diffstat (limited to 'gn_auth/auth/authentication/oauth2')
9 files changed, 190 insertions, 47 deletions
diff --git a/gn_auth/auth/authentication/oauth2/endpoints/introspection.py b/gn_auth/auth/authentication/oauth2/endpoints/introspection.py index 572324e..200b25d 100644 --- a/gn_auth/auth/authentication/oauth2/endpoints/introspection.py +++ b/gn_auth/auth/authentication/oauth2/endpoints/introspection.py @@ -20,6 +20,7 @@ def get_token_user_sub(token: OAuth2Token) -> str:# pylint: disable=[unused-argu class IntrospectionEndpoint(_IntrospectionEndpoint): """Introspect token.""" + CLIENT_AUTH_METHODS = ['client_secret_post'] def query_token(self, token_string: str, token_type_hint: str): """Query the token.""" return _query_token(self, token_string, token_type_hint) diff --git a/gn_auth/auth/authentication/oauth2/endpoints/revocation.py b/gn_auth/auth/authentication/oauth2/endpoints/revocation.py index 240ca30..80922f1 100644 --- a/gn_auth/auth/authentication/oauth2/endpoints/revocation.py +++ b/gn_auth/auth/authentication/oauth2/endpoints/revocation.py @@ -12,6 +12,7 @@ from .utilities import query_token as _query_token class RevocationEndpoint(_RevocationEndpoint): """Revoke the tokens""" ENDPOINT_NAME = "revoke" + CLIENT_AUTH_METHODS = ['client_secret_post'] def query_token(self, token_string: str, token_type_hint: str): """Query the token.""" return _query_token(self, token_string, token_type_hint) diff --git a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py index b0f2cc7..1f53186 100644 --- a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py +++ b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py @@ -74,7 +74,7 @@ class JWTBearerGrant(_JWTBearerGrant): def resolve_client_key(self, client, headers, payload): """Resolve client key to decode assertion data.""" - return app.config["SSL_PUBLIC_KEYS"].get(headers["kid"]) + return client.jwks().find_by_kid(headers["kid"]) def authenticate_user(self, subject): diff --git a/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py new file mode 100644 index 0000000..2606ac6 --- /dev/null +++ b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py @@ -0,0 +1,15 @@ +"""Implement model for JWTBearerToken""" +import uuid + +from authlib.oauth2.rfc7523 import JWTBearerToken as _JWTBearerToken + +from gn_auth.auth.db.sqlite3 import with_db_connection +from gn_auth.auth.authentication.users import user_by_id + +class JWTBearerToken(_JWTBearerToken): + """Overrides default JWTBearerToken class.""" + + def __init__(self, payload, header, options=None, params=None): + super().__init__(payload, header, options, params) + self.user = with_db_connection( + lambda conn:user_by_id(conn, uuid.UUID(payload["sub"]))) diff --git a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py index 31c9147..46515c8 100644 --- a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py +++ b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py @@ -142,7 +142,7 @@ def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str): "WHERE token=:parenttoken"), {"parenttoken": parent.token, "childtoken": childtoken}) - def __check_child__(parent): + def __check_child__(parent):#pylint: disable=[unused-variable] with db.cursor(conn) as cursor: cursor.execute( ("SELECT * FROM jwt_refresh_tokens WHERE token=:parenttoken"), @@ -154,15 +154,17 @@ def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str): "activity detected.") return Right(parent) - def __revoke_and_raise_error__(_error_msg_): + def __revoke_and_raise_error__(_error_msg_):#pylint: disable=[unused-variable] load_refresh_token(conn, parenttoken).then( lambda _tok: revoke_refresh_token(conn, _tok)) raise InvalidGrantError(_error_msg_) + def __handle_not_found__(_error_msg_): + raise InvalidGrantError(_error_msg_) + load_refresh_token(conn, parenttoken).maybe( - Left("Token not found"), Right).then( - __check_child__).either(__revoke_and_raise_error__, - __link_to_child__) + Left("Token not found"), Right).either( + __handle_not_found__, __link_to_child__) def is_refresh_token_valid(token: JWTRefreshToken, client: OAuth2Client) -> bool: diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py index d31faf6..2c36f45 100644 --- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py +++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py @@ -1,13 +1,14 @@ """OAuth2 Client model.""" import json +import logging import datetime -from pathlib import Path - from uuid import UUID -from dataclasses import dataclass from functools import cached_property -from typing import Sequence, Optional +from dataclasses import asdict, dataclass +from typing import Any, Sequence, Optional +import requests +from requests.exceptions import JSONDecodeError from authlib.jose import KeySet, JsonWebKey from authlib.oauth2.rfc6749 import ClientMixin from pymonad.maybe import Just, Maybe, Nothing @@ -57,16 +58,30 @@ class OAuth2Client(ClientMixin): """ return self.client_metadata.get("client_type", "public") - @cached_property + def jwks(self) -> KeySet: """Return this client's KeySet.""" - def __parse_key__(keypath: Path) -> JsonWebKey: - with open(keypath) as _key:# pylint: disable=[unspecified-encoding] - return JsonWebKey.import_key(_key.read()) + jwksuri = self.client_metadata.get("public-jwks-uri") + if not bool(jwksuri): + logging.debug("No Public JWKs URI set for client!") + return KeySet([]) + try: + ## IMPORTANT: This can cause a deadlock if the client is working in + ## single-threaded mode, i.e. can only serve one request + ## at a time. + return KeySet([JsonWebKey.import_key(key) + for key in requests.get(jwksuri).json()["jwks"]]) + except requests.ConnectionError as _connerr: + logging.debug( + "Could not connect to provided URI: %s", jwksuri, exc_info=True) + except JSONDecodeError as _jsonerr: + logging.debug( + "Could not convert response to JSON", exc_info=True) + except Exception as _exc:# pylint: disable=[broad-except] + logging.debug( + "Error retrieving the JWKs for the client.", exc_info=True) + return KeySet([]) - return KeySet([ - __parse_key__(Path(pth)) - for pth in self.client_metadata.get("public_keys", [])]) def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool: """ @@ -77,12 +92,9 @@ class OAuth2Client(ClientMixin): * client_secret_post: Client uses the HTTP POST parameters * client_secret_basic: Client uses HTTP Basic """ - if endpoint == "token": + if endpoint in ("token", "revoke", "introspection"): return (method in self.token_endpoint_auth_method and method == "client_secret_post") - if endpoint in ("introspection", "revoke"): - return (method in self.token_endpoint_auth_method - and method == "client_secret_basic") return False @cached_property @@ -277,3 +289,22 @@ def delete_client( cursor.execute("DELETE FROM oauth2_tokens WHERE client_id=?", params) cursor.execute("DELETE FROM oauth2_clients WHERE client_id=?", params) return the_client + + +def update_client_attribute( + client: OAuth2Client, attribute: str, value: Any) -> OAuth2Client: + """Return a new OAuth2Client with the given attribute updated/changed.""" + attrs = { + attr: type(value) + for attr, value in asdict(client).items() + if attr != "client_id" + } + assert ( + attribute in attrs.keys() and isinstance(value, attrs[attribute])), ( + "Invalid attribute/value provided!") + return OAuth2Client( + client_id=client.client_id, + **{ + attr: (value if attr==attribute else getattr(client, attr)) + for attr in attrs + }) diff --git a/gn_auth/auth/authentication/oauth2/resource_server.py b/gn_auth/auth/authentication/oauth2/resource_server.py index 2405ee2..9c885e2 100644 --- a/gn_auth/auth/authentication/oauth2/resource_server.py +++ b/gn_auth/auth/authentication/oauth2/resource_server.py @@ -1,11 +1,20 @@ """Protect the resources endpoints""" +from datetime import datetime, timezone, timedelta from flask import current_app as app + +from authlib.jose import jwt, KeySet, JoseError from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator +from authlib.oauth2.rfc7523 import ( + JWTBearerTokenValidator as _JWTBearerTokenValidator) from authlib.integrations.flask_oauth2 import ResourceProtector from gn_auth.auth.db import sqlite3 as db -from gn_auth.auth.authentication.oauth2.models.oauth2token import token_by_access_token +from gn_auth.auth.jwks import list_jwks, jwks_directory +from gn_auth.auth.authentication.oauth2.models.jwt_bearer_token import ( + JWTBearerToken) +from gn_auth.auth.authentication.oauth2.models.oauth2token import ( + token_by_access_token) class BearerTokenValidator(_BearerTokenValidator): """Extends `authlib.oauth2.rfc6750.BearerTokenValidator`""" @@ -14,4 +23,47 @@ class BearerTokenValidator(_BearerTokenValidator): return token_by_access_token(conn, token_string).maybe(# type: ignore[misc] None, lambda tok: tok) +class JWTBearerTokenValidator(_JWTBearerTokenValidator): + """Validate a token using all the keys""" + token_cls = JWTBearerToken + _local_attributes = ("jwt_refresh_frequency_hours",) + + def __init__(self, public_key, issuer=None, realm=None, **extra_attributes): + """Initialise the validator class.""" + # https://docs.authlib.org/en/latest/jose/jwt.html#use-dynamic-keys + # We can simply use the KeySet rather than a specific key. + super().__init__(public_key, + issuer, + realm, + **{ + key: value for key,value + in extra_attributes.items() + if key not in self._local_attributes + }) + self._last_jwks_update = datetime.now(tz=timezone.utc) + self._refresh_frequency = timedelta(hours=int( + extra_attributes.get("jwt_refresh_frequency_hours", 6))) + + def __refresh_jwks__(self): + now = datetime.now(tz=timezone.utc) + if (now - self._last_jwks_update) >= self._refresh_frequency: + self.public_key = KeySet(list_jwks(jwks_directory(app))) + + def authenticate_token(self, token_string: str): + self.__refresh_jwks__() + for key in self.public_key.keys: + try: + claims = jwt.decode( + token_string, key, + claims_options=self.claims_options, + claims_cls=self.token_cls, + ) + claims.validate() + return claims + except JoseError as error: + app.logger.debug('Authenticate token failed. %r', error) + + return None + + require_oauth = ResourceProtector() diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py index d845c60..a8109b7 100644 --- a/gn_auth/auth/authentication/oauth2/server.py +++ b/gn_auth/auth/authentication/oauth2/server.py @@ -1,15 +1,20 @@ """Initialise the OAuth2 Server""" import uuid -import datetime from typing import Callable +from datetime import datetime from flask import Flask, current_app -from authlib.jose import jwk, jwt -from authlib.oauth2.rfc7523 import JWTBearerTokenValidator +from authlib.jose import jwt, KeySet from authlib.oauth2.rfc6749.errors import InvalidClientError from authlib.integrations.flask_oauth2 import AuthorizationServer +from authlib.oauth2.rfc6749 import OAuth2Request +from authlib.integrations.flask_helpers import create_oauth_request from gn_auth.auth.db import sqlite3 as db +from gn_auth.auth.jwks import ( + list_jwks, + jwks_directory, + newest_jwk_with_rotation) from .models.oauth2client import client as fetch_client from .models.oauth2token import OAuth2Token, save_token @@ -27,7 +32,7 @@ from .grants.jwt_bearer_grant import JWTBearerGrant, JWTBearerTokenGenerator from .endpoints.revocation import RevocationEndpoint from .endpoints.introspection import IntrospectionEndpoint -from .resource_server import require_oauth, BearerTokenValidator +from .resource_server import require_oauth, JWTBearerTokenValidator def create_query_client_func() -> Callable: @@ -45,10 +50,14 @@ def create_query_client_func() -> Callable: return __query_client__ -def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable: +def create_save_token_func(token_model: type, app: Flask) -> Callable: """Create the function that saves the token.""" def __save_token__(token, request): - _jwt = jwt.decode(token["access_token"], jwtkey) + _jwt = jwt.decode( + token["access_token"], + newest_jwk_with_rotation( + jwks_directory(app), + int(app.config["JWKS_ROTATION_AGE_DAYS"]))) _token = token_model( token_id=uuid.UUID(_jwt["jti"]), client=request.client, @@ -56,7 +65,7 @@ def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable: **{ "refresh_token": None, "revoked": False, - "issued_at": datetime.datetime.now(), + "issued_at": datetime.now(), **token }) with db.connection(current_app.config["AUTH_DB"]) as conn: @@ -70,8 +79,8 @@ def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable: client=request.client, user=request.user, issued_with=uuid.UUID(_jwt["jti"]), - issued_at=datetime.datetime.fromtimestamp(_jwt["iat"]), - expires=datetime.datetime.fromtimestamp( + issued_at=datetime.fromtimestamp(_jwt["iat"]), + expires=datetime.fromtimestamp( old_refresh_token.then( lambda _tok: _tok.expires.timestamp() ).maybe((int(_jwt["iat"]) + @@ -86,10 +95,8 @@ def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable: return __save_token__ - def make_jwt_token_generator(app): """Make token generator function.""" - _gen = JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"]) def __generator__(# pylint: disable=[too-many-arguments] grant_type, client, @@ -98,19 +105,32 @@ def make_jwt_token_generator(app): expires_in=None,# pylint: disable=[unused-argument] include_refresh_token=True ): - return _gen.__call__( - grant_type, - client, - user, - scope, - JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN, - include_refresh_token) + return JWTBearerTokenGenerator( + newest_jwk_with_rotation( + jwks_directory(app), + int(app.config["JWKS_ROTATION_AGE_DAYS"]))).__call__( + grant_type, + client, + user, + scope, + JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN, + include_refresh_token) return __generator__ + +class JsonAuthorizationServer(AuthorizationServer): + """An authorisation server using JSON rather than FORMDATA.""" + + def create_oauth2_request(self, request): + """Create an OAuth2 Request from the flask request.""" + res = create_oauth_request(request, OAuth2Request, True) + return res + + def setup_oauth2_server(app: Flask) -> None: """Set's up the oauth2 server for the flask application.""" - server = AuthorizationServer() + server = JsonAuthorizationServer() server.register_grant(PasswordGrant) # Figure out a common `code_verifier` for GN2 and GN3 and set @@ -133,11 +153,9 @@ def setup_oauth2_server(app: Flask) -> None: server.init_app( app, query_client=create_query_client_func(), - save_token=create_save_token_func( - OAuth2Token, app.config["SSL_PRIVATE_KEY"])) + save_token=create_save_token_func(OAuth2Token, app)) app.config["OAUTH2_SERVER"] = server ## Set up the token validators - require_oauth.register_token_validator(BearerTokenValidator()) require_oauth.register_token_validator( - JWTBearerTokenValidator(app.config["SSL_PRIVATE_KEY"].get_public_key())) + JWTBearerTokenValidator(KeySet(list_jwks(jwks_directory(app))))) diff --git a/gn_auth/auth/authentication/oauth2/views.py b/gn_auth/auth/authentication/oauth2/views.py index 22437a2..d0b55b4 100644 --- a/gn_auth/auth/authentication/oauth2/views.py +++ b/gn_auth/auth/authentication/oauth2/views.py @@ -9,6 +9,7 @@ from flask import ( flash, request, url_for, + jsonify, redirect, Response, Blueprint, @@ -17,6 +18,7 @@ from flask import ( from gn_auth.auth.db import sqlite3 as db from gn_auth.auth.db.sqlite3 import with_db_connection +from gn_auth.auth.jwks import jwks_directory, list_jwks from gn_auth.auth.errors import NotFoundError, ForbiddenAccess from gn_auth.auth.authentication.users import valid_login, user_by_email @@ -45,6 +47,14 @@ def authorise(): flash("Invalid OAuth2 client.", "alert-danger") if request.method == "GET": + def __forgot_password_table_exists__(conn): + with db.cursor(conn) as cursor: + cursor.execute("SELECT name FROM sqlite_master " + "WHERE type='table' " + "AND name='forgot_password_tokens'") + return bool(cursor.fetchone()) + return False + client = server.query_client(request.args.get("client_id")) _src = urlparse(request.args["redirect_uri"]) return render_template( @@ -53,7 +63,9 @@ def authorise(): scope=client.scope, response_type=request.args["response_type"], redirect_uri=request.args["redirect_uri"], - source_uri=f"{_src.scheme}://{_src.netloc}/") + source_uri=f"{_src.scheme}://{_src.netloc}/", + display_forgot_password=with_db_connection( + __forgot_password_table_exists__)) form = request.form def __authorise__(conn: db.DbConnection): @@ -72,7 +84,8 @@ def authorise(): url_for("oauth2.users.handle_unverified", response_type=form["response_type"], client_id=client_id, - redirect_uri=form["redirect_uri"]), + redirect_uri=form["redirect_uri"], + email=email["email"]), code=307) return server.create_authorization_response(request=request, grant_user=user) flash(email_passwd_msg, "alert-danger") @@ -116,3 +129,13 @@ def introspect_token() -> Response: IntrospectionEndpoint.ENDPOINT_NAME) raise ForbiddenAccess("You cannot access this endpoint") + + +@auth.route("/public-jwks", methods=["GET"]) +def public_jwks(): + """Provide the JWK public keys used by this application.""" + return jsonify({ + "documentation": ( + "The keys are listed in order of creation, from the oldest (first) " + "to the newest (last)."), + "jwks": tuple(key.as_dict() for key in list_jwks(jwks_directory(app)))}) |