diff options
Diffstat (limited to 'uploader/oauth2')
| -rw-r--r-- | uploader/oauth2/client.py | 142 | ||||
| -rw-r--r-- | uploader/oauth2/tokens.py | 47 | ||||
| -rw-r--r-- | uploader/oauth2/views.py | 62 |
3 files changed, 157 insertions, 94 deletions
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py index a3e4ba3..b94a044 100644 --- a/uploader/oauth2/client.py +++ b/uploader/oauth2/client.py @@ -1,5 +1,8 @@ """OAuth2 client utilities.""" import json +import time +import uuid +import random from datetime import datetime, timedelta from urllib.parse import urljoin, urlparse @@ -8,10 +11,9 @@ from flask import request, current_app as app from pymonad.either import Left, Right, Either -from authlib.jose import jwt from authlib.common.urls import url_decode -from authlib.jose import KeySet, JsonWebKey from authlib.jose.errors import BadSignatureError +from authlib.jose import KeySet, JsonWebKey, JsonWebToken from authlib.integrations.requests_client import OAuth2Session from uploader import session @@ -36,49 +38,42 @@ def oauth2_clientsecret(): return app.config["OAUTH2_CLIENT_SECRET"] -def __make_token_validator__(keys: KeySet): - """Make a token validator function.""" - def __validator__(token: dict): - for key in keys.keys: - try: - jwt.decode(token["access_token"], key) - return Right(token) - except BadSignatureError: - pass - - return Left("INVALID-TOKEN") - - return __validator__ - - -def __validate_token__(sess_info): - """Validate that the token is really from the auth server.""" - info = sess_info - info["user"]["token"] = info["user"]["token"].then(__make_token_validator__( - KeySet([JsonWebKey.import_key(key) for key in info.get( - "auth_server_jwks", {}).get( - "jwks", {"keys": []})["keys"]]))) - return session.save_session_info(info) - - -def __update_auth_server_jwks__(sess_info): - """Updates the JWKs every 2 hours or so.""" - jwks = sess_info.get("auth_server_jwks") - if bool(jwks): - last_updated = jwks.get("last-updated") - now = datetime.now().timestamp() - if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds: - return __validate_token__({**sess_info, "auth_server_jwks": jwks}) - - jwksuri = urljoin(authserver_uri(), "auth/public-jwks") - return __validate_token__({ - **sess_info, - "auth_server_jwks": { +def __fetch_auth_server_jwks__() -> KeySet: + """Fetch the JWKs from the auth server.""" + return KeySet([ + JsonWebKey.import_key(key) + for key in requests.get( + urljoin(authserver_uri(), "auth/public-jwks"), + timeout=(9.13, 20) + ).json()["jwks"]]) + + +def __update_auth_server_jwks__(jwks) -> KeySet: + """Update the JWKs from the servers if necessary.""" + last_updated = jwks["last-updated"] + now = datetime.now().timestamp() + # Maybe the `two_hours` variable below can be made into a configuration + # variable and passed in to this function + two_hours = timedelta(hours=2).seconds + if bool(last_updated) and (now - last_updated) < two_hours: + return jwks["jwks"] + + return session.set_auth_server_jwks(__fetch_auth_server_jwks__()) + + +def auth_server_jwks() -> KeySet: + """Fetch the auth-server JSON Web Keys information.""" + _jwks = session.session_info().get("auth_server_jwks") or {} + if bool(_jwks): + return __update_auth_server_jwks__({ + "last-updated": _jwks["last-updated"], "jwks": KeySet([ - JsonWebKey.import_key(key) - for key in requests.get(jwksuri).json()["jwks"]]).as_dict(), - "last-updated": datetime.now().timestamp() - } + JsonWebKey.import_key(key) for key in _jwks.get( + "jwks", {"keys": []})["keys"]]) + }) + + return __update_auth_server_jwks__({ + "last-updated": (datetime.now() - timedelta(hours=3)).timestamp() }) @@ -111,15 +106,66 @@ def oauth2_client(): ("client_secret_post", __json_auth__)) return client - __update_auth_server_jwks__(session.session_info()) - return session.user_token().either( + def __token_expired__(token): + """Check whether the token has expired.""" + jwks = auth_server_jwks() + if bool(jwks): + for jwk in jwks.keys: + try: + jwt = JsonWebToken(["RS256"]).decode( + token["access_token"], key=jwk) + if bool(jwt.get("exp")): + return datetime.now().timestamp() > jwt["exp"] + except BadSignatureError as _bse: + pass + + return False + + def __delay__(): + """Do a tiny delay.""" + time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100)))) + + def __refresh_token__(token): + """Refresh the token if necessary — synchronise amongst threads.""" + if __token_expired__(token): + __delay__() + if session.is_token_refreshing(): + while session.is_token_refreshing(): + __delay__() + + return session.user_token().either(None, lambda _tok: _tok) + + session.toggle_token_refreshing() + _client = __client__(token) + _client.get(urljoin(authserver_uri(), "auth/user/")) + session.toggle_token_refreshing() + return _client.token + + return token + + return session.user_token().then(__refresh_token__).either( lambda _notok: __client__(None), __client__) +def fetch_user_details() -> Either: + """Retrieve user details from the auth server""" + suser = session.session_info()["user"] + if suser["email"] == "anon@ymous.user": + udets = oauth2_get("auth/user/").then( + lambda usrdets: session.set_user_details({ + "user_id": uuid.UUID(usrdets["user_id"]), + "name": usrdets["name"], + "email": usrdets["email"], + "token": session.user_token()})) + return udets + return Right(suser) + + def user_logged_in(): """Check whether the user has logged in.""" suser = session.session_info()["user"] + fetch_user_details() return suser["logged_in"] and suser["token"].is_right() @@ -163,7 +209,7 @@ def oauth2_get(url, **kwargs) -> Either: return Right(resp.json()) return Left(resp) except Exception as exc:#pylint: disable=[broad-except] - app.logger.error("Error retriving data from auth server: (GET %s)", + app.logger.error("Error retrieving data from auth server: (GET %s)", _uri, exc_info=True) return Left(exc) @@ -195,7 +241,7 @@ def oauth2_post(url, data=None, json=None, **kwargs):#pylint: disable=[redefined return Right(resp.json()) return Left(resp) except Exception as exc:#pylint: disable=[broad-except] - app.logger.error("Error retriving data from auth server: (POST %s)", + app.logger.error("Error retrieving data from auth server: (POST %s)", _uri, exc_info=True) return Left(exc) diff --git a/uploader/oauth2/tokens.py b/uploader/oauth2/tokens.py new file mode 100644 index 0000000..eb650f6 --- /dev/null +++ b/uploader/oauth2/tokens.py @@ -0,0 +1,47 @@ +"""Utilities for dealing with tokens.""" +import uuid +from typing import Union +from urllib.parse import urljoin +from datetime import datetime, timedelta + +from authlib.jose import jwt +from flask import current_app as app + +from uploader import monadic_requests as mrequests + +from . import jwks +from .client import (SCOPE, authserver_uri, oauth2_clientid) + + +def request_token(token_uri: str, user_id: Union[uuid.UUID, str], **kwargs): + """Request token from the auth server.""" + issued = datetime.now() + jwtkey = jwks.newest_jwk_with_rotation( + jwks.jwks_directory(app, "UPLOADER_SECRETS"), + int(app.config["JWKS_ROTATION_AGE_DAYS"])) + _mins2expiry = kwargs.get("minutes_to_expiry", 5) + return mrequests.post( + token_uri, + json={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": kwargs.get("scope", SCOPE), + "assertion": jwt.encode( + header={ + "alg": "RS256", + "typ": "JWT", + "kid": jwtkey.as_dict()["kid"] + }, + payload={ + "iss": str(oauth2_clientid()), + "sub": str(user_id), + "aud": urljoin(authserver_uri(), "auth/token"), + "exp": (issued + timedelta(minutes=_mins2expiry)).timestamp(), + "nbf": int(issued.timestamp()), + "iat": int(issued.timestamp()), + "jti": str(uuid.uuid4()) + }, + key=jwtkey).decode("utf8"), + "client_id": oauth2_clientid(), + **kwargs.get("extra_params", {}) + } + ) diff --git a/uploader/oauth2/views.py b/uploader/oauth2/views.py index 61037f3..05f8542 100644 --- a/uploader/oauth2/views.py +++ b/uploader/oauth2/views.py @@ -1,45 +1,43 @@ """Views for OAuth2 related functionality.""" -import uuid -from datetime import datetime, timedelta from urllib.parse import urljoin, urlparse, urlunparse -from authlib.jose import jwt from flask import ( flash, jsonify, - url_for, request, redirect, Blueprint, current_app as app) from uploader import session +from uploader.flask_extensions import url_for from uploader import monadic_requests as mrequests from uploader.monadic_requests import make_error_handler from . import jwks +from .tokens import request_token from .client import ( - SCOPE, - oauth2_get, user_logged_in, authserver_uri, oauth2_clientid, + fetch_user_details, oauth2_clientsecret) oauth2 = Blueprint("oauth2", __name__) + @oauth2.route("/code") def authorisation_code(): """Receive authorisation code from auth server and use it to get token.""" - def __process_error__(resp_or_exception): - app.logger.debug("ERROR: (%s)", resp_or_exception) + def __process_error__(error_response): + app.logger.debug("ERROR: (%s)", error_response.content) flash("There was an error retrieving the authorisation token.", - "alert-danger") + "alert alert-danger") return redirect("/") def __fail_set_user_details__(_failure): app.logger.debug("Fetching user details fails: %s", _failure) - flash("Could not retrieve the user details", "alert-danger") + flash("Could not retrieve the user details", "alert alert-danger") return redirect("/") def __success_set_user_details__(_success): @@ -48,52 +46,24 @@ def authorisation_code(): def __success__(token): session.set_user_token(token) - return oauth2_get("auth/user/").then( - lambda usrdets: session.set_user_details({ - "user_id": uuid.UUID(usrdets["user_id"]), - "name": usrdets["name"], - "email": usrdets["email"], - "token": session.user_token(), - "logged_in": True})).either( + return fetch_user_details().either( __fail_set_user_details__, __success_set_user_details__) code = request.args.get("code", "").strip() if not bool(code): - flash("AuthorisationError: No code was provided.", "alert-danger") + flash("AuthorisationError: No code was provided.", "alert alert-danger") return redirect("/") baseurl = urlparse(request.base_url, scheme=request.scheme) - issued = datetime.now() - jwtkey = jwks.newest_jwk_with_rotation( - jwks.jwks_directory(app, "UPLOADER_SECRETS"), - int(app.config["JWKS_ROTATION_AGE_DAYS"])) - return mrequests.post( - urljoin(authserver_uri(), "auth/token"), - json={ - "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + return request_token( + token_uri=urljoin(authserver_uri(), "auth/token"), + user_id=request.args["user_id"], + extra_params={ "code": code, - "scope": SCOPE, "redirect_uri": urljoin( urlunparse(baseurl), url_for("oauth2.authorisation_code")), - "assertion": jwt.encode( - header={ - "alg": "RS256", - "typ": "JWT", - "kid": jwtkey.as_dict()["kid"] - }, - payload={ - "iss": str(oauth2_clientid()), - "sub": request.args["user_id"], - "aud": urljoin(authserver_uri(),"auth/token"), - "exp": (issued + timedelta(minutes=5)).timestamp(), - "nbf": int(issued.timestamp()), - "iat": int(issued.timestamp()), - "jti": str(uuid.uuid4()) - }, - key=jwtkey).decode("utf8"), - "client_id": oauth2_clientid() }).either(__process_error__, __success__) @oauth2.route("/public-jwks") @@ -116,7 +86,7 @@ def logout(): _user = session_info["user"] _user_str = f"{_user['name']} ({_user['email']})" session.clear_session_info() - flash("Successfully logged out.", "alert-success") + flash("Successfully signed out.", "alert alert-success") return redirect("/") if user_logged_in(): @@ -134,5 +104,5 @@ def logout(): cleanup_thunk=lambda: __unset_session__( session.session_info())), lambda res: __unset_session__(session.session_info())) - flash("There is no user that is currently logged in.", "alert-info") + flash("There is no user that is currently logged in.", "alert alert-info") return redirect("/") |
