diff options
Diffstat (limited to 'gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py')
-rw-r--r-- | gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py | 65 |
1 files changed, 58 insertions, 7 deletions
diff --git a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py index 1f53186..c802091 100644 --- a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py +++ b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py @@ -1,15 +1,21 @@ """JWT as Authorisation Grant""" import uuid +import time +from typing import Optional from flask import current_app as app +from authlib.jose import jwt +from authlib.common.encoding import to_native from authlib.common.security import generate_token from authlib.oauth2.rfc7523.jwt_bearer import JWTBearerGrant as _JWTBearerGrant from authlib.oauth2.rfc7523.token import ( JWTBearerTokenGenerator as _JWTBearerTokenGenerator) +from gn_auth.debug import __pk__ from gn_auth.auth.db.sqlite3 import with_db_connection -from gn_auth.auth.authentication.users import user_by_id +from gn_auth.auth.authentication.users import User, user_by_id +from gn_auth.auth.authentication.oauth2.models.oauth2client import OAuth2Client class JWTBearerTokenGenerator(_JWTBearerTokenGenerator): @@ -19,23 +25,66 @@ class JWTBearerTokenGenerator(_JWTBearerTokenGenerator): DEFAULT_EXPIRES_IN = 300 - def get_token_data(#pylint: disable=[too-many-arguments] + def get_token_data(#pylint: disable=[too-many-arguments, too-many-positional-arguments] self, grant_type, client, expires_in=None, user=None, scope=None ): """Post process data to prevent JSON serialization problems.""" - tokendata = super().get_token_data( - grant_type, client, expires_in, user, scope) + issued_at = int(time.time()) + tokendata = { + "scope": self.get_allowed_scope(client, scope), + "grant_type": grant_type, + "iat": issued_at, + "client_id": client.get_client_id() + } + if isinstance(expires_in, int) and expires_in > 0: + tokendata["exp"] = issued_at + expires_in + if self.issuer: + tokendata["iss"] = self.issuer + if user: + tokendata["sub"] = self.get_sub_value(user) + return { **{ key: str(value) if key.endswith("_id") else value for key, value in tokendata.items() }, "sub": str(tokendata["sub"]), - "jti": str(uuid.uuid4()) + "jti": str(uuid.uuid4()), + "oauth2_client_id": str(client.client_id) } + def generate(# pylint: disable=[too-many-arguments, too-many-positional-arguments] + self, + grant_type: str, + client: OAuth2Client, + user: Optional[User] = None, + scope: Optional[str] = None, + expires_in: Optional[int] = None + ) -> dict: + """Generate a bearer token for OAuth 2.0 authorization token endpoint. + + :param client: the client that making the request. + :param grant_type: current requested grant_type. + :param user: current authorized user. + :param expires_in: if provided, use this value as expires_in. + :param scope: current requested scope. + :return: Token dict + """ + + token_data = self.get_token_data(grant_type, client, expires_in, user, scope) + access_token = jwt.encode({"alg": self.alg}, token_data, key=self.secret_key, check=False) + token = { + "token_type": "Bearer", + "access_token": to_native(access_token) + } + if expires_in: + token["expires_in"] = expires_in + if scope: + token["scope"] = scope + return token + - def __call__(# pylint: disable=[too-many-arguments] + def __call__(# pylint: disable=[too-many-arguments, too-many-positional-arguments] self, grant_type, client, user=None, scope=None, expires_in=None, include_refresh_token=True ): @@ -74,7 +123,9 @@ class JWTBearerGrant(_JWTBearerGrant): def resolve_client_key(self, client, headers, payload): """Resolve client key to decode assertion data.""" - return client.jwks().find_by_kid(headers["kid"]) + keyset = client.jwks() + __pk__("THE KEYSET =======>", keyset.keys) + return keyset.find_by_kid(headers["kid"]) def authenticate_user(self, subject): |