about summary refs log tree commit diff
path: root/uploader/oauth2
diff options
context:
space:
mode:
Diffstat (limited to 'uploader/oauth2')
-rw-r--r--uploader/oauth2/__init__.py1
-rw-r--r--uploader/oauth2/client.py248
-rw-r--r--uploader/oauth2/jwks.py86
-rw-r--r--uploader/oauth2/tokens.py47
-rw-r--r--uploader/oauth2/views.py108
5 files changed, 490 insertions, 0 deletions
diff --git a/uploader/oauth2/__init__.py b/uploader/oauth2/__init__.py
new file mode 100644
index 0000000..aaea638
--- /dev/null
+++ b/uploader/oauth2/__init__.py
@@ -0,0 +1 @@
+"""Package to handle OAuth2 authentication/authorisation issues."""
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py
new file mode 100644
index 0000000..b94a044
--- /dev/null
+++ b/uploader/oauth2/client.py
@@ -0,0 +1,248 @@
+"""OAuth2 client utilities."""
+import json
+import time
+import uuid
+import random
+from datetime import datetime, timedelta
+from urllib.parse import urljoin, urlparse
+
+import requests
+from flask import request, current_app as app
+
+from pymonad.either import Left, Right, Either
+
+from authlib.common.urls import url_decode
+from authlib.jose.errors import BadSignatureError
+from authlib.jose import KeySet, JsonWebKey, JsonWebToken
+from authlib.integrations.requests_client import OAuth2Session
+
+from uploader import session
+import uploader.monadic_requests as mrequests
+
+SCOPE = ("profile group role resource register-client user masquerade "
+         "introspect migrate-data")
+
+
+def authserver_uri():
+    """Return URI to authorisation server."""
+    return app.config["AUTH_SERVER_URL"]
+
+
+def oauth2_clientid():
+    """Return the client id."""
+    return app.config["OAUTH2_CLIENT_ID"]
+
+
+def oauth2_clientsecret():
+    """Return the client secret."""
+    return app.config["OAUTH2_CLIENT_SECRET"]
+
+
+def __fetch_auth_server_jwks__() -> KeySet:
+    """Fetch the JWKs from the auth server."""
+    return KeySet([
+        JsonWebKey.import_key(key)
+        for key in requests.get(
+                urljoin(authserver_uri(), "auth/public-jwks"),
+                timeout=(9.13, 20)
+        ).json()["jwks"]])
+
+
+def __update_auth_server_jwks__(jwks) -> KeySet:
+    """Update the JWKs from the servers if necessary."""
+    last_updated = jwks["last-updated"]
+    now = datetime.now().timestamp()
+    # Maybe the `two_hours` variable below can be made into a configuration
+    # variable and passed in to this function
+    two_hours = timedelta(hours=2).seconds
+    if bool(last_updated) and (now - last_updated) < two_hours:
+        return jwks["jwks"]
+
+    return session.set_auth_server_jwks(__fetch_auth_server_jwks__())
+
+
+def auth_server_jwks() -> KeySet:
+    """Fetch the auth-server JSON Web Keys information."""
+    _jwks = session.session_info().get("auth_server_jwks") or {}
+    if bool(_jwks):
+        return __update_auth_server_jwks__({
+            "last-updated": _jwks["last-updated"],
+            "jwks": KeySet([
+                JsonWebKey.import_key(key) for key in _jwks.get(
+                        "jwks", {"keys": []})["keys"]])
+        })
+
+    return __update_auth_server_jwks__({
+        "last-updated": (datetime.now() - timedelta(hours=3)).timestamp()
+    })
+
+
+def oauth2_client():
+    """Build the OAuth2 client for use fetching data."""
+    def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument]
+        """Update the token when refreshed."""
+        session.set_user_token(token)
+
+    def __json_auth__(client, _method, uri, headers, body):
+        return (
+            uri,
+            {**headers, "Content-Type": "application/json"},
+            json.dumps({
+                **dict(url_decode(body)),
+                "client_id": client.client_id,
+                "client_secret": client.client_secret
+            }))
+
+    def __client__(token) -> OAuth2Session:
+        client = OAuth2Session(
+            oauth2_clientid(),
+            oauth2_clientsecret(),
+            scope=SCOPE,
+            token_endpoint=urljoin(authserver_uri(), "/auth/token"),
+            token_endpoint_auth_method="client_secret_post",
+            token=token,
+            update_token=__update_token__)
+        client.register_client_auth_method(
+            ("client_secret_post", __json_auth__))
+        return client
+
+    def __token_expired__(token):
+        """Check whether the token has expired."""
+        jwks = auth_server_jwks()
+        if bool(jwks):
+            for jwk in jwks.keys:
+                try:
+                    jwt = JsonWebToken(["RS256"]).decode(
+                        token["access_token"], key=jwk)
+                    if bool(jwt.get("exp")):
+                        return datetime.now().timestamp() > jwt["exp"]
+                except BadSignatureError as _bse:
+                    pass
+
+        return False
+
+    def __delay__():
+        """Do a tiny delay."""
+        time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
+
+    def __refresh_token__(token):
+        """Refresh the token if necessary — synchronise amongst threads."""
+        if __token_expired__(token):
+            __delay__()
+            if session.is_token_refreshing():
+                while session.is_token_refreshing():
+                    __delay__()
+
+                return session.user_token().either(None, lambda _tok: _tok)
+
+            session.toggle_token_refreshing()
+            _client = __client__(token)
+            _client.get(urljoin(authserver_uri(), "auth/user/"))
+            session.toggle_token_refreshing()
+            return _client.token
+
+        return token
+
+    return session.user_token().then(__refresh_token__).either(
+        lambda _notok: __client__(None),
+        __client__)
+
+
+def fetch_user_details() -> Either:
+    """Retrieve user details from the auth server"""
+    suser = session.session_info()["user"]
+    if suser["email"] == "anon@ymous.user":
+        udets = oauth2_get("auth/user/").then(
+            lambda usrdets: session.set_user_details({
+                "user_id": uuid.UUID(usrdets["user_id"]),
+                "name": usrdets["name"],
+                "email": usrdets["email"],
+                "token": session.user_token()}))
+        return udets
+    return Right(suser)
+
+
+def user_logged_in():
+    """Check whether the user has logged in."""
+    suser = session.session_info()["user"]
+    fetch_user_details()
+    return suser["logged_in"] and suser["token"].is_right()
+
+
+def authserver_authorise_uri():
+    """Build up the authorisation URI."""
+    req_baseurl = urlparse(request.base_url, scheme=request.scheme)
+    host_uri = f"{req_baseurl.scheme}://{req_baseurl.netloc}/"
+    return urljoin(
+        authserver_uri(),
+        "auth/authorise?response_type=code"
+        f"&client_id={oauth2_clientid()}"
+        f"&redirect_uri={urljoin(host_uri, 'oauth2/code')}")
+
+
+def __no_token__(_err) -> Left:
+    """Handle situation where request is attempted with no token."""
+    resp = requests.models.Response()
+    resp._content = json.dumps({#pylint: disable=[protected-access]
+        "error": "AuthenticationError",
+        "error-description": ("You need to authenticate to access requested "
+                              "information.")}).encode("utf-8")
+    resp.status_code = 400
+    return Left(resp)
+
+
+def oauth2_get(url, **kwargs) -> Either:
+    """Do a get request to the authentication/authorisation server."""
+    def __get__(_token) -> Either:
+        _uri = urljoin(authserver_uri(), url)
+        try:
+            resp = oauth2_client().get(
+                _uri,
+                **{
+                    **kwargs,
+                    "headers": {
+                        **kwargs.get("headers", {}),
+                        "Content-Type": "application/json"
+                    }
+                })
+            if resp.status_code in mrequests.SUCCESS_CODES:
+                return Right(resp.json())
+            return Left(resp)
+        except Exception as exc:#pylint: disable=[broad-except]
+            app.logger.error("Error retrieving data from auth server: (GET %s)",
+                             _uri,
+                             exc_info=True)
+            return Left(exc)
+    return session.user_token().either(__no_token__, __get__)
+
+
+def oauth2_post(url, data=None, json=None, **kwargs):#pylint: disable=[redefined-outer-name]
+    """Do a POST request to the authentication/authorisation server."""
+    def __post__(_token) -> Either:
+        _uri = urljoin(authserver_uri(), url)
+        _headers = ({
+                        **kwargs.get("headers", {}),
+                        "Content-Type": "application/json"
+                    }
+                    if bool(json) else kwargs.get("headers", {}))
+        try:
+            request_data = {
+                **(data or {}),
+                **(json or {}),
+                "client_id": oauth2_clientid(),
+                "client_secret": oauth2_clientsecret()
+            }
+            resp = oauth2_client().post(
+                _uri,
+                data=(request_data if bool(data) else None),
+                json=(request_data if bool(json) else None),
+                **{**kwargs, "headers": _headers})
+            if resp.status_code in mrequests.SUCCESS_CODES:
+                return Right(resp.json())
+            return Left(resp)
+        except Exception as exc:#pylint: disable=[broad-except]
+            app.logger.error("Error retrieving data from auth server: (POST %s)",
+                             _uri,
+                             exc_info=True)
+            return Left(exc)
+    return session.user_token().either(__no_token__, __post__)
diff --git a/uploader/oauth2/jwks.py b/uploader/oauth2/jwks.py
new file mode 100644
index 0000000..efd0499
--- /dev/null
+++ b/uploader/oauth2/jwks.py
@@ -0,0 +1,86 @@
+"""Utilities dealing with JSON Web Keys (JWK)"""
+import os
+from pathlib import Path
+from typing import Any, Union
+from datetime import datetime, timedelta
+
+from flask import Flask
+from authlib.jose import JsonWebKey
+from pymonad.either import Left, Right, Either
+
+def jwks_directory(app: Flask, configname: str) -> Path:
+    """Compute the directory where the JWKs are stored."""
+    appsecretsdir = Path(app.config[configname]).parent
+    if appsecretsdir.exists() and appsecretsdir.is_dir():
+        jwksdir = Path(appsecretsdir, "jwks/")
+        if not jwksdir.exists():
+            jwksdir.mkdir()
+        return jwksdir
+    raise ValueError(
+        "The `appsecretsdir` value should be a directory that actually exists.")
+
+
+def generate_and_save_private_key(
+        storagedir: Path,
+        kty: str = "RSA",
+        crv_or_size: Union[str, int] = 2048,
+        options: tuple[tuple[str, Any]] = (("iat", datetime.now().timestamp()),)
+) -> JsonWebKey:
+    """Generate a private key and save to `storagedir`."""
+    privatejwk = JsonWebKey.generate_key(
+        kty, crv_or_size, dict(options), is_private=True)
+    keyname = f"{privatejwk.thumbprint()}.private.pem"
+    with open(Path(storagedir, keyname), "wb") as pemfile:
+        pemfile.write(privatejwk.as_pem(is_private=True))
+
+    return privatejwk
+
+
+def pem_to_jwk(filepath: Path) -> JsonWebKey:
+    """Parse a PEM file into a JWK object."""
+    with open(filepath, "rb") as pemfile:
+        return JsonWebKey.import_key(pemfile.read())
+
+
+def __sorted_jwks_paths__(storagedir: Path) -> tuple[tuple[float, Path], ...]:
+    """A sorted list of the JWK file paths with their creation timestamps."""
+    return tuple(sorted(((os.stat(keypath).st_ctime, keypath)
+                         for keypath in (Path(storagedir, keyfile)
+                                         for keyfile in os.listdir(storagedir)
+                                         if keyfile.endswith(".pem"))),
+                        key=lambda tpl: tpl[0]))
+
+
+def list_jwks(storagedir: Path) -> tuple[JsonWebKey, ...]:
+    """
+    List all the JWKs in a particular directory in the order they were created.
+    """
+    return tuple(pem_to_jwk(keypath) for ctime,keypath in
+                 __sorted_jwks_paths__(storagedir))
+
+
+def newest_jwk(storagedir: Path) -> Either:
+    """
+    Return an Either monad with the newest JWK or a message if none exists.
+    """
+    existingkeys = __sorted_jwks_paths__(storagedir)
+    if len(existingkeys) > 0:
+        return Right(pem_to_jwk(existingkeys[-1][1]))
+    return Left("No JWKs exist")
+
+
+def newest_jwk_with_rotation(jwksdir: Path, keyage: int) -> JsonWebKey:
+    """
+    Retrieve the latests JWK, creating a new one if older than `keyage` days.
+    """
+    def newer_than_days(jwkey):
+        filestat = os.stat(Path(
+            jwksdir, f"{jwkey.as_dict()['kid']}.private.pem"))
+        oldesttimeallowed = (datetime.now() - timedelta(days=keyage))
+        if filestat.st_ctime < (oldesttimeallowed.timestamp()):
+            return Left("JWK is too old!")
+        return jwkey
+
+    return newest_jwk(jwksdir).then(newer_than_days).either(
+        lambda _errmsg: generate_and_save_private_key(jwksdir),
+        lambda key: key)
diff --git a/uploader/oauth2/tokens.py b/uploader/oauth2/tokens.py
new file mode 100644
index 0000000..eb650f6
--- /dev/null
+++ b/uploader/oauth2/tokens.py
@@ -0,0 +1,47 @@
+"""Utilities for dealing with tokens."""
+import uuid
+from typing import Union
+from urllib.parse import urljoin
+from datetime import datetime, timedelta
+
+from authlib.jose import jwt
+from flask import current_app as app
+
+from uploader import monadic_requests as mrequests
+
+from . import jwks
+from .client import (SCOPE, authserver_uri, oauth2_clientid)
+
+
+def request_token(token_uri: str, user_id: Union[uuid.UUID, str], **kwargs):
+    """Request token from the auth server."""
+    issued = datetime.now()
+    jwtkey = jwks.newest_jwk_with_rotation(
+        jwks.jwks_directory(app, "UPLOADER_SECRETS"),
+        int(app.config["JWKS_ROTATION_AGE_DAYS"]))
+    _mins2expiry = kwargs.get("minutes_to_expiry", 5)
+    return mrequests.post(
+        token_uri,
+        json={
+            "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
+            "scope": kwargs.get("scope", SCOPE),
+            "assertion": jwt.encode(
+                header={
+                    "alg": "RS256",
+                    "typ": "JWT",
+                    "kid": jwtkey.as_dict()["kid"]
+                },
+                payload={
+                    "iss": str(oauth2_clientid()),
+                    "sub": str(user_id),
+                    "aud": urljoin(authserver_uri(), "auth/token"),
+                    "exp": (issued + timedelta(minutes=_mins2expiry)).timestamp(),
+                    "nbf": int(issued.timestamp()),
+                    "iat": int(issued.timestamp()),
+                    "jti": str(uuid.uuid4())
+                },
+                key=jwtkey).decode("utf8"),
+            "client_id": oauth2_clientid(),
+            **kwargs.get("extra_params", {})
+        }
+    )
diff --git a/uploader/oauth2/views.py b/uploader/oauth2/views.py
new file mode 100644
index 0000000..1ee4257
--- /dev/null
+++ b/uploader/oauth2/views.py
@@ -0,0 +1,108 @@
+"""Views for OAuth2 related functionality."""
+from urllib.parse import urljoin, urlparse, urlunparse
+
+from flask import (
+    flash,
+    jsonify,
+    url_for,
+    request,
+    redirect,
+    Blueprint,
+    current_app as app)
+
+from uploader import session
+from uploader import monadic_requests as mrequests
+from uploader.monadic_requests import make_error_handler
+
+from . import jwks
+from .tokens import request_token
+from .client import (
+    user_logged_in,
+    authserver_uri,
+    oauth2_clientid,
+    fetch_user_details,
+    oauth2_clientsecret)
+
+oauth2 = Blueprint("oauth2", __name__)
+
+
+@oauth2.route("/code")
+def authorisation_code():
+    """Receive authorisation code from auth server and use it to get token."""
+    def __process_error__(error_response):
+        app.logger.debug("ERROR: (%s)", error_response.content)
+        flash("There was an error retrieving the authorisation token.",
+              "alert alert-danger")
+        return redirect("/")
+
+    def __fail_set_user_details__(_failure):
+        app.logger.debug("Fetching user details fails: %s", _failure)
+        flash("Could not retrieve the user details", "alert alert-danger")
+        return redirect("/")
+
+    def __success_set_user_details__(_success):
+        app.logger.debug("Session info: %s", _success)
+        return redirect("/")
+
+    def __success__(token):
+        session.set_user_token(token)
+        return fetch_user_details().either(
+                    __fail_set_user_details__,
+                    __success_set_user_details__)
+
+    code = request.args.get("code", "").strip()
+    if not bool(code):
+        flash("AuthorisationError: No code was provided.", "alert alert-danger")
+        return redirect("/")
+
+    baseurl = urlparse(request.base_url, scheme=request.scheme)
+    return request_token(
+        token_uri=urljoin(authserver_uri(), "auth/token"),
+        user_id=request.args["user_id"],
+        extra_params={
+            "code": code,
+            "redirect_uri": urljoin(
+                urlunparse(baseurl),
+                url_for("oauth2.authorisation_code")),
+        }).either(__process_error__, __success__)
+
+@oauth2.route("/public-jwks")
+def public_jwks():
+    """List the available JWKs"""
+    return jsonify({
+        "documentation": (
+            "The keys are listed in order of creation, from the oldest (first) "
+            "to the newest (last)."),
+        "jwks": tuple(key.as_dict() for key
+                      in jwks.list_jwks(jwks.jwks_directory(
+                          app, "UPLOADER_SECRETS")))
+    })
+
+
+@oauth2.route("/logout", methods=["GET"])
+def logout():
+    """Log out of any active sessions."""
+    def __unset_session__(session_info):
+        _user = session_info["user"]
+        _user_str = f"{_user['name']} ({_user['email']})"
+        session.clear_session_info()
+        flash("Successfully signed out.", "alert alert-success")
+        return redirect("/")
+
+    if user_logged_in():
+        return session.user_token().then(
+            lambda _tok: mrequests.post(
+                urljoin(authserver_uri(), "auth/revoke"),
+                json={
+                    "token": _tok["refresh_token"],
+                    "token_type_hint": "refresh_token",
+                    "client_id": oauth2_clientid(),
+                    "client_secret": oauth2_clientsecret()
+                })).either(
+                    make_error_handler(
+                        redirect_to=redirect("/"),
+                        cleanup_thunk=lambda: __unset_session__(
+                            session.session_info())),
+                    lambda res: __unset_session__(session.session_info()))
+    flash("There is no user that is currently logged in.", "alert alert-info")
+    return redirect("/")