"""OAuth2 client utilities.""" import json import time import random from datetime import datetime, timedelta from urllib.parse import urljoin, urlparse import requests from flask import request, current_app as app from pymonad.either import Left, Right, Either from authlib.common.urls import url_decode from authlib.jose.errors import BadSignatureError from authlib.jose import KeySet, JsonWebKey, JsonWebToken from authlib.integrations.requests_client import OAuth2Session from uploader import session import uploader.monadic_requests as mrequests SCOPE = ("profile group role resource register-client user masquerade " "introspect migrate-data") def authserver_uri(): """Return URI to authorisation server.""" return app.config["AUTH_SERVER_URL"] def oauth2_clientid(): """Return the client id.""" return app.config["OAUTH2_CLIENT_ID"] def oauth2_clientsecret(): """Return the client secret.""" return app.config["OAUTH2_CLIENT_SECRET"] def __fetch_auth_server_jwks__() -> KeySet: """Fetch the JWKs from the auth server.""" return KeySet([ JsonWebKey.import_key(key) for key in requests.get( urljoin(authserver_uri(), "auth/public-jwks") ).json()["jwks"]]) def __update_auth_server_jwks__(jwks) -> KeySet: """Update the JWKs from the servers if necessary.""" last_updated = jwks["last-updated"] now = datetime.now().timestamp() # Maybe the `two_hours` variable below can be made into a configuration # variable and passed in to this function two_hours = timedelta(hours=2).seconds if bool(last_updated) and (now - last_updated) < two_hours: return jwks["jwks"] return session.set_auth_server_jwks(__fetch_auth_server_jwks__()) def auth_server_jwks() -> KeySet: """Fetch the auth-server JSON Web Keys information.""" _jwks = session.session_info().get("auth_server_jwks") if bool(_jwks): return __update_auth_server_jwks__({ "last-updated": _jwks["last-updated"], "jwks": KeySet([ JsonWebKey.import_key(key) for key in _jwks.get( "jwks", {"keys": []})["keys"]]) }) return __update_auth_server_jwks__({ "last-updated": (datetime.now() - timedelta(hours=3)).timestamp() }) def oauth2_client(): """Build the OAuth2 client for use fetching data.""" def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument] """Update the token when refreshed.""" session.set_user_token(token) def __json_auth__(client, _method, uri, headers, body): return ( uri, {**headers, "Content-Type": "application/json"}, json.dumps({ **dict(url_decode(body)), "client_id": client.client_id, "client_secret": client.client_secret })) def __client__(token) -> OAuth2Session: client = OAuth2Session( oauth2_clientid(), oauth2_clientsecret(), scope=SCOPE, token_endpoint=urljoin(authserver_uri(), "/auth/token"), token_endpoint_auth_method="client_secret_post", token=token, update_token=__update_token__) client.register_client_auth_method( ("client_secret_post", __json_auth__)) return client def __token_expired__(token): """Check whether the token has expired.""" jwks = auth_server_jwks() if bool(jwks): for jwk in jwks.keys: try: jwt = JsonWebToken(["RS256"]).decode( token["access_token"], key=jwk) return datetime.now().timestamp() > jwt["exp"] except BadSignatureError as _bse: pass return False def __delay__(): """Do a tiny delay.""" time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100)))) def __refresh_token__(token): """Refresh the token if necessary — synchronise amongst threads.""" if __token_expired__(token): __delay__() if session.is_token_refreshing(): while session.is_token_refreshing(): __delay__() return session.user_token().either(None, lambda _tok: _tok) session.toggle_token_refreshing() _client = __client__(token) _client.get(urljoin(authserver_uri(), "auth/user/")) session.toggle_token_refreshing() return _client.token return token return session.user_token().then(__refresh_token__).either( lambda _notok: __client__(None), __client__) def user_logged_in(): """Check whether the user has logged in.""" suser = session.session_info()["user"] return suser["logged_in"] and suser["token"].is_right() def authserver_authorise_uri(): """Build up the authorisation URI.""" req_baseurl = urlparse(request.base_url, scheme=request.scheme) host_uri = f"{req_baseurl.scheme}://{req_baseurl.netloc}/" return urljoin( authserver_uri(), "auth/authorise?response_type=code" f"&client_id={oauth2_clientid()}" f"&redirect_uri={urljoin(host_uri, 'oauth2/code')}") def __no_token__(_err) -> Left: """Handle situation where request is attempted with no token.""" resp = requests.models.Response() resp._content = json.dumps({#pylint: disable=[protected-access] "error": "AuthenticationError", "error-description": ("You need to authenticate to access requested " "information.")}).encode("utf-8") resp.status_code = 400 return Left(resp) def oauth2_get(url, **kwargs) -> Either: """Do a get request to the authentication/authorisation server.""" def __get__(_token) -> Either: _uri = urljoin(authserver_uri(), url) try: resp = oauth2_client().get( _uri, **{ **kwargs, "headers": { **kwargs.get("headers", {}), "Content-Type": "application/json" } }) if resp.status_code in mrequests.SUCCESS_CODES: return Right(resp.json()) return Left(resp) except Exception as exc:#pylint: disable=[broad-except] app.logger.error("Error retriving data from auth server: (GET %s)", _uri, exc_info=True) return Left(exc) return session.user_token().either(__no_token__, __get__) def oauth2_post(url, data=None, json=None, **kwargs):#pylint: disable=[redefined-outer-name] """Do a POST request to the authentication/authorisation server.""" def __post__(_token) -> Either: _uri = urljoin(authserver_uri(), url) _headers = ({ **kwargs.get("headers", {}), "Content-Type": "application/json" } if bool(json) else kwargs.get("headers", {})) try: request_data = { **(data or {}), **(json or {}), "client_id": oauth2_clientid(), "client_secret": oauth2_clientsecret() } resp = oauth2_client().post( _uri, data=(request_data if bool(data) else None), json=(request_data if bool(json) else None), **{**kwargs, "headers": _headers}) if resp.status_code in mrequests.SUCCESS_CODES: return Right(resp.json()) return Left(resp) except Exception as exc:#pylint: disable=[broad-except] app.logger.error("Error retriving data from auth server: (POST %s)", _uri, exc_info=True) return Left(exc) return session.user_token().either(__no_token__, __post__)