about summary refs log tree commit diff
path: root/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py')
-rw-r--r--gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py65
1 files changed, 58 insertions, 7 deletions
diff --git a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
index b0f2cc7..c802091 100644
--- a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
+++ b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
@@ -1,15 +1,21 @@
 """JWT as Authorisation Grant"""
 import uuid
+import time
 
+from typing import Optional
 from flask import current_app as app
 
+from authlib.jose import jwt
+from authlib.common.encoding import to_native
 from authlib.common.security import generate_token
 from authlib.oauth2.rfc7523.jwt_bearer import JWTBearerGrant as _JWTBearerGrant
 from authlib.oauth2.rfc7523.token import (
     JWTBearerTokenGenerator as _JWTBearerTokenGenerator)
 
+from gn_auth.debug import __pk__
 from gn_auth.auth.db.sqlite3 import with_db_connection
-from gn_auth.auth.authentication.users import user_by_id
+from gn_auth.auth.authentication.users import User, user_by_id
+from gn_auth.auth.authentication.oauth2.models.oauth2client import OAuth2Client
 
 
 class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
@@ -19,23 +25,66 @@ class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
 
     DEFAULT_EXPIRES_IN = 300
 
-    def get_token_data(#pylint: disable=[too-many-arguments]
+    def get_token_data(#pylint: disable=[too-many-arguments, too-many-positional-arguments]
             self, grant_type, client, expires_in=None, user=None, scope=None
     ):
         """Post process data to prevent JSON serialization problems."""
-        tokendata = super().get_token_data(
-            grant_type, client, expires_in, user, scope)
+        issued_at = int(time.time())
+        tokendata = {
+            "scope": self.get_allowed_scope(client, scope),
+            "grant_type": grant_type,
+            "iat": issued_at,
+            "client_id": client.get_client_id()
+        }
+        if isinstance(expires_in, int) and expires_in > 0:
+            tokendata["exp"] = issued_at + expires_in
+        if self.issuer:
+            tokendata["iss"] = self.issuer
+        if user:
+            tokendata["sub"] = self.get_sub_value(user)
+
         return {
             **{
                 key: str(value) if key.endswith("_id") else value
                 for key, value in tokendata.items()
             },
             "sub": str(tokendata["sub"]),
-            "jti": str(uuid.uuid4())
+            "jti": str(uuid.uuid4()),
+            "oauth2_client_id": str(client.client_id)
         }
 
+    def generate(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
+            self,
+            grant_type: str,
+            client: OAuth2Client,
+            user: Optional[User] = None,
+            scope: Optional[str] = None,
+            expires_in: Optional[int] = None
+    ) -> dict:
+        """Generate a bearer token for OAuth 2.0 authorization token endpoint.
+
+        :param client: the client that making the request.
+        :param grant_type: current requested grant_type.
+        :param user: current authorized user.
+        :param expires_in: if provided, use this value as expires_in.
+        :param scope: current requested scope.
+        :return: Token dict
+        """
+
+        token_data = self.get_token_data(grant_type, client, expires_in, user, scope)
+        access_token = jwt.encode({"alg": self.alg}, token_data, key=self.secret_key, check=False)
+        token = {
+            "token_type": "Bearer",
+            "access_token": to_native(access_token)
+        }
+        if expires_in:
+            token["expires_in"] = expires_in
+        if scope:
+            token["scope"] = scope
+        return token
+
 
-    def __call__(# pylint: disable=[too-many-arguments]
+    def __call__(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
             self, grant_type, client, user=None, scope=None, expires_in=None,
             include_refresh_token=True
     ):
@@ -74,7 +123,9 @@ class JWTBearerGrant(_JWTBearerGrant):
 
     def resolve_client_key(self, client, headers, payload):
         """Resolve client key to decode assertion data."""
-        return app.config["SSL_PUBLIC_KEYS"].get(headers["kid"])
+        keyset = client.jwks()
+        __pk__("THE KEYSET =======>", keyset.keys)
+        return keyset.find_by_kid(headers["kid"])
 
 
     def authenticate_user(self, subject):