aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--uploader/oauth2/client.py79
1 files changed, 34 insertions, 45 deletions
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py
index a3e4ba3..8ad8f7c 100644
--- a/uploader/oauth2/client.py
+++ b/uploader/oauth2/client.py
@@ -8,10 +8,8 @@ from flask import request, current_app as app
from pymonad.either import Left, Right, Either
-from authlib.jose import jwt
from authlib.common.urls import url_decode
from authlib.jose import KeySet, JsonWebKey
-from authlib.jose.errors import BadSignatureError
from authlib.integrations.requests_client import OAuth2Session
from uploader import session
@@ -36,49 +34,41 @@ def oauth2_clientsecret():
return app.config["OAUTH2_CLIENT_SECRET"]
-def __make_token_validator__(keys: KeySet):
- """Make a token validator function."""
- def __validator__(token: dict):
- for key in keys.keys:
- try:
- jwt.decode(token["access_token"], key)
- return Right(token)
- except BadSignatureError:
- pass
-
- return Left("INVALID-TOKEN")
-
- return __validator__
-
-
-def __validate_token__(sess_info):
- """Validate that the token is really from the auth server."""
- info = sess_info
- info["user"]["token"] = info["user"]["token"].then(__make_token_validator__(
- KeySet([JsonWebKey.import_key(key) for key in info.get(
- "auth_server_jwks", {}).get(
- "jwks", {"keys": []})["keys"]])))
- return session.save_session_info(info)
-
-
-def __update_auth_server_jwks__(sess_info):
- """Updates the JWKs every 2 hours or so."""
- jwks = sess_info.get("auth_server_jwks")
- if bool(jwks):
- last_updated = jwks.get("last-updated")
- now = datetime.now().timestamp()
- if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds:
- return __validate_token__({**sess_info, "auth_server_jwks": jwks})
-
- jwksuri = urljoin(authserver_uri(), "auth/public-jwks")
- return __validate_token__({
- **sess_info,
- "auth_server_jwks": {
+def __fetch_auth_server_jwks__() -> KeySet:
+ """Fetch the JWKs from the auth server."""
+ return KeySet([
+ JsonWebKey.import_key(key)
+ for key in requests.get(
+ urljoin(authserver_uri(), "auth/public-jwks")
+ ).json()["jwks"]])
+
+
+def __update_auth_server_jwks__(jwks) -> KeySet:
+ """Update the JWKs from the servers if necessary."""
+ last_updated = jwks["last-updated"]
+ now = datetime.now().timestamp()
+ # Maybe the `two_hours` variable below can be made into a configuration
+ # variable and passed in to this function
+ two_hours = timedelta(hours=2).seconds
+ if bool(last_updated) and (now - last_updated) < two_hours:
+ return jwks["jwks"]
+
+ return session.set_auth_server_jwks(__fetch_auth_server_jwks__())
+
+
+def auth_server_jwks() -> KeySet:
+ """Fetch the auth-server JSON Web Keys information."""
+ _jwks = session.session_info().get("auth_server_jwks")
+ if bool(_jwks):
+ return __update_auth_server_jwks__({
+ "last-updated": _jwks["last-updated"],
"jwks": KeySet([
- JsonWebKey.import_key(key)
- for key in requests.get(jwksuri).json()["jwks"]]).as_dict(),
- "last-updated": datetime.now().timestamp()
- }
+ JsonWebKey.import_key(key) for key in _jwks.get(
+ "jwks", {"keys": []})["keys"]])
+ })
+
+ return __update_auth_server_jwks__({
+ "last-updated": (datetime.now() - timedelta(hours=3)).timestamp()
})
@@ -111,7 +101,6 @@ def oauth2_client():
("client_secret_post", __json_auth__))
return client
- __update_auth_server_jwks__(session.session_info())
return session.user_token().either(
lambda _notok: __client__(None),
__client__)