diff options
Diffstat (limited to 'gn_auth/auth/authentication')
3 files changed, 45 insertions, 4 deletions
diff --git a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py index 1f53186..27783ac 100644 --- a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py +++ b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py @@ -8,6 +8,7 @@ from authlib.oauth2.rfc7523.jwt_bearer import JWTBearerGrant as _JWTBearerGrant from authlib.oauth2.rfc7523.token import ( JWTBearerTokenGenerator as _JWTBearerTokenGenerator) +from gn_auth.debug import __pk__ from gn_auth.auth.db.sqlite3 import with_db_connection from gn_auth.auth.authentication.users import user_by_id @@ -31,7 +32,8 @@ class JWTBearerTokenGenerator(_JWTBearerTokenGenerator): for key, value in tokendata.items() }, "sub": str(tokendata["sub"]), - "jti": str(uuid.uuid4()) + "jti": str(uuid.uuid4()), + "oauth2_client_id": str(client.client_id) } @@ -74,7 +76,9 @@ class JWTBearerGrant(_JWTBearerGrant): def resolve_client_key(self, client, headers, payload): """Resolve client key to decode assertion data.""" - return client.jwks().find_by_kid(headers["kid"]) + keyset = client.jwks() + __pk__("THE KEYSET =======>", keyset.keys) + return keyset.find_by_kid(headers["kid"]) def authenticate_user(self, subject): diff --git a/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py index 2606ac6..cca75f4 100644 --- a/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py +++ b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py @@ -5,11 +5,26 @@ from authlib.oauth2.rfc7523 import JWTBearerToken as _JWTBearerToken from gn_auth.auth.db.sqlite3 import with_db_connection from gn_auth.auth.authentication.users import user_by_id +from gn_auth.auth.authentication.oauth2.models.oauth2client import ( + client as fetch_client) class JWTBearerToken(_JWTBearerToken): """Overrides default JWTBearerToken class.""" def __init__(self, payload, header, options=None, params=None): + """Initialise the bearer token.""" + # TOD0: Maybe remove this init and make this a dataclass like the way + # OAuth2Client is a dataclass super().__init__(payload, header, options, params) self.user = with_db_connection( lambda conn:user_by_id(conn, uuid.UUID(payload["sub"]))) + self.client = with_db_connection( + lambda conn: fetch_client( + conn, uuid.UUID(payload["oauth2_client_id"]) + ) + ).maybe(None, lambda _client: _client) + + + def check_client(self, client): + """Check that the client is right.""" + return self.client.get_client_id() == client.get_client_id() diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py index 8fac648..df5d564 100644 --- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py +++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py @@ -3,9 +3,9 @@ import json import logging import datetime from uuid import UUID -from dataclasses import dataclass from functools import cached_property -from typing import Sequence, Optional +from dataclasses import asdict, dataclass +from typing import Any, Sequence, Optional import requests from requests.exceptions import JSONDecodeError @@ -289,3 +289,25 @@ def delete_client( cursor.execute("DELETE FROM oauth2_tokens WHERE client_id=?", params) cursor.execute("DELETE FROM oauth2_clients WHERE client_id=?", params) return the_client + + +def update_client_attribute( + client: OAuth2Client,# pylint: disable=[redefined-outer-name] + attribute: str, + value: Any +) -> OAuth2Client: + """Return a new OAuth2Client with the given attribute updated/changed.""" + attrs = { + attr: type(value) + for attr, value in asdict(client).items() + if attr != "client_id" + } + assert ( + attribute in attrs.keys() and isinstance(value, attrs[attribute])), ( + "Invalid attribute/value provided!") + return OAuth2Client( + client_id=client.client_id, + **{ + attr: (value if attr==attribute else getattr(client, attr)) + for attr in attrs + }) |