about summary refs log tree commit diff
path: root/uploader/oauth2
diff options
context:
space:
mode:
Diffstat (limited to 'uploader/oauth2')
-rw-r--r--uploader/oauth2/client.py142
-rw-r--r--uploader/oauth2/tokens.py47
-rw-r--r--uploader/oauth2/views.py62
3 files changed, 157 insertions, 94 deletions
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py
index a3e4ba3..b94a044 100644
--- a/uploader/oauth2/client.py
+++ b/uploader/oauth2/client.py
@@ -1,5 +1,8 @@
 """OAuth2 client utilities."""
 import json
+import time
+import uuid
+import random
 from datetime import datetime, timedelta
 from urllib.parse import urljoin, urlparse
 
@@ -8,10 +11,9 @@ from flask import request, current_app as app
 
 from pymonad.either import Left, Right, Either
 
-from authlib.jose import jwt
 from authlib.common.urls import url_decode
-from authlib.jose import KeySet, JsonWebKey
 from authlib.jose.errors import BadSignatureError
+from authlib.jose import KeySet, JsonWebKey, JsonWebToken
 from authlib.integrations.requests_client import OAuth2Session
 
 from uploader import session
@@ -36,49 +38,42 @@ def oauth2_clientsecret():
     return app.config["OAUTH2_CLIENT_SECRET"]
 
 
-def __make_token_validator__(keys: KeySet):
-    """Make a token validator function."""
-    def __validator__(token: dict):
-        for key in keys.keys:
-            try:
-                jwt.decode(token["access_token"], key)
-                return Right(token)
-            except BadSignatureError:
-                pass
-
-        return Left("INVALID-TOKEN")
-
-    return __validator__
-
-
-def __validate_token__(sess_info):
-    """Validate that the token is really from the auth server."""
-    info = sess_info
-    info["user"]["token"] = info["user"]["token"].then(__make_token_validator__(
-        KeySet([JsonWebKey.import_key(key) for key in info.get(
-            "auth_server_jwks", {}).get(
-                "jwks", {"keys": []})["keys"]])))
-    return session.save_session_info(info)
-
-
-def __update_auth_server_jwks__(sess_info):
-    """Updates the JWKs every 2 hours or so."""
-    jwks = sess_info.get("auth_server_jwks")
-    if bool(jwks):
-        last_updated = jwks.get("last-updated")
-        now = datetime.now().timestamp()
-        if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds:
-            return __validate_token__({**sess_info, "auth_server_jwks": jwks})
-
-    jwksuri = urljoin(authserver_uri(), "auth/public-jwks")
-    return __validate_token__({
-        **sess_info,
-        "auth_server_jwks": {
+def __fetch_auth_server_jwks__() -> KeySet:
+    """Fetch the JWKs from the auth server."""
+    return KeySet([
+        JsonWebKey.import_key(key)
+        for key in requests.get(
+                urljoin(authserver_uri(), "auth/public-jwks"),
+                timeout=(9.13, 20)
+        ).json()["jwks"]])
+
+
+def __update_auth_server_jwks__(jwks) -> KeySet:
+    """Update the JWKs from the servers if necessary."""
+    last_updated = jwks["last-updated"]
+    now = datetime.now().timestamp()
+    # Maybe the `two_hours` variable below can be made into a configuration
+    # variable and passed in to this function
+    two_hours = timedelta(hours=2).seconds
+    if bool(last_updated) and (now - last_updated) < two_hours:
+        return jwks["jwks"]
+
+    return session.set_auth_server_jwks(__fetch_auth_server_jwks__())
+
+
+def auth_server_jwks() -> KeySet:
+    """Fetch the auth-server JSON Web Keys information."""
+    _jwks = session.session_info().get("auth_server_jwks") or {}
+    if bool(_jwks):
+        return __update_auth_server_jwks__({
+            "last-updated": _jwks["last-updated"],
             "jwks": KeySet([
-                JsonWebKey.import_key(key)
-                for key in requests.get(jwksuri).json()["jwks"]]).as_dict(),
-            "last-updated": datetime.now().timestamp()
-        }
+                JsonWebKey.import_key(key) for key in _jwks.get(
+                        "jwks", {"keys": []})["keys"]])
+        })
+
+    return __update_auth_server_jwks__({
+        "last-updated": (datetime.now() - timedelta(hours=3)).timestamp()
     })
 
 
@@ -111,15 +106,66 @@ def oauth2_client():
             ("client_secret_post", __json_auth__))
         return client
 
-    __update_auth_server_jwks__(session.session_info())
-    return session.user_token().either(
+    def __token_expired__(token):
+        """Check whether the token has expired."""
+        jwks = auth_server_jwks()
+        if bool(jwks):
+            for jwk in jwks.keys:
+                try:
+                    jwt = JsonWebToken(["RS256"]).decode(
+                        token["access_token"], key=jwk)
+                    if bool(jwt.get("exp")):
+                        return datetime.now().timestamp() > jwt["exp"]
+                except BadSignatureError as _bse:
+                    pass
+
+        return False
+
+    def __delay__():
+        """Do a tiny delay."""
+        time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
+
+    def __refresh_token__(token):
+        """Refresh the token if necessary — synchronise amongst threads."""
+        if __token_expired__(token):
+            __delay__()
+            if session.is_token_refreshing():
+                while session.is_token_refreshing():
+                    __delay__()
+
+                return session.user_token().either(None, lambda _tok: _tok)
+
+            session.toggle_token_refreshing()
+            _client = __client__(token)
+            _client.get(urljoin(authserver_uri(), "auth/user/"))
+            session.toggle_token_refreshing()
+            return _client.token
+
+        return token
+
+    return session.user_token().then(__refresh_token__).either(
         lambda _notok: __client__(None),
         __client__)
 
 
+def fetch_user_details() -> Either:
+    """Retrieve user details from the auth server"""
+    suser = session.session_info()["user"]
+    if suser["email"] == "anon@ymous.user":
+        udets = oauth2_get("auth/user/").then(
+            lambda usrdets: session.set_user_details({
+                "user_id": uuid.UUID(usrdets["user_id"]),
+                "name": usrdets["name"],
+                "email": usrdets["email"],
+                "token": session.user_token()}))
+        return udets
+    return Right(suser)
+
+
 def user_logged_in():
     """Check whether the user has logged in."""
     suser = session.session_info()["user"]
+    fetch_user_details()
     return suser["logged_in"] and suser["token"].is_right()
 
 
@@ -163,7 +209,7 @@ def oauth2_get(url, **kwargs) -> Either:
                 return Right(resp.json())
             return Left(resp)
         except Exception as exc:#pylint: disable=[broad-except]
-            app.logger.error("Error retriving data from auth server: (GET %s)",
+            app.logger.error("Error retrieving data from auth server: (GET %s)",
                              _uri,
                              exc_info=True)
             return Left(exc)
@@ -195,7 +241,7 @@ def oauth2_post(url, data=None, json=None, **kwargs):#pylint: disable=[redefined
                 return Right(resp.json())
             return Left(resp)
         except Exception as exc:#pylint: disable=[broad-except]
-            app.logger.error("Error retriving data from auth server: (POST %s)",
+            app.logger.error("Error retrieving data from auth server: (POST %s)",
                              _uri,
                              exc_info=True)
             return Left(exc)
diff --git a/uploader/oauth2/tokens.py b/uploader/oauth2/tokens.py
new file mode 100644
index 0000000..eb650f6
--- /dev/null
+++ b/uploader/oauth2/tokens.py
@@ -0,0 +1,47 @@
+"""Utilities for dealing with tokens."""
+import uuid
+from typing import Union
+from urllib.parse import urljoin
+from datetime import datetime, timedelta
+
+from authlib.jose import jwt
+from flask import current_app as app
+
+from uploader import monadic_requests as mrequests
+
+from . import jwks
+from .client import (SCOPE, authserver_uri, oauth2_clientid)
+
+
+def request_token(token_uri: str, user_id: Union[uuid.UUID, str], **kwargs):
+    """Request token from the auth server."""
+    issued = datetime.now()
+    jwtkey = jwks.newest_jwk_with_rotation(
+        jwks.jwks_directory(app, "UPLOADER_SECRETS"),
+        int(app.config["JWKS_ROTATION_AGE_DAYS"]))
+    _mins2expiry = kwargs.get("minutes_to_expiry", 5)
+    return mrequests.post(
+        token_uri,
+        json={
+            "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
+            "scope": kwargs.get("scope", SCOPE),
+            "assertion": jwt.encode(
+                header={
+                    "alg": "RS256",
+                    "typ": "JWT",
+                    "kid": jwtkey.as_dict()["kid"]
+                },
+                payload={
+                    "iss": str(oauth2_clientid()),
+                    "sub": str(user_id),
+                    "aud": urljoin(authserver_uri(), "auth/token"),
+                    "exp": (issued + timedelta(minutes=_mins2expiry)).timestamp(),
+                    "nbf": int(issued.timestamp()),
+                    "iat": int(issued.timestamp()),
+                    "jti": str(uuid.uuid4())
+                },
+                key=jwtkey).decode("utf8"),
+            "client_id": oauth2_clientid(),
+            **kwargs.get("extra_params", {})
+        }
+    )
diff --git a/uploader/oauth2/views.py b/uploader/oauth2/views.py
index 61037f3..05f8542 100644
--- a/uploader/oauth2/views.py
+++ b/uploader/oauth2/views.py
@@ -1,45 +1,43 @@
 """Views for OAuth2 related functionality."""
-import uuid
-from datetime import datetime, timedelta
 from urllib.parse import urljoin, urlparse, urlunparse
 
-from authlib.jose import jwt
 from flask import (
     flash,
     jsonify,
-    url_for,
     request,
     redirect,
     Blueprint,
     current_app as app)
 
 from uploader import session
+from uploader.flask_extensions import url_for
 from uploader import monadic_requests as mrequests
 from uploader.monadic_requests import make_error_handler
 
 from . import jwks
+from .tokens import request_token
 from .client import (
-    SCOPE,
-    oauth2_get,
     user_logged_in,
     authserver_uri,
     oauth2_clientid,
+    fetch_user_details,
     oauth2_clientsecret)
 
 oauth2 = Blueprint("oauth2", __name__)
 
+
 @oauth2.route("/code")
 def authorisation_code():
     """Receive authorisation code from auth server and use it to get token."""
-    def __process_error__(resp_or_exception):
-        app.logger.debug("ERROR: (%s)", resp_or_exception)
+    def __process_error__(error_response):
+        app.logger.debug("ERROR: (%s)", error_response.content)
         flash("There was an error retrieving the authorisation token.",
-              "alert-danger")
+              "alert alert-danger")
         return redirect("/")
 
     def __fail_set_user_details__(_failure):
         app.logger.debug("Fetching user details fails: %s", _failure)
-        flash("Could not retrieve the user details", "alert-danger")
+        flash("Could not retrieve the user details", "alert alert-danger")
         return redirect("/")
 
     def __success_set_user_details__(_success):
@@ -48,52 +46,24 @@ def authorisation_code():
 
     def __success__(token):
         session.set_user_token(token)
-        return oauth2_get("auth/user/").then(
-            lambda usrdets: session.set_user_details({
-                "user_id": uuid.UUID(usrdets["user_id"]),
-                "name": usrdets["name"],
-                "email": usrdets["email"],
-                "token": session.user_token(),
-                "logged_in": True})).either(
+        return fetch_user_details().either(
                     __fail_set_user_details__,
                     __success_set_user_details__)
 
     code = request.args.get("code", "").strip()
     if not bool(code):
-        flash("AuthorisationError: No code was provided.", "alert-danger")
+        flash("AuthorisationError: No code was provided.", "alert alert-danger")
         return redirect("/")
 
     baseurl = urlparse(request.base_url, scheme=request.scheme)
-    issued = datetime.now()
-    jwtkey = jwks.newest_jwk_with_rotation(
-        jwks.jwks_directory(app, "UPLOADER_SECRETS"),
-        int(app.config["JWKS_ROTATION_AGE_DAYS"]))
-    return mrequests.post(
-        urljoin(authserver_uri(), "auth/token"),
-        json={
-            "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
+    return request_token(
+        token_uri=urljoin(authserver_uri(), "auth/token"),
+        user_id=request.args["user_id"],
+        extra_params={
             "code": code,
-            "scope": SCOPE,
             "redirect_uri": urljoin(
                 urlunparse(baseurl),
                 url_for("oauth2.authorisation_code")),
-            "assertion": jwt.encode(
-                header={
-                    "alg": "RS256",
-                    "typ": "JWT",
-                    "kid": jwtkey.as_dict()["kid"]
-                },
-                payload={
-                    "iss": str(oauth2_clientid()),
-                    "sub": request.args["user_id"],
-                    "aud": urljoin(authserver_uri(),"auth/token"),
-                    "exp": (issued + timedelta(minutes=5)).timestamp(),
-                    "nbf": int(issued.timestamp()),
-                    "iat": int(issued.timestamp()),
-                    "jti": str(uuid.uuid4())
-                },
-                key=jwtkey).decode("utf8"),
-            "client_id": oauth2_clientid()
         }).either(__process_error__, __success__)
 
 @oauth2.route("/public-jwks")
@@ -116,7 +86,7 @@ def logout():
         _user = session_info["user"]
         _user_str = f"{_user['name']} ({_user['email']})"
         session.clear_session_info()
-        flash("Successfully logged out.", "alert-success")
+        flash("Successfully signed out.", "alert alert-success")
         return redirect("/")
 
     if user_logged_in():
@@ -134,5 +104,5 @@ def logout():
                         cleanup_thunk=lambda: __unset_session__(
                             session.session_info())),
                     lambda res: __unset_session__(session.session_info()))
-    flash("There is no user that is currently logged in.", "alert-info")
+    flash("There is no user that is currently logged in.", "alert alert-info")
     return redirect("/")