diff options
Diffstat (limited to 'gn_auth/auth/authentication/oauth2/models')
3 files changed, 111 insertions, 20 deletions
diff --git a/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py new file mode 100644 index 0000000..71769e1 --- /dev/null +++ b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py @@ -0,0 +1,50 @@ +"""Implement model for JWTBearerToken""" +import uuid +import time +from typing import Optional + +from authlib.oauth2.rfc7523 import JWTBearerToken as _JWTBearerToken + +from gn_auth.auth.db.sqlite3 import with_db_connection +from gn_auth.auth.authentication.users import user_by_id +from gn_auth.auth.authentication.oauth2.models.oauth2client import ( + client as fetch_client) + +class JWTBearerToken(_JWTBearerToken): + """Overrides default JWTBearerToken class.""" + + def __init__(self, payload, header, options=None, params=None): + """Initialise the bearer token.""" + # TOD0: Maybe remove this init and make this a dataclass like the way + # OAuth2Client is a dataclass + super().__init__(payload, header, options, params) + self.user = with_db_connection( + lambda conn:user_by_id(conn, uuid.UUID(payload["sub"]))) + self.client = with_db_connection( + lambda conn: fetch_client( + conn, uuid.UUID(payload["oauth2_client_id"]) + ) + ).maybe(None, lambda _client: _client) + + + def check_client(self, client): + """Check that the client is right.""" + return self.client.get_client_id() == client.get_client_id() + + + def get_expires_in(self) -> Optional[int]: + """Return the number of seconds the token is valid for since issue. + + If `None`, the token never expires.""" + if "exp" in self: + return self['exp'] - self['iat'] + return None + + + def is_expired(self): + """Check whether the token is expired. + + If there is no 'exp' member, assume this token will never expire.""" + if "exp" in self: + return self["exp"] < time.time() + return False diff --git a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py index 31c9147..46515c8 100644 --- a/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py +++ b/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py @@ -142,7 +142,7 @@ def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str): "WHERE token=:parenttoken"), {"parenttoken": parent.token, "childtoken": childtoken}) - def __check_child__(parent): + def __check_child__(parent):#pylint: disable=[unused-variable] with db.cursor(conn) as cursor: cursor.execute( ("SELECT * FROM jwt_refresh_tokens WHERE token=:parenttoken"), @@ -154,15 +154,17 @@ def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str): "activity detected.") return Right(parent) - def __revoke_and_raise_error__(_error_msg_): + def __revoke_and_raise_error__(_error_msg_):#pylint: disable=[unused-variable] load_refresh_token(conn, parenttoken).then( lambda _tok: revoke_refresh_token(conn, _tok)) raise InvalidGrantError(_error_msg_) + def __handle_not_found__(_error_msg_): + raise InvalidGrantError(_error_msg_) + load_refresh_token(conn, parenttoken).maybe( - Left("Token not found"), Right).then( - __check_child__).either(__revoke_and_raise_error__, - __link_to_child__) + Left("Token not found"), Right).either( + __handle_not_found__, __link_to_child__) def is_refresh_token_valid(token: JWTRefreshToken, client: OAuth2Client) -> bool: diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py index d31faf6..1639e2e 100644 --- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py +++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py @@ -1,17 +1,19 @@ """OAuth2 Client model.""" import json import datetime -from pathlib import Path - from uuid import UUID -from dataclasses import dataclass from functools import cached_property -from typing import Sequence, Optional +from dataclasses import asdict, dataclass +from typing import Any, Sequence, Optional +import requests +from flask import current_app as app +from requests.exceptions import JSONDecodeError from authlib.jose import KeySet, JsonWebKey from authlib.oauth2.rfc6749 import ClientMixin from pymonad.maybe import Just, Maybe, Nothing +from gn_auth.debug import __pk__ from gn_auth.auth.db import sqlite3 as db from gn_auth.auth.errors import NotFoundError from gn_auth.auth.authentication.users import (User, @@ -57,16 +59,34 @@ class OAuth2Client(ClientMixin): """ return self.client_metadata.get("client_type", "public") - @cached_property + def jwks(self) -> KeySet: """Return this client's KeySet.""" - def __parse_key__(keypath: Path) -> JsonWebKey: - with open(keypath) as _key:# pylint: disable=[unspecified-encoding] - return JsonWebKey.import_key(_key.read()) + jwksuri = self.client_metadata.get("public-jwks-uri") + __pk__(f"PUBLIC JWKs link for client {self.client_id}", jwksuri) + if not bool(jwksuri): + app.logger.debug("No Public JWKs URI set for client!") + return KeySet([]) + try: + ## IMPORTANT: This can cause a deadlock if the client is working in + ## single-threaded mode, i.e. can only serve one request + ## at a time. + return KeySet([JsonWebKey.import_key(key) + for key in requests.get( + jwksuri, + timeout=300, + allow_redirects=True).json()["jwks"]]) + except requests.ConnectionError as _connerr: + app.logger.debug( + "Could not connect to provided URI: %s", jwksuri, exc_info=True) + except JSONDecodeError as _jsonerr: + app.logger.debug( + "Could not convert response to JSON", exc_info=True) + except Exception as _exc:# pylint: disable=[broad-except] + app.logger.debug( + "Error retrieving the JWKs for the client.", exc_info=True) + return KeySet([]) - return KeySet([ - __parse_key__(Path(pth)) - for pth in self.client_metadata.get("public_keys", [])]) def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool: """ @@ -77,12 +97,9 @@ class OAuth2Client(ClientMixin): * client_secret_post: Client uses the HTTP POST parameters * client_secret_basic: Client uses HTTP Basic """ - if endpoint == "token": + if endpoint in ("token", "revoke", "introspection"): return (method in self.token_endpoint_auth_method and method == "client_secret_post") - if endpoint in ("introspection", "revoke"): - return (method in self.token_endpoint_auth_method - and method == "client_secret_basic") return False @cached_property @@ -277,3 +294,25 @@ def delete_client( cursor.execute("DELETE FROM oauth2_tokens WHERE client_id=?", params) cursor.execute("DELETE FROM oauth2_clients WHERE client_id=?", params) return the_client + + +def update_client_attribute( + client: OAuth2Client,# pylint: disable=[redefined-outer-name] + attribute: str, + value: Any +) -> OAuth2Client: + """Return a new OAuth2Client with the given attribute updated/changed.""" + attrs = { + attr: type(value) + for attr, value in asdict(client).items() + if attr != "client_id" + } + assert ( + attribute in attrs.keys() and isinstance(value, attrs[attribute])), ( + "Invalid attribute/value provided!") + return OAuth2Client( + client_id=client.client_id, + **{ + attr: (value if attr==attribute else getattr(client, attr)) + for attr in attrs + }) |