about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-07-19 12:24:04 -0500
committerFrederick Muriuki Muriithi2024-07-31 09:30:24 -0500
commit12453be99f2bf21842a9e488ecf72c06a06625f0 (patch)
tree039dc3b5b7ca80f73a98b8cee1a8f35b05cc609a
parent6b18e1f0b05222d84fd0b06a8e5c2780df6958d5 (diff)
downloadgn-auth-12453be99f2bf21842a9e488ecf72c06a06625f0.tar.gz
Fetch a client's JWKs from a URI
-rw-r--r--gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py2
-rw-r--r--gn_auth/auth/authentication/oauth2/models/oauth2client.py33
2 files changed, 25 insertions, 10 deletions
diff --git a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
index b0f2cc7..1f53186 100644
--- a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
+++ b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
@@ -74,7 +74,7 @@ class JWTBearerGrant(_JWTBearerGrant):
 
     def resolve_client_key(self, client, headers, payload):
         """Resolve client key to decode assertion data."""
-        return app.config["SSL_PUBLIC_KEYS"].get(headers["kid"])
+        return client.jwks().find_by_kid(headers["kid"])
 
 
     def authenticate_user(self, subject):
diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
index d31faf6..1413722 100644
--- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py
+++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
@@ -1,13 +1,14 @@
 """OAuth2 Client model."""
 import json
+import logging
 import datetime
-from pathlib import Path
-
 from uuid import UUID
 from dataclasses import dataclass
 from functools import cached_property
 from typing import Sequence, Optional
 
+import requests
+from requests.exceptions import JSONDecodeError
 from authlib.jose import KeySet, JsonWebKey
 from authlib.oauth2.rfc6749 import ClientMixin
 from pymonad.maybe import Just, Maybe, Nothing
@@ -57,16 +58,30 @@ class OAuth2Client(ClientMixin):
         """
         return self.client_metadata.get("client_type", "public")
 
-    @cached_property
+
     def jwks(self) -> KeySet:
         """Return this client's KeySet."""
-        def __parse_key__(keypath: Path) -> JsonWebKey:
-            with open(keypath) as _key:# pylint: disable=[unspecified-encoding]
-                return JsonWebKey.import_key(_key.read())
+        jwksuri = self.client_metadata.get("public-jwks-uri")
+        if not bool(jwksuri):
+            logging.debug("No Public JWKs URI set for client!")
+            return KeySet([])
+        try:
+            ## IMPORTANT: This can cause a deadlock if the client is working in
+            ##            single-threaded mode, i.e. can only serve one request
+            ##            at a time.
+            return KeySet([JsonWebKey.import_key(key)
+                           for key in requests.get(jwksuri).json()["jwks"]])
+        except requests.ConnectionError as _connerr:
+            logging.debug(
+                "Could not connect to provided URI: %s", jwksuri, exc_info=True)
+        except JSONDecodeError as _jsonerr:
+            logging.debug(
+                "Could not convert response to JSON", exc_info=True)
+        except Exception as _exc:# pylint: disable=[broad-except]
+            logging.debug(
+                "Error retrieving the JWKs for the client.", exc_info=True)
+        return KeySet([])
 
-        return KeySet([
-            __parse_key__(Path(pth))
-            for pth in self.client_metadata.get("public_keys", [])])
 
     def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool:
         """