diff options
Diffstat (limited to 'uploader/oauth2/client.py')
-rw-r--r-- | uploader/oauth2/client.py | 230 |
1 files changed, 230 insertions, 0 deletions
diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py new file mode 100644 index 0000000..e7128de --- /dev/null +++ b/uploader/oauth2/client.py @@ -0,0 +1,230 @@ +"""OAuth2 client utilities.""" +import json +import time +import random +from datetime import datetime, timedelta +from urllib.parse import urljoin, urlparse + +import requests +from flask import request, current_app as app + +from pymonad.either import Left, Right, Either + +from authlib.common.urls import url_decode +from authlib.jose.errors import BadSignatureError +from authlib.jose import KeySet, JsonWebKey, JsonWebToken +from authlib.integrations.requests_client import OAuth2Session + +from uploader import session +import uploader.monadic_requests as mrequests + +SCOPE = ("profile group role resource register-client user masquerade " + "introspect migrate-data") + + +def authserver_uri(): + """Return URI to authorisation server.""" + return app.config["AUTH_SERVER_URL"] + + +def oauth2_clientid(): + """Return the client id.""" + return app.config["OAUTH2_CLIENT_ID"] + + +def oauth2_clientsecret(): + """Return the client secret.""" + return app.config["OAUTH2_CLIENT_SECRET"] + + +def __fetch_auth_server_jwks__() -> KeySet: + """Fetch the JWKs from the auth server.""" + return KeySet([ + JsonWebKey.import_key(key) + for key in requests.get( + urljoin(authserver_uri(), "auth/public-jwks") + ).json()["jwks"]]) + + +def __update_auth_server_jwks__(jwks) -> KeySet: + """Update the JWKs from the servers if necessary.""" + last_updated = jwks["last-updated"] + now = datetime.now().timestamp() + # Maybe the `two_hours` variable below can be made into a configuration + # variable and passed in to this function + two_hours = timedelta(hours=2).seconds + if bool(last_updated) and (now - last_updated) < two_hours: + return jwks["jwks"] + + return session.set_auth_server_jwks(__fetch_auth_server_jwks__()) + + +def auth_server_jwks() -> KeySet: + """Fetch the auth-server JSON Web Keys information.""" + _jwks = session.session_info().get("auth_server_jwks") or {} + if bool(_jwks): + return __update_auth_server_jwks__({ + "last-updated": _jwks["last-updated"], + "jwks": KeySet([ + JsonWebKey.import_key(key) for key in _jwks.get( + "jwks", {"keys": []})["keys"]]) + }) + + return __update_auth_server_jwks__({ + "last-updated": (datetime.now() - timedelta(hours=3)).timestamp() + }) + + +def oauth2_client(): + """Build the OAuth2 client for use fetching data.""" + def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument] + """Update the token when refreshed.""" + session.set_user_token(token) + + def __json_auth__(client, _method, uri, headers, body): + return ( + uri, + {**headers, "Content-Type": "application/json"}, + json.dumps({ + **dict(url_decode(body)), + "client_id": client.client_id, + "client_secret": client.client_secret + })) + + def __client__(token) -> OAuth2Session: + client = OAuth2Session( + oauth2_clientid(), + oauth2_clientsecret(), + scope=SCOPE, + token_endpoint=urljoin(authserver_uri(), "/auth/token"), + token_endpoint_auth_method="client_secret_post", + token=token, + update_token=__update_token__) + client.register_client_auth_method( + ("client_secret_post", __json_auth__)) + return client + + def __token_expired__(token): + """Check whether the token has expired.""" + jwks = auth_server_jwks() + if bool(jwks): + for jwk in jwks.keys: + try: + jwt = JsonWebToken(["RS256"]).decode( + token["access_token"], key=jwk) + return datetime.now().timestamp() > jwt["exp"] + except BadSignatureError as _bse: + pass + + return False + + def __delay__(): + """Do a tiny delay.""" + time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100)))) + + def __refresh_token__(token): + """Refresh the token if necessary — synchronise amongst threads.""" + if __token_expired__(token): + __delay__() + if session.is_token_refreshing(): + while session.is_token_refreshing(): + __delay__() + + return session.user_token().either(None, lambda _tok: _tok) + + session.toggle_token_refreshing() + _client = __client__(token) + _client.get(urljoin(authserver_uri(), "auth/user/")) + session.toggle_token_refreshing() + return _client.token + + return token + + return session.user_token().then(__refresh_token__).either( + lambda _notok: __client__(None), + __client__) + + +def user_logged_in(): + """Check whether the user has logged in.""" + suser = session.session_info()["user"] + return suser["logged_in"] and suser["token"].is_right() + + +def authserver_authorise_uri(): + """Build up the authorisation URI.""" + req_baseurl = urlparse(request.base_url, scheme=request.scheme) + host_uri = f"{req_baseurl.scheme}://{req_baseurl.netloc}/" + return urljoin( + authserver_uri(), + "auth/authorise?response_type=code" + f"&client_id={oauth2_clientid()}" + f"&redirect_uri={urljoin(host_uri, 'oauth2/code')}") + + +def __no_token__(_err) -> Left: + """Handle situation where request is attempted with no token.""" + resp = requests.models.Response() + resp._content = json.dumps({#pylint: disable=[protected-access] + "error": "AuthenticationError", + "error-description": ("You need to authenticate to access requested " + "information.")}).encode("utf-8") + resp.status_code = 400 + return Left(resp) + + +def oauth2_get(url, **kwargs) -> Either: + """Do a get request to the authentication/authorisation server.""" + def __get__(_token) -> Either: + _uri = urljoin(authserver_uri(), url) + try: + resp = oauth2_client().get( + _uri, + **{ + **kwargs, + "headers": { + **kwargs.get("headers", {}), + "Content-Type": "application/json" + } + }) + if resp.status_code in mrequests.SUCCESS_CODES: + return Right(resp.json()) + return Left(resp) + except Exception as exc:#pylint: disable=[broad-except] + app.logger.error("Error retrieving data from auth server: (GET %s)", + _uri, + exc_info=True) + return Left(exc) + return session.user_token().either(__no_token__, __get__) + + +def oauth2_post(url, data=None, json=None, **kwargs):#pylint: disable=[redefined-outer-name] + """Do a POST request to the authentication/authorisation server.""" + def __post__(_token) -> Either: + _uri = urljoin(authserver_uri(), url) + _headers = ({ + **kwargs.get("headers", {}), + "Content-Type": "application/json" + } + if bool(json) else kwargs.get("headers", {})) + try: + request_data = { + **(data or {}), + **(json or {}), + "client_id": oauth2_clientid(), + "client_secret": oauth2_clientsecret() + } + resp = oauth2_client().post( + _uri, + data=(request_data if bool(data) else None), + json=(request_data if bool(json) else None), + **{**kwargs, "headers": _headers}) + if resp.status_code in mrequests.SUCCESS_CODES: + return Right(resp.json()) + return Left(resp) + except Exception as exc:#pylint: disable=[broad-except] + app.logger.error("Error retrieving data from auth server: (POST %s)", + _uri, + exc_info=True) + return Left(exc) + return session.user_token().either(__no_token__, __post__) |