aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/auth/authentication
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth/auth/authentication')
-rw-r--r--gn_auth/auth/authentication/oauth2/endpoints/introspection.py1
-rw-r--r--gn_auth/auth/authentication/oauth2/endpoints/revocation.py1
-rw-r--r--gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py65
-rw-r--r--gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py10
-rw-r--r--gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py50
-rw-r--r--gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py12
-rw-r--r--gn_auth/auth/authentication/oauth2/models/oauth2client.py69
-rw-r--r--gn_auth/auth/authentication/oauth2/resource_server.py59
-rw-r--r--gn_auth/auth/authentication/oauth2/server.py118
-rw-r--r--gn_auth/auth/authentication/oauth2/views.py29
10 files changed, 321 insertions, 93 deletions
diff --git a/gn_auth/auth/authentication/oauth2/endpoints/introspection.py b/gn_auth/auth/authentication/oauth2/endpoints/introspection.py
index 572324e..200b25d 100644
--- a/gn_auth/auth/authentication/oauth2/endpoints/introspection.py
+++ b/gn_auth/auth/authentication/oauth2/endpoints/introspection.py
@@ -20,6 +20,7 @@ def get_token_user_sub(token: OAuth2Token) -> str:# pylint: disable=[unused-argu
class IntrospectionEndpoint(_IntrospectionEndpoint):
"""Introspect token."""
+ CLIENT_AUTH_METHODS = ['client_secret_post']
def query_token(self, token_string: str, token_type_hint: str):
"""Query the token."""
return _query_token(self, token_string, token_type_hint)
diff --git a/gn_auth/auth/authentication/oauth2/endpoints/revocation.py b/gn_auth/auth/authentication/oauth2/endpoints/revocation.py
index 240ca30..80922f1 100644
--- a/gn_auth/auth/authentication/oauth2/endpoints/revocation.py
+++ b/gn_auth/auth/authentication/oauth2/endpoints/revocation.py
@@ -12,6 +12,7 @@ from .utilities import query_token as _query_token
class RevocationEndpoint(_RevocationEndpoint):
"""Revoke the tokens"""
ENDPOINT_NAME = "revoke"
+ CLIENT_AUTH_METHODS = ['client_secret_post']
def query_token(self, token_string: str, token_type_hint: str):
"""Query the token."""
return _query_token(self, token_string, token_type_hint)
diff --git a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
index b0f2cc7..c802091 100644
--- a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
+++ b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
@@ -1,15 +1,21 @@
"""JWT as Authorisation Grant"""
import uuid
+import time
+from typing import Optional
from flask import current_app as app
+from authlib.jose import jwt
+from authlib.common.encoding import to_native
from authlib.common.security import generate_token
from authlib.oauth2.rfc7523.jwt_bearer import JWTBearerGrant as _JWTBearerGrant
from authlib.oauth2.rfc7523.token import (
JWTBearerTokenGenerator as _JWTBearerTokenGenerator)
+from gn_auth.debug import __pk__
from gn_auth.auth.db.sqlite3 import with_db_connection
-from gn_auth.auth.authentication.users import user_by_id
+from gn_auth.auth.authentication.users import User, user_by_id
+from gn_auth.auth.authentication.oauth2.models.oauth2client import OAuth2Client
class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
@@ -19,23 +25,66 @@ class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
DEFAULT_EXPIRES_IN = 300
- def get_token_data(#pylint: disable=[too-many-arguments]
+ def get_token_data(#pylint: disable=[too-many-arguments, too-many-positional-arguments]
self, grant_type, client, expires_in=None, user=None, scope=None
):
"""Post process data to prevent JSON serialization problems."""
- tokendata = super().get_token_data(
- grant_type, client, expires_in, user, scope)
+ issued_at = int(time.time())
+ tokendata = {
+ "scope": self.get_allowed_scope(client, scope),
+ "grant_type": grant_type,
+ "iat": issued_at,
+ "client_id": client.get_client_id()
+ }
+ if isinstance(expires_in, int) and expires_in > 0:
+ tokendata["exp"] = issued_at + expires_in
+ if self.issuer:
+ tokendata["iss"] = self.issuer
+ if user:
+ tokendata["sub"] = self.get_sub_value(user)
+
return {
**{
key: str(value) if key.endswith("_id") else value
for key, value in tokendata.items()
},
"sub": str(tokendata["sub"]),
- "jti": str(uuid.uuid4())
+ "jti": str(uuid.uuid4()),
+ "oauth2_client_id": str(client.client_id)
}
+ def generate(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
+ self,
+ grant_type: str,
+ client: OAuth2Client,
+ user: Optional[User] = None,
+ scope: Optional[str] = None,
+ expires_in: Optional[int] = None
+ ) -> dict:
+ """Generate a bearer token for OAuth 2.0 authorization token endpoint.
+
+ :param client: the client that making the request.
+ :param grant_type: current requested grant_type.
+ :param user: current authorized user.
+ :param expires_in: if provided, use this value as expires_in.
+ :param scope: current requested scope.
+ :return: Token dict
+ """
+
+ token_data = self.get_token_data(grant_type, client, expires_in, user, scope)
+ access_token = jwt.encode({"alg": self.alg}, token_data, key=self.secret_key, check=False)
+ token = {
+ "token_type": "Bearer",
+ "access_token": to_native(access_token)
+ }
+ if expires_in:
+ token["expires_in"] = expires_in
+ if scope:
+ token["scope"] = scope
+ return token
+
- def __call__(# pylint: disable=[too-many-arguments]
+ def __call__(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
self, grant_type, client, user=None, scope=None, expires_in=None,
include_refresh_token=True
):
@@ -74,7 +123,9 @@ class JWTBearerGrant(_JWTBearerGrant):
def resolve_client_key(self, client, headers, payload):
"""Resolve client key to decode assertion data."""
- return app.config["SSL_PUBLIC_KEYS"].get(headers["kid"])
+ keyset = client.jwks()
+ __pk__("THE KEYSET =======>", keyset.keys)
+ return keyset.find_by_kid(headers["kid"])
def authenticate_user(self, subject):
diff --git a/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py b/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py
index fd6804d..f897d89 100644
--- a/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py
+++ b/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py
@@ -34,18 +34,18 @@ class RefreshTokenGrant(grants.RefreshTokenGrant):
else Nothing)
).maybe(None, lambda _tok: _tok)
- def authenticate_user(self, credential):
+ def authenticate_user(self, refresh_token):
"""Check that user is valid for given token."""
with connection(app.config["AUTH_DB"]) as conn:
try:
- return user_by_id(conn, credential.user.user_id)
+ return user_by_id(conn, refresh_token.user.user_id)
except NotFoundError as _nfe:
return None
return None
- def revoke_old_credential(self, credential):
+ def revoke_old_credential(self, refresh_token):
"""Revoke any old refresh token after issuing new refresh token."""
with connection(app.config["AUTH_DB"]) as conn:
- if credential.parent_of is not None:
- revoke_refresh_token(conn, credential)
+ if refresh_token.parent_of is not None:
+ revoke_refresh_token(conn, refresh_token)
diff --git a/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py
new file mode 100644
index 0000000..71769e1
--- /dev/null
+++ b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py
@@ -0,0 +1,50 @@
+"""Implement model for JWTBearerToken"""
+import uuid
+import time
+from typing import Optional
+
+from authlib.oauth2.rfc7523 import JWTBearerToken as _JWTBearerToken
+
+from gn_auth.auth.db.sqlite3 import with_db_connection
+from gn_auth.auth.authentication.users import user_by_id
+from gn_auth.auth.authentication.oauth2.models.oauth2client import (
+ client as fetch_client)
+
+class JWTBearerToken(_JWTBearerToken):
+ """Overrides default JWTBearerToken class."""
+
+ def __init__(self, payload, header, options=None, params=None):
+ """Initialise the bearer token."""
+ # TOD0: Maybe remove this init and make this a dataclass like the way
+ # OAuth2Client is a dataclass
+ super().__init__(payload, header, options, params)
+ self.user = with_db_connection(
+ lambda conn:user_by_id(conn, uuid.UUID(payload["sub"])))
+ self.client = with_db_connection(
+ lambda conn: fetch_client(
+ conn, uuid.UUID(payload["oauth2_client_id"])
+ )
+ ).maybe(None, lambda _client: _client)
+
+
+ def check_client(self, client):
+ """Check that the client is right."""
+ return self.client.get_client_id() == client.get_client_id()
+
+
+ def get_expires_in(self) -> Optional[int]:
+ """Return the number of seconds the token is valid for since issue.
+
+ If `None`, the token never expires."""
+ if "exp" in self:
+ return self['exp'] - self['iat']
+ return None
+
+
+ def is_expired(self):
+ """Check whether the token is expired.
+
+ If there is no 'exp' member, assume this token will never expire."""
+ if "exp" in self:
+ return self["exp"] < time.time()
+ return False
diff --git a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
index 31c9147..46515c8 100644
--- a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
+++ b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
@@ -142,7 +142,7 @@ def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str):
"WHERE token=:parenttoken"),
{"parenttoken": parent.token, "childtoken": childtoken})
- def __check_child__(parent):
+ def __check_child__(parent):#pylint: disable=[unused-variable]
with db.cursor(conn) as cursor:
cursor.execute(
("SELECT * FROM jwt_refresh_tokens WHERE token=:parenttoken"),
@@ -154,15 +154,17 @@ def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str):
"activity detected.")
return Right(parent)
- def __revoke_and_raise_error__(_error_msg_):
+ def __revoke_and_raise_error__(_error_msg_):#pylint: disable=[unused-variable]
load_refresh_token(conn, parenttoken).then(
lambda _tok: revoke_refresh_token(conn, _tok))
raise InvalidGrantError(_error_msg_)
+ def __handle_not_found__(_error_msg_):
+ raise InvalidGrantError(_error_msg_)
+
load_refresh_token(conn, parenttoken).maybe(
- Left("Token not found"), Right).then(
- __check_child__).either(__revoke_and_raise_error__,
- __link_to_child__)
+ Left("Token not found"), Right).either(
+ __handle_not_found__, __link_to_child__)
def is_refresh_token_valid(token: JWTRefreshToken, client: OAuth2Client) -> bool:
diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
index d31faf6..1639e2e 100644
--- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py
+++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
@@ -1,17 +1,19 @@
"""OAuth2 Client model."""
import json
import datetime
-from pathlib import Path
-
from uuid import UUID
-from dataclasses import dataclass
from functools import cached_property
-from typing import Sequence, Optional
+from dataclasses import asdict, dataclass
+from typing import Any, Sequence, Optional
+import requests
+from flask import current_app as app
+from requests.exceptions import JSONDecodeError
from authlib.jose import KeySet, JsonWebKey
from authlib.oauth2.rfc6749 import ClientMixin
from pymonad.maybe import Just, Maybe, Nothing
+from gn_auth.debug import __pk__
from gn_auth.auth.db import sqlite3 as db
from gn_auth.auth.errors import NotFoundError
from gn_auth.auth.authentication.users import (User,
@@ -57,16 +59,34 @@ class OAuth2Client(ClientMixin):
"""
return self.client_metadata.get("client_type", "public")
- @cached_property
+
def jwks(self) -> KeySet:
"""Return this client's KeySet."""
- def __parse_key__(keypath: Path) -> JsonWebKey:
- with open(keypath) as _key:# pylint: disable=[unspecified-encoding]
- return JsonWebKey.import_key(_key.read())
+ jwksuri = self.client_metadata.get("public-jwks-uri")
+ __pk__(f"PUBLIC JWKs link for client {self.client_id}", jwksuri)
+ if not bool(jwksuri):
+ app.logger.debug("No Public JWKs URI set for client!")
+ return KeySet([])
+ try:
+ ## IMPORTANT: This can cause a deadlock if the client is working in
+ ## single-threaded mode, i.e. can only serve one request
+ ## at a time.
+ return KeySet([JsonWebKey.import_key(key)
+ for key in requests.get(
+ jwksuri,
+ timeout=300,
+ allow_redirects=True).json()["jwks"]])
+ except requests.ConnectionError as _connerr:
+ app.logger.debug(
+ "Could not connect to provided URI: %s", jwksuri, exc_info=True)
+ except JSONDecodeError as _jsonerr:
+ app.logger.debug(
+ "Could not convert response to JSON", exc_info=True)
+ except Exception as _exc:# pylint: disable=[broad-except]
+ app.logger.debug(
+ "Error retrieving the JWKs for the client.", exc_info=True)
+ return KeySet([])
- return KeySet([
- __parse_key__(Path(pth))
- for pth in self.client_metadata.get("public_keys", [])])
def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool:
"""
@@ -77,12 +97,9 @@ class OAuth2Client(ClientMixin):
* client_secret_post: Client uses the HTTP POST parameters
* client_secret_basic: Client uses HTTP Basic
"""
- if endpoint == "token":
+ if endpoint in ("token", "revoke", "introspection"):
return (method in self.token_endpoint_auth_method
and method == "client_secret_post")
- if endpoint in ("introspection", "revoke"):
- return (method in self.token_endpoint_auth_method
- and method == "client_secret_basic")
return False
@cached_property
@@ -277,3 +294,25 @@ def delete_client(
cursor.execute("DELETE FROM oauth2_tokens WHERE client_id=?", params)
cursor.execute("DELETE FROM oauth2_clients WHERE client_id=?", params)
return the_client
+
+
+def update_client_attribute(
+ client: OAuth2Client,# pylint: disable=[redefined-outer-name]
+ attribute: str,
+ value: Any
+) -> OAuth2Client:
+ """Return a new OAuth2Client with the given attribute updated/changed."""
+ attrs = {
+ attr: type(value)
+ for attr, value in asdict(client).items()
+ if attr != "client_id"
+ }
+ assert (
+ attribute in attrs.keys() and isinstance(value, attrs[attribute])), (
+ "Invalid attribute/value provided!")
+ return OAuth2Client(
+ client_id=client.client_id,
+ **{
+ attr: (value if attr==attribute else getattr(client, attr))
+ for attr in attrs
+ })
diff --git a/gn_auth/auth/authentication/oauth2/resource_server.py b/gn_auth/auth/authentication/oauth2/resource_server.py
index 2405ee2..8ecf923 100644
--- a/gn_auth/auth/authentication/oauth2/resource_server.py
+++ b/gn_auth/auth/authentication/oauth2/resource_server.py
@@ -1,11 +1,20 @@
"""Protect the resources endpoints"""
+from datetime import datetime, timezone, timedelta
from flask import current_app as app
+
+from authlib.jose import jwt, KeySet, JoseError
from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator
+from authlib.oauth2.rfc7523 import (
+ JWTBearerTokenValidator as _JWTBearerTokenValidator)
from authlib.integrations.flask_oauth2 import ResourceProtector
from gn_auth.auth.db import sqlite3 as db
-from gn_auth.auth.authentication.oauth2.models.oauth2token import token_by_access_token
+from gn_auth.auth.jwks import list_jwks, jwks_directory
+from gn_auth.auth.authentication.oauth2.models.jwt_bearer_token import (
+ JWTBearerToken)
+from gn_auth.auth.authentication.oauth2.models.oauth2token import (
+ token_by_access_token)
class BearerTokenValidator(_BearerTokenValidator):
"""Extends `authlib.oauth2.rfc6750.BearerTokenValidator`"""
@@ -14,4 +23,52 @@ class BearerTokenValidator(_BearerTokenValidator):
return token_by_access_token(conn, token_string).maybe(# type: ignore[misc]
None, lambda tok: tok)
+class JWTBearerTokenValidator(_JWTBearerTokenValidator):
+ """Validate a token using all the keys"""
+ token_cls = JWTBearerToken
+ _local_attributes = ("jwt_refresh_frequency_hours",)
+
+ def __init__(self, public_key, issuer=None, realm=None, **extra_attributes):
+ """Initialise the validator class."""
+ # https://docs.authlib.org/en/latest/jose/jwt.html#use-dynamic-keys
+ # We can simply use the KeySet rather than a specific key.
+ super().__init__(public_key,
+ issuer,
+ realm,
+ **{
+ key: value for key,value
+ in extra_attributes.items()
+ if key not in self._local_attributes
+ })
+ self._last_jwks_update = datetime.now(tz=timezone.utc)
+ self._refresh_frequency = timedelta(hours=int(
+ extra_attributes.get("jwt_refresh_frequency_hours", 6)))
+ self.claims_options = {
+ 'exp': {'essential': False},
+ 'client_id': {'essential': True},
+ 'grant_type': {'essential': True},
+ }
+
+ def __refresh_jwks__(self):
+ now = datetime.now(tz=timezone.utc)
+ if (now - self._last_jwks_update) >= self._refresh_frequency:
+ self.public_key = KeySet(list_jwks(jwks_directory(app)))
+
+ def authenticate_token(self, token_string: str):
+ self.__refresh_jwks__()
+ for key in self.public_key.keys:
+ try:
+ claims = jwt.decode(
+ token_string, key,
+ claims_options=self.claims_options,
+ claims_cls=self.token_cls,
+ )
+ claims.validate()
+ return claims
+ except JoseError as error:
+ app.logger.debug('Authenticate token failed. %r', error)
+
+ return None
+
+
require_oauth = ResourceProtector()
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py
index d845c60..8ac5106 100644
--- a/gn_auth/auth/authentication/oauth2/server.py
+++ b/gn_auth/auth/authentication/oauth2/server.py
@@ -1,23 +1,24 @@
"""Initialise the OAuth2 Server"""
import uuid
-import datetime
from typing import Callable
+from datetime import datetime
-from flask import Flask, current_app
-from authlib.jose import jwk, jwt
-from authlib.oauth2.rfc7523 import JWTBearerTokenValidator
+from flask import Flask, current_app, request as flask_request
+from authlib.jose import KeySet
+from authlib.oauth2.rfc6749 import OAuth2Request
from authlib.oauth2.rfc6749.errors import InvalidClientError
from authlib.integrations.flask_oauth2 import AuthorizationServer
+from authlib.integrations.flask_oauth2.requests import FlaskOAuth2Request
from gn_auth.auth.db import sqlite3 as db
+from gn_auth.auth.jwks import (
+ list_jwks,
+ jwks_directory,
+ newest_jwk_with_rotation)
+from .models.jwt_bearer_token import JWTBearerToken
from .models.oauth2client import client as fetch_client
from .models.oauth2token import OAuth2Token, save_token
-from .models.jwtrefreshtoken import (
- JWTRefreshToken,
- link_child_token,
- save_refresh_token,
- load_refresh_token)
from .grants.password_grant import PasswordGrant
from .grants.refresh_token_grant import RefreshTokenGrant
@@ -27,7 +28,9 @@ from .grants.jwt_bearer_grant import JWTBearerGrant, JWTBearerTokenGenerator
from .endpoints.revocation import RevocationEndpoint
from .endpoints.introspection import IntrospectionEndpoint
-from .resource_server import require_oauth, BearerTokenValidator
+from .resource_server import require_oauth, JWTBearerTokenValidator
+
+_TWO_HOURS_ = 2 * 60 * 60
def create_query_client_func() -> Callable:
@@ -45,52 +48,32 @@ def create_query_client_func() -> Callable:
return __query_client__
-def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable:
+def create_save_token_func(token_model: type) -> Callable:
"""Create the function that saves the token."""
+ def __ignore_token__(token, request):# pylint: disable=[unused-argument]
+ """Ignore the token: i.e. Do not save it."""
+
def __save_token__(token, request):
- _jwt = jwt.decode(token["access_token"], jwtkey)
- _token = token_model(
- token_id=uuid.UUID(_jwt["jti"]),
- client=request.client,
- user=request.user,
- **{
- "refresh_token": None,
- "revoked": False,
- "issued_at": datetime.datetime.now(),
- **token
- })
with db.connection(current_app.config["AUTH_DB"]) as conn:
- save_token(conn, _token)
- old_refresh_token = load_refresh_token(
+ save_token(
conn,
- request.form.get("refresh_token", "nosuchtoken")
- )
- new_refresh_token = JWTRefreshToken(
- token=_token.refresh_token,
+ token_model(
+ **token,
+ token_id=uuid.uuid4(),
client=request.client,
user=request.user,
- issued_with=uuid.UUID(_jwt["jti"]),
- issued_at=datetime.datetime.fromtimestamp(_jwt["iat"]),
- expires=datetime.datetime.fromtimestamp(
- old_refresh_token.then(
- lambda _tok: _tok.expires.timestamp()
- ).maybe((int(_jwt["iat"]) +
- RefreshTokenGrant.DEFAULT_EXPIRES_IN),
- lambda _expires: _expires)),
- scope=_token.get_scope(),
+ issued_at=datetime.now(),
revoked=False,
- parent_of=None)
- save_refresh_token(conn, new_refresh_token)
- old_refresh_token.then(lambda _tok: link_child_token(
- conn, _tok.token, new_refresh_token.token))
-
- return __save_token__
+ expires_in=_TWO_HOURS_))
+ return {
+ OAuth2Token: __save_token__,
+ JWTBearerToken: __ignore_token__
+ }[token_model]
def make_jwt_token_generator(app):
"""Make token generator function."""
- _gen = JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"])
- def __generator__(# pylint: disable=[too-many-arguments]
+ def __generator__(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
grant_type,
client,
user=None,
@@ -98,19 +81,42 @@ def make_jwt_token_generator(app):
expires_in=None,# pylint: disable=[unused-argument]
include_refresh_token=True
):
- return _gen.__call__(
- grant_type,
- client,
- user,
- scope,
- JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
- include_refresh_token)
+ return JWTBearerTokenGenerator(
+ secret_key=newest_jwk_with_rotation(
+ jwks_directory(app),
+ int(app.config["JWKS_ROTATION_AGE_DAYS"])),
+ issuer=flask_request.host_url,
+ alg="RS256").__call__(
+ grant_type=grant_type,
+ client=client,
+ user=user,
+ scope=scope,
+ expires_in=expires_in,
+ include_refresh_token=include_refresh_token)
return __generator__
+
+class JsonAuthorizationServer(AuthorizationServer):
+ """An authorisation server using JSON rather than FORMDATA."""
+
+ def create_oauth2_request(self, request):
+ """Create an OAuth2 Request from the flask request."""
+ match flask_request.headers.get("Content-Type"):
+ case "application/json":
+ req = OAuth2Request(flask_request.method,
+ flask_request.url,
+ flask_request.get_json(),
+ flask_request.headers)
+ case _:
+ req = FlaskOAuth2Request(flask_request)
+
+ return req
+
+
def setup_oauth2_server(app: Flask) -> None:
"""Set's up the oauth2 server for the flask application."""
- server = AuthorizationServer()
+ server = JsonAuthorizationServer()
server.register_grant(PasswordGrant)
# Figure out a common `code_verifier` for GN2 and GN3 and set
@@ -133,11 +139,9 @@ def setup_oauth2_server(app: Flask) -> None:
server.init_app(
app,
query_client=create_query_client_func(),
- save_token=create_save_token_func(
- OAuth2Token, app.config["SSL_PRIVATE_KEY"]))
+ save_token=create_save_token_func(JWTBearerToken))
app.config["OAUTH2_SERVER"] = server
## Set up the token validators
- require_oauth.register_token_validator(BearerTokenValidator())
require_oauth.register_token_validator(
- JWTBearerTokenValidator(app.config["SSL_PRIVATE_KEY"].get_public_key()))
+ JWTBearerTokenValidator(KeySet(list_jwks(jwks_directory(app)))))
diff --git a/gn_auth/auth/authentication/oauth2/views.py b/gn_auth/auth/authentication/oauth2/views.py
index 22437a2..0e2c4eb 100644
--- a/gn_auth/auth/authentication/oauth2/views.py
+++ b/gn_auth/auth/authentication/oauth2/views.py
@@ -9,6 +9,7 @@ from flask import (
flash,
request,
url_for,
+ jsonify,
redirect,
Response,
Blueprint,
@@ -17,6 +18,7 @@ from flask import (
from gn_auth.auth.db import sqlite3 as db
from gn_auth.auth.db.sqlite3 import with_db_connection
+from gn_auth.auth.jwks import jwks_directory, list_jwks
from gn_auth.auth.errors import NotFoundError, ForbiddenAccess
from gn_auth.auth.authentication.users import valid_login, user_by_email
@@ -45,6 +47,14 @@ def authorise():
flash("Invalid OAuth2 client.", "alert-danger")
if request.method == "GET":
+ def __forgot_password_table_exists__(conn):
+ with db.cursor(conn) as cursor:
+ cursor.execute("SELECT name FROM sqlite_master "
+ "WHERE type='table' "
+ "AND name='forgot_password_tokens'")
+ return bool(cursor.fetchone())
+ return False
+
client = server.query_client(request.args.get("client_id"))
_src = urlparse(request.args["redirect_uri"])
return render_template(
@@ -53,7 +63,9 @@ def authorise():
scope=client.scope,
response_type=request.args["response_type"],
redirect_uri=request.args["redirect_uri"],
- source_uri=f"{_src.scheme}://{_src.netloc}/")
+ source_uri=f"{_src.scheme}://{_src.netloc}/",
+ display_forgot_password=with_db_connection(
+ __forgot_password_table_exists__))
form = request.form
def __authorise__(conn: db.DbConnection):
@@ -65,14 +77,15 @@ def authorise():
try:
email = validate_email(
form.get("user:email"), check_deliverability=False)
- user = user_by_email(conn, email["email"])
+ user = user_by_email(conn, email["email"]) # type: ignore
if valid_login(conn, user, form.get("user:password", "")):
if not user.verified:
return redirect(
url_for("oauth2.users.handle_unverified",
response_type=form["response_type"],
client_id=client_id,
- redirect_uri=form["redirect_uri"]),
+ redirect_uri=form["redirect_uri"],
+ email=email["email"]),
code=307)
return server.create_authorization_response(request=request, grant_user=user)
flash(email_passwd_msg, "alert-danger")
@@ -116,3 +129,13 @@ def introspect_token() -> Response:
IntrospectionEndpoint.ENDPOINT_NAME)
raise ForbiddenAccess("You cannot access this endpoint")
+
+
+@auth.route("/public-jwks", methods=["GET"])
+def public_jwks():
+ """Provide the JWK public keys used by this application."""
+ return jsonify({
+ "documentation": (
+ "The keys are listed in order of creation, from the oldest (first) "
+ "to the newest (last)."),
+ "jwks": tuple(key.as_dict() for key in list_jwks(jwks_directory(app)))})