aboutsummaryrefslogtreecommitdiff
path: root/gn2/wqflask/oauth2
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-08-01 12:19:34 -0500
committerFrederick Muriuki Muriithi2024-08-01 15:02:17 -0500
commit20933e55de7927063dd159d116b468b4724a19a8 (patch)
treef8df653cf1ad4ebe590f96d66d4da1e14fe8f3f3 /gn2/wqflask/oauth2
parent6095b4556bf47a29074be76a72d99681e263d3db (diff)
downloadgenenetwork2-20933e55de7927063dd159d116b468b4724a19a8.tar.gz
Use JWKs from auth server public endpoint
* Fetch keys from auth server * Validate token is signed with one of the keys from server * Ensure refreshing of token is still synchronised
Diffstat (limited to 'gn2/wqflask/oauth2')
-rw-r--r--gn2/wqflask/oauth2/client.py95
-rw-r--r--gn2/wqflask/oauth2/session.py8
2 files changed, 83 insertions, 20 deletions
diff --git a/gn2/wqflask/oauth2/client.py b/gn2/wqflask/oauth2/client.py
index 0d4615e8..6f137f52 100644
--- a/gn2/wqflask/oauth2/client.py
+++ b/gn2/wqflask/oauth2/client.py
@@ -5,10 +5,12 @@ import random
import requests
from typing import Optional
from urllib.parse import urljoin
+from datetime import datetime, timedelta
from flask import current_app as app
from pymonad.either import Left, Right, Either
-from authlib.jose import jwt
+from authlib.jose import KeySet, JsonWebKey, JsonWebToken
+from authlib.jose.errors import BadSignatureError
from authlib.integrations.requests_client import OAuth2Session
from gn2.wqflask.oauth2 import session
@@ -36,30 +38,96 @@ def user_logged_in():
return suser["logged_in"] and suser["token"].is_right()
+def __make_token_validator__(keys: KeySet):
+ """Make a token validator function."""
+ def __validator__(token: dict):
+ for key in keys.keys:
+ try:
+ # Fixes CVE-2016-10555. See
+ # https://docs.authlib.org/en/latest/jose/jwt.html
+ jwt = JsonWebToken(["RS256"])
+ jwt.decode(token["access_token"], key)
+ return Right(token)
+ except BadSignatureError:
+ pass
+
+ return Left("INVALID-TOKEN")
+
+ return __validator__
+
+
+def auth_server_jwks() -> Optional[KeySet]:
+ """Fetch the auth-server JSON Web Keys information."""
+ _jwks = session.session_info().get("auth_server_jwks")
+ if bool(_jwks):
+ return {
+ "last-updated": _jwks["last-updated"],
+ "jwks": KeySet([
+ JsonWebKey.import_key(key) for key in _jwks.get(
+ "auth_server_jwks", {}).get(
+ "jwks", {"keys": []})["keys"]])}
+
+
+def __validate_token__(keys):
+ """Validate that the token is really from the auth server."""
+ def __index__(_sess):
+ return _sess
+ return session.user_token().then(__make_token_validator__(keys)).then(
+ session.set_user_token).either(__index__, __index__)
+
+
+def __update_auth_server_jwks__():
+ """Updates the JWKs every 2 hours or so."""
+ jwks = auth_server_jwks()
+ if bool(jwks):
+ last_updated = jwks.get("last-updated")
+ now = datetime.now().timestamp()
+ if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds:
+ return __validate_token__(jwks["jwks"])
+
+ jwksuri = urljoin(authserver_uri(), "auth/public-jwks")
+ jwks = KeySet([
+ JsonWebKey.import_key(key)
+ for key in requests.get(jwksuri).json()["jwks"]])
+ return __validate_token__(jwks)
+
+
+def is_token_expired(token):
+ """Check whether the token has expired."""
+ __update_auth_server_jwks__()
+ jwks = auth_server_jwks()
+ if bool(jwks):
+ for jwk in jwks["jwks"].keys:
+ try:
+ jwt = JsonWebToken(["RS256"]).decode(
+ token["access_token"], key=jwk)
+ return datetime.now().timestamp() > jwt["exp"]
+ except BadSignatureError as _bse:
+ pass
+
+ return False
+
+
def oauth2_client():
def __update_token__(token, refresh_token=None, access_token=None):
"""Update the token when refreshed."""
session.set_user_token(token)
return token
- def __validate_token__(token):
- _jwt = jwt.decode(token["access_token"],
- app.config["AUTH_SERVER_SSL_PUBLIC_KEY"])
- return token
-
def __delay__():
"""Do a tiny delay."""
time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
def __refresh_token__(token):
"""Synchronise token refresh."""
- if session.is_token_expired():
+ if is_token_expired(token):
__delay__()
if session.is_token_refreshing():
while session.is_token_refreshing():
__delay__()
- _token = session.user_token().either(None, lambda _tok: _tok)
- return _token
+
+ _token = session.user_token().either(None, lambda _tok: _tok)
+ return _token
session.toggle_token_refreshing()
_client = __client__(token)
@@ -79,10 +147,11 @@ def oauth2_client():
token=token,
update_token=__update_token__)
return client
- return session.user_token().then(__validate_token__).then(
- __refresh_token__).either(
- lambda _notok: __client__(None),
- lambda token: __client__(token))
+
+ __update_auth_server_jwks__()
+ return session.user_token().then(__refresh_token__).either(
+ lambda _notok: __client__(None),
+ lambda token: __client__(token))
def __no_token__(_err) -> Left:
"""Handle situation where request is attempted with no token."""
diff --git a/gn2/wqflask/oauth2/session.py b/gn2/wqflask/oauth2/session.py
index 92181ccf..b91534b0 100644
--- a/gn2/wqflask/oauth2/session.py
+++ b/gn2/wqflask/oauth2/session.py
@@ -23,6 +23,7 @@ class SessionInfo(TypedDict):
ip_addr: str
masquerade: Optional[UserDetails]
refreshing_token: bool
+ auth_server_jwks: Optional[dict[str, Any]]
__SESSION_KEY__ = "GN::2::session_info" # Do not use this outside this module!!
@@ -114,13 +115,6 @@ def toggle_token_refreshing():
"token_refreshing": not _session.get("token_refreshing", False)})
-def is_token_expired():
- """Check whether the token is expired."""
- return user_token().either(
- lambda _no_token: False,
- lambda token: datetime.now().timestamp() > token["expires_at"])
-
-
def is_token_refreshing():
"""Returns whether the token is being refreshed or not."""
return session_info().get("token_refreshing", False)