diff options
Diffstat (limited to 'gn_auth/auth/authentication')
7 files changed, 182 insertions, 78 deletions
diff --git a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py index 1f53186..c802091 100644 --- a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py +++ b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py @@ -1,15 +1,21 @@ """JWT as Authorisation Grant""" import uuid +import time +from typing import Optional from flask import current_app as app +from authlib.jose import jwt +from authlib.common.encoding import to_native from authlib.common.security import generate_token from authlib.oauth2.rfc7523.jwt_bearer import JWTBearerGrant as _JWTBearerGrant from authlib.oauth2.rfc7523.token import ( JWTBearerTokenGenerator as _JWTBearerTokenGenerator) +from gn_auth.debug import __pk__ from gn_auth.auth.db.sqlite3 import with_db_connection -from gn_auth.auth.authentication.users import user_by_id +from gn_auth.auth.authentication.users import User, user_by_id +from gn_auth.auth.authentication.oauth2.models.oauth2client import OAuth2Client class JWTBearerTokenGenerator(_JWTBearerTokenGenerator): @@ -19,23 +25,66 @@ class JWTBearerTokenGenerator(_JWTBearerTokenGenerator): DEFAULT_EXPIRES_IN = 300 - def get_token_data(#pylint: disable=[too-many-arguments] + def get_token_data(#pylint: disable=[too-many-arguments, too-many-positional-arguments] self, grant_type, client, expires_in=None, user=None, scope=None ): """Post process data to prevent JSON serialization problems.""" - tokendata = super().get_token_data( - grant_type, client, expires_in, user, scope) + issued_at = int(time.time()) + tokendata = { + "scope": self.get_allowed_scope(client, scope), + "grant_type": grant_type, + "iat": issued_at, + "client_id": client.get_client_id() + } + if isinstance(expires_in, int) and expires_in > 0: + tokendata["exp"] = issued_at + expires_in + if self.issuer: + tokendata["iss"] = self.issuer + if user: + tokendata["sub"] = self.get_sub_value(user) + return { **{ key: str(value) if key.endswith("_id") else value for key, value in tokendata.items() }, "sub": str(tokendata["sub"]), - "jti": str(uuid.uuid4()) + "jti": str(uuid.uuid4()), + "oauth2_client_id": str(client.client_id) } + def generate(# pylint: disable=[too-many-arguments, too-many-positional-arguments] + self, + grant_type: str, + client: OAuth2Client, + user: Optional[User] = None, + scope: Optional[str] = None, + expires_in: Optional[int] = None + ) -> dict: + """Generate a bearer token for OAuth 2.0 authorization token endpoint. + + :param client: the client that making the request. + :param grant_type: current requested grant_type. + :param user: current authorized user. + :param expires_in: if provided, use this value as expires_in. + :param scope: current requested scope. + :return: Token dict + """ + + token_data = self.get_token_data(grant_type, client, expires_in, user, scope) + access_token = jwt.encode({"alg": self.alg}, token_data, key=self.secret_key, check=False) + token = { + "token_type": "Bearer", + "access_token": to_native(access_token) + } + if expires_in: + token["expires_in"] = expires_in + if scope: + token["scope"] = scope + return token + - def __call__(# pylint: disable=[too-many-arguments] + def __call__(# pylint: disable=[too-many-arguments, too-many-positional-arguments] self, grant_type, client, user=None, scope=None, expires_in=None, include_refresh_token=True ): @@ -74,7 +123,9 @@ class JWTBearerGrant(_JWTBearerGrant): def resolve_client_key(self, client, headers, payload): """Resolve client key to decode assertion data.""" - return client.jwks().find_by_kid(headers["kid"]) + keyset = client.jwks() + __pk__("THE KEYSET =======>", keyset.keys) + return keyset.find_by_kid(headers["kid"]) def authenticate_user(self, subject): diff --git a/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py b/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py index fd6804d..f897d89 100644 --- a/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py +++ b/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py @@ -34,18 +34,18 @@ class RefreshTokenGrant(grants.RefreshTokenGrant): else Nothing) ).maybe(None, lambda _tok: _tok) - def authenticate_user(self, credential): + def authenticate_user(self, refresh_token): """Check that user is valid for given token.""" with connection(app.config["AUTH_DB"]) as conn: try: - return user_by_id(conn, credential.user.user_id) + return user_by_id(conn, refresh_token.user.user_id) except NotFoundError as _nfe: return None return None - def revoke_old_credential(self, credential): + def revoke_old_credential(self, refresh_token): """Revoke any old refresh token after issuing new refresh token.""" with connection(app.config["AUTH_DB"]) as conn: - if credential.parent_of is not None: - revoke_refresh_token(conn, credential) + if refresh_token.parent_of is not None: + revoke_refresh_token(conn, refresh_token) diff --git a/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py index 2606ac6..71769e1 100644 --- a/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py +++ b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py @@ -1,15 +1,50 @@ """Implement model for JWTBearerToken""" import uuid +import time +from typing import Optional from authlib.oauth2.rfc7523 import JWTBearerToken as _JWTBearerToken from gn_auth.auth.db.sqlite3 import with_db_connection from gn_auth.auth.authentication.users import user_by_id +from gn_auth.auth.authentication.oauth2.models.oauth2client import ( + client as fetch_client) class JWTBearerToken(_JWTBearerToken): """Overrides default JWTBearerToken class.""" def __init__(self, payload, header, options=None, params=None): + """Initialise the bearer token.""" + # TOD0: Maybe remove this init and make this a dataclass like the way + # OAuth2Client is a dataclass super().__init__(payload, header, options, params) self.user = with_db_connection( lambda conn:user_by_id(conn, uuid.UUID(payload["sub"]))) + self.client = with_db_connection( + lambda conn: fetch_client( + conn, uuid.UUID(payload["oauth2_client_id"]) + ) + ).maybe(None, lambda _client: _client) + + + def check_client(self, client): + """Check that the client is right.""" + return self.client.get_client_id() == client.get_client_id() + + + def get_expires_in(self) -> Optional[int]: + """Return the number of seconds the token is valid for since issue. + + If `None`, the token never expires.""" + if "exp" in self: + return self['exp'] - self['iat'] + return None + + + def is_expired(self): + """Check whether the token is expired. + + If there is no 'exp' member, assume this token will never expire.""" + if "exp" in self: + return self["exp"] < time.time() + return False diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py index 8fac648..1639e2e 100644 --- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py +++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py @@ -1,18 +1,19 @@ """OAuth2 Client model.""" import json -import logging import datetime from uuid import UUID -from dataclasses import dataclass from functools import cached_property -from typing import Sequence, Optional +from dataclasses import asdict, dataclass +from typing import Any, Sequence, Optional import requests +from flask import current_app as app from requests.exceptions import JSONDecodeError from authlib.jose import KeySet, JsonWebKey from authlib.oauth2.rfc6749 import ClientMixin from pymonad.maybe import Just, Maybe, Nothing +from gn_auth.debug import __pk__ from gn_auth.auth.db import sqlite3 as db from gn_auth.auth.errors import NotFoundError from gn_auth.auth.authentication.users import (User, @@ -62,23 +63,27 @@ class OAuth2Client(ClientMixin): def jwks(self) -> KeySet: """Return this client's KeySet.""" jwksuri = self.client_metadata.get("public-jwks-uri") + __pk__(f"PUBLIC JWKs link for client {self.client_id}", jwksuri) if not bool(jwksuri): - logging.debug("No Public JWKs URI set for client!") + app.logger.debug("No Public JWKs URI set for client!") return KeySet([]) try: ## IMPORTANT: This can cause a deadlock if the client is working in ## single-threaded mode, i.e. can only serve one request ## at a time. return KeySet([JsonWebKey.import_key(key) - for key in requests.get(jwksuri).json()["jwks"]]) + for key in requests.get( + jwksuri, + timeout=300, + allow_redirects=True).json()["jwks"]]) except requests.ConnectionError as _connerr: - logging.debug( + app.logger.debug( "Could not connect to provided URI: %s", jwksuri, exc_info=True) except JSONDecodeError as _jsonerr: - logging.debug( + app.logger.debug( "Could not convert response to JSON", exc_info=True) except Exception as _exc:# pylint: disable=[broad-except] - logging.debug( + app.logger.debug( "Error retrieving the JWKs for the client.", exc_info=True) return KeySet([]) @@ -289,3 +294,25 @@ def delete_client( cursor.execute("DELETE FROM oauth2_tokens WHERE client_id=?", params) cursor.execute("DELETE FROM oauth2_clients WHERE client_id=?", params) return the_client + + +def update_client_attribute( + client: OAuth2Client,# pylint: disable=[redefined-outer-name] + attribute: str, + value: Any +) -> OAuth2Client: + """Return a new OAuth2Client with the given attribute updated/changed.""" + attrs = { + attr: type(value) + for attr, value in asdict(client).items() + if attr != "client_id" + } + assert ( + attribute in attrs.keys() and isinstance(value, attrs[attribute])), ( + "Invalid attribute/value provided!") + return OAuth2Client( + client_id=client.client_id, + **{ + attr: (value if attr==attribute else getattr(client, attr)) + for attr in attrs + }) diff --git a/gn_auth/auth/authentication/oauth2/resource_server.py b/gn_auth/auth/authentication/oauth2/resource_server.py index 9c885e2..8ecf923 100644 --- a/gn_auth/auth/authentication/oauth2/resource_server.py +++ b/gn_auth/auth/authentication/oauth2/resource_server.py @@ -43,6 +43,11 @@ class JWTBearerTokenValidator(_JWTBearerTokenValidator): self._last_jwks_update = datetime.now(tz=timezone.utc) self._refresh_frequency = timedelta(hours=int( extra_attributes.get("jwt_refresh_frequency_hours", 6))) + self.claims_options = { + 'exp': {'essential': False}, + 'client_id': {'essential': True}, + 'grant_type': {'essential': True}, + } def __refresh_jwks__(self): now = datetime.now(tz=timezone.utc) diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py index a8109b7..8ac5106 100644 --- a/gn_auth/auth/authentication/oauth2/server.py +++ b/gn_auth/auth/authentication/oauth2/server.py @@ -3,12 +3,12 @@ import uuid from typing import Callable from datetime import datetime -from flask import Flask, current_app -from authlib.jose import jwt, KeySet +from flask import Flask, current_app, request as flask_request +from authlib.jose import KeySet +from authlib.oauth2.rfc6749 import OAuth2Request from authlib.oauth2.rfc6749.errors import InvalidClientError from authlib.integrations.flask_oauth2 import AuthorizationServer -from authlib.oauth2.rfc6749 import OAuth2Request -from authlib.integrations.flask_helpers import create_oauth_request +from authlib.integrations.flask_oauth2.requests import FlaskOAuth2Request from gn_auth.auth.db import sqlite3 as db from gn_auth.auth.jwks import ( @@ -16,13 +16,9 @@ from gn_auth.auth.jwks import ( jwks_directory, newest_jwk_with_rotation) +from .models.jwt_bearer_token import JWTBearerToken from .models.oauth2client import client as fetch_client from .models.oauth2token import OAuth2Token, save_token -from .models.jwtrefreshtoken import ( - JWTRefreshToken, - link_child_token, - save_refresh_token, - load_refresh_token) from .grants.password_grant import PasswordGrant from .grants.refresh_token_grant import RefreshTokenGrant @@ -34,6 +30,8 @@ from .endpoints.introspection import IntrospectionEndpoint from .resource_server import require_oauth, JWTBearerTokenValidator +_TWO_HOURS_ = 2 * 60 * 60 + def create_query_client_func() -> Callable: """Create the function that loads the client.""" @@ -50,54 +48,32 @@ def create_query_client_func() -> Callable: return __query_client__ -def create_save_token_func(token_model: type, app: Flask) -> Callable: +def create_save_token_func(token_model: type) -> Callable: """Create the function that saves the token.""" + def __ignore_token__(token, request):# pylint: disable=[unused-argument] + """Ignore the token: i.e. Do not save it.""" + def __save_token__(token, request): - _jwt = jwt.decode( - token["access_token"], - newest_jwk_with_rotation( - jwks_directory(app), - int(app.config["JWKS_ROTATION_AGE_DAYS"]))) - _token = token_model( - token_id=uuid.UUID(_jwt["jti"]), - client=request.client, - user=request.user, - **{ - "refresh_token": None, - "revoked": False, - "issued_at": datetime.now(), - **token - }) with db.connection(current_app.config["AUTH_DB"]) as conn: - save_token(conn, _token) - old_refresh_token = load_refresh_token( + save_token( conn, - request.form.get("refresh_token", "nosuchtoken") - ) - new_refresh_token = JWTRefreshToken( - token=_token.refresh_token, + token_model( + **token, + token_id=uuid.uuid4(), client=request.client, user=request.user, - issued_with=uuid.UUID(_jwt["jti"]), - issued_at=datetime.fromtimestamp(_jwt["iat"]), - expires=datetime.fromtimestamp( - old_refresh_token.then( - lambda _tok: _tok.expires.timestamp() - ).maybe((int(_jwt["iat"]) + - RefreshTokenGrant.DEFAULT_EXPIRES_IN), - lambda _expires: _expires)), - scope=_token.get_scope(), + issued_at=datetime.now(), revoked=False, - parent_of=None) - save_refresh_token(conn, new_refresh_token) - old_refresh_token.then(lambda _tok: link_child_token( - conn, _tok.token, new_refresh_token.token)) + expires_in=_TWO_HOURS_)) - return __save_token__ + return { + OAuth2Token: __save_token__, + JWTBearerToken: __ignore_token__ + }[token_model] def make_jwt_token_generator(app): """Make token generator function.""" - def __generator__(# pylint: disable=[too-many-arguments] + def __generator__(# pylint: disable=[too-many-arguments, too-many-positional-arguments] grant_type, client, user=None, @@ -106,15 +82,17 @@ def make_jwt_token_generator(app): include_refresh_token=True ): return JWTBearerTokenGenerator( - newest_jwk_with_rotation( + secret_key=newest_jwk_with_rotation( jwks_directory(app), - int(app.config["JWKS_ROTATION_AGE_DAYS"]))).__call__( - grant_type, - client, - user, - scope, - JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN, - include_refresh_token) + int(app.config["JWKS_ROTATION_AGE_DAYS"])), + issuer=flask_request.host_url, + alg="RS256").__call__( + grant_type=grant_type, + client=client, + user=user, + scope=scope, + expires_in=expires_in, + include_refresh_token=include_refresh_token) return __generator__ @@ -124,8 +102,16 @@ class JsonAuthorizationServer(AuthorizationServer): def create_oauth2_request(self, request): """Create an OAuth2 Request from the flask request.""" - res = create_oauth_request(request, OAuth2Request, True) - return res + match flask_request.headers.get("Content-Type"): + case "application/json": + req = OAuth2Request(flask_request.method, + flask_request.url, + flask_request.get_json(), + flask_request.headers) + case _: + req = FlaskOAuth2Request(flask_request) + + return req def setup_oauth2_server(app: Flask) -> None: @@ -153,7 +139,7 @@ def setup_oauth2_server(app: Flask) -> None: server.init_app( app, query_client=create_query_client_func(), - save_token=create_save_token_func(OAuth2Token, app)) + save_token=create_save_token_func(JWTBearerToken)) app.config["OAUTH2_SERVER"] = server ## Set up the token validators diff --git a/gn_auth/auth/authentication/oauth2/views.py b/gn_auth/auth/authentication/oauth2/views.py index d0b55b4..0e2c4eb 100644 --- a/gn_auth/auth/authentication/oauth2/views.py +++ b/gn_auth/auth/authentication/oauth2/views.py @@ -77,7 +77,7 @@ def authorise(): try: email = validate_email( form.get("user:email"), check_deliverability=False) - user = user_by_email(conn, email["email"]) + user = user_by_email(conn, email["email"]) # type: ignore if valid_login(conn, user, form.get("user:password", "")): if not user.verified: return redirect( |