aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/auth/authentication/oauth2/grants
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth/auth/authentication/oauth2/grants')
-rw-r--r--gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py26
1 files changed, 25 insertions, 1 deletions
diff --git a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
index b0f2cc7..d75f730 100644
--- a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
+++ b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
@@ -3,6 +3,8 @@ import uuid
from flask import current_app as app
+from dataclasses import asdict
+
from authlib.common.security import generate_token
from authlib.oauth2.rfc7523.jwt_bearer import JWTBearerGrant as _JWTBearerGrant
from authlib.oauth2.rfc7523.token import (
@@ -10,6 +12,22 @@ from authlib.oauth2.rfc7523.token import (
from gn_auth.auth.db.sqlite3 import with_db_connection
from gn_auth.auth.authentication.users import user_by_id
+from gn_auth.auth.authorisation.roles.models import user_roles
+
+
+def convert_uuids_to_string(srcdict: dict) -> dict:
+ """
+ Convert *ALL* UUID objects in a dict to strings.
+
+ `json.dumps` does not encode UUID objects by default.
+ """
+ def uuid2str(key, value):
+ if isinstance(value, dict):
+ return (key, convert_uuids_to_string(value))
+ if isinstance(value, uuid.UUID):
+ return (key, str(value))
+ return (key, value)
+ return dict(tuple(uuid2str(_key, _val) for _key, _val in srcdict.items()))
class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
@@ -31,7 +49,13 @@ class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
for key, value in tokendata.items()
},
"sub": str(tokendata["sub"]),
- "jti": str(uuid.uuid4())
+ "jti": str(uuid.uuid4()),
+ "gn:auth:user:roles": tuple(convert_uuids_to_string({
+ **item,
+ "roles": tuple(convert_uuids_to_string(asdict(role))
+ for role in item["roles"])
+ }) for item in with_db_connection(
+ lambda conn: user_roles(conn, user)))
}