aboutsummaryrefslogtreecommitdiff
path: root/gn2/wqflask/oauth2/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn2/wqflask/oauth2/client.py')
-rw-r--r--gn2/wqflask/oauth2/client.py157
1 files changed, 140 insertions, 17 deletions
diff --git a/gn2/wqflask/oauth2/client.py b/gn2/wqflask/oauth2/client.py
index 75e438cc..a7d20f6b 100644
--- a/gn2/wqflask/oauth2/client.py
+++ b/gn2/wqflask/oauth2/client.py
@@ -1,12 +1,17 @@
"""Common oauth2 client utilities."""
import json
+import time
+import random
import requests
from typing import Optional
from urllib.parse import urljoin
+from datetime import datetime, timedelta
from flask import current_app as app
from pymonad.either import Left, Right, Either
-from authlib.jose import jwt
+from authlib.common.urls import url_decode
+from authlib.jose.errors import BadSignatureError
+from authlib.jose import KeySet, JsonWebKey, JsonWebToken
from authlib.integrations.requests_client import OAuth2Session
from gn2.wqflask.oauth2 import session
@@ -31,11 +36,76 @@ def oauth2_clientsecret():
def user_logged_in():
"""Check whether the user has logged in."""
suser = session.session_info()["user"]
- if suser["logged_in"]:
- if session.expired():
- session.clear_session_info()
- return False
- return suser["token"].is_right()
+ return suser["logged_in"] and suser["token"].is_right()
+
+
+def __make_token_validator__(keys: KeySet):
+ """Make a token validator function."""
+ def __validator__(token: dict):
+ for key in keys.keys:
+ try:
+ # Fixes CVE-2016-10555. See
+ # https://docs.authlib.org/en/latest/jose/jwt.html
+ jwt = JsonWebToken(["RS256"])
+ jwt.decode(token["access_token"], key)
+ return Right(token)
+ except BadSignatureError:
+ pass
+
+ return Left("INVALID-TOKEN")
+
+ return __validator__
+
+
+def auth_server_jwks() -> Optional[KeySet]:
+ """Fetch the auth-server JSON Web Keys information."""
+ _jwks = session.session_info().get("auth_server_jwks")
+ if bool(_jwks):
+ return {
+ "last-updated": _jwks["last-updated"],
+ "jwks": KeySet([
+ JsonWebKey.import_key(key) for key in _jwks.get(
+ "auth_server_jwks", {}).get(
+ "jwks", {"keys": []})["keys"]])}
+
+
+def __validate_token__(keys):
+ """Validate that the token is really from the auth server."""
+ def __index__(_sess):
+ return _sess
+ return session.user_token().then(__make_token_validator__(keys)).then(
+ session.set_user_token).either(__index__, __index__)
+
+
+def __update_auth_server_jwks__():
+ """Updates the JWKs every 2 hours or so."""
+ jwks = auth_server_jwks()
+ if bool(jwks):
+ last_updated = jwks.get("last-updated")
+ now = datetime.now().timestamp()
+ if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds:
+ return __validate_token__(jwks["jwks"])
+
+ jwksuri = urljoin(authserver_uri(), "auth/public-jwks")
+ jwks = KeySet([
+ JsonWebKey.import_key(key)
+ for key in requests.get(jwksuri).json()["jwks"]])
+ return __validate_token__(jwks)
+
+
+def is_token_expired(token):
+ """Check whether the token has expired."""
+ __update_auth_server_jwks__()
+ jwks = auth_server_jwks()
+ if bool(jwks):
+ for jwk in jwks["jwks"].keys:
+ try:
+ jwt = JsonWebToken(["RS256"]).decode(
+ token["access_token"], key=jwk)
+ return datetime.now().timestamp() > jwt["exp"]
+ except BadSignatureError as _bse:
+ pass
+
return False
@@ -43,20 +113,56 @@ def oauth2_client():
def __update_token__(token, refresh_token=None, access_token=None):
"""Update the token when refreshed."""
session.set_user_token(token)
+ return token
+
+ def __delay__():
+ """Do a tiny delay."""
+ time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
+
+ def __refresh_token__(token):
+ """Synchronise token refresh."""
+ if is_token_expired(token):
+ __delay__()
+ if session.is_token_refreshing():
+ while session.is_token_refreshing():
+ __delay__()
+
+ _token = session.user_token().either(None, lambda _tok: _tok)
+ return _token
+
+ session.toggle_token_refreshing()
+ _client = __client__(token)
+ _client.get(urljoin(authserver_uri(), "auth/user/"))
+ session.toggle_token_refreshing()
+ return _client.token
+
+ return token
+
+ def __json_auth__(client, method, uri, headers, body):
+ return (
+ uri,
+ {**headers, "Content-Type": "application/json"},
+ json.dumps({
+ **dict(url_decode(body)),
+ "client_id": oauth2_clientid(),
+ "client_secret": oauth2_clientsecret()
+ }))
def __client__(token) -> OAuth2Session:
- _jwt = jwt.decode(token["access_token"],
- app.config["AUTH_SERVER_SSL_PUBLIC_KEY"])
client = OAuth2Session(
oauth2_clientid(),
oauth2_clientsecret(),
scope=SCOPE,
- token_endpoint=urljoin(authserver_uri(), "/auth/token"),
+ token_endpoint=urljoin(authserver_uri(), "auth/token"),
token_endpoint_auth_method="client_secret_post",
token=token,
update_token=__update_token__)
+ client.register_client_auth_method(
+ ("client_secret_post", __json_auth__))
return client
- return session.user_token().either(
+
+ __update_auth_server_jwks__()
+ return session.user_token().then(__refresh_token__).either(
lambda _notok: __client__(None),
lambda token: __client__(token))
@@ -70,12 +176,18 @@ def __no_token__(_err) -> Left:
resp.status_code = 400
return Left(resp)
-def oauth2_get(uri_path: str, data: dict = {},
- jsonify_p: bool = False, **kwargs) -> Either:
+def oauth2_get(
+ uri_path: str,
+ data: dict = {},
+ jsonify_p: bool = False,
+ headers: dict = {"Content-Type": "application/json"},
+ **kwargs
+) -> Either:
def __get__(token) -> Either:
resp = oauth2_client().get(
urljoin(authserver_uri(), uri_path),
data=data,
+ headers=headers,
**kwargs)
if resp.status_code == 200:
if jsonify_p:
@@ -87,11 +199,18 @@ def oauth2_get(uri_path: str, data: dict = {},
return session.user_token().either(__no_token__, __get__)
def oauth2_post(
- uri_path: str, data: Optional[dict] = None, json: Optional[dict] = None,
- **kwargs) -> Either:
+ uri_path: str,
+ data: Optional[dict] = None,
+ json: Optional[dict] = None,
+ headers: dict = {"Content-Type": "application/json"},
+ **kwargs
+) -> Either:
def __post__(token) -> Either:
resp = oauth2_client().post(
- urljoin(authserver_uri(), uri_path), data=data, json=json,
+ urljoin(authserver_uri(), uri_path),
+ data=data,
+ json=json,
+ headers=headers,
**kwargs)
if resp.status_code == 200:
return Right(resp.json())
@@ -100,10 +219,14 @@ def oauth2_post(
return session.user_token().either(__no_token__, __post__)
-def no_token_get(uri_path: str, **kwargs) -> Either:
+def no_token_get(
+ uri_path: str,
+ headers: dict = {"Content-Type": "application/json"},
+ **kwargs
+) -> Either:
uri = urljoin(authserver_uri(), uri_path)
try:
- resp = requests.get(uri, **kwargs)
+ resp = requests.get(uri, headers=headers, **kwargs)
if resp.status_code == 200:
return Right(resp.json())
return Left(resp)