"""OAuth2 client utilities.""" from datetime import datetime, timedelta from urllib.parse import urljoin, urlparse import requests from flask import request, current_app as app from pymonad.either import Left, Right, Either from authlib.jose import jwt from authlib.jose import KeySet, JsonWebKey from authlib.jose.errors import BadSignatureError from authlib.integrations.requests_client import OAuth2Session from uploader import session import uploader.monadic_requests as mrequests SCOPE = ("profile group role resource register-client user masquerade " "introspect migrate-data") def authserver_uri(): """Return URI to authorisation server.""" return app.config["AUTH_SERVER_URL"] def oauth2_clientid(): """Return the client id.""" return app.config["OAUTH2_CLIENT_ID"] def oauth2_clientsecret(): """Return the client secret.""" return app.config["OAUTH2_CLIENT_SECRET"] def __make_token_validator__(keys: KeySet): """Make a token validator function.""" def __validator__(token: str): for key in keys.keys: try: jwt.decode(token["access_token"], key) return Right(token) except BadSignatureError: pass return Left("INVALID-TOKEN") return __validator__ def __validate_token__(sess_info): """Validate that the token is really from the auth server.""" info = sess_info info["user"]["token"] = info["user"]["token"].then(__make_token_validator__( KeySet([JsonWebKey.import_key(key) for key in info.get( "auth_server_jwks", {}).get( "jwks", {"keys": []})["keys"]]))) return session.save_session_info(info) def __update_auth_server_jwks__(sess_info): """Updates the JWKs every 2 hours or so.""" jwks = sess_info.get("auth_server_jwks") if bool(jwks): last_updated = jwks.get("last-updated") now = datetime.now().timestamp() if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds: return __validate_token__({**sess_info, "auth_server_jwks": jwks}) jwksuri = urljoin(authserver_uri(), "auth/public-jwks") return __validate_token__({ **sess_info, "auth_server_jwks": { "jwks": KeySet([ JsonWebKey.import_key(key) for key in requests.get(jwksuri).json()["jwks"]]).as_dict(), "last-updated": datetime.now().timestamp() } }) def oauth2_client(): """Build the OAuth2 client for use fetching data.""" def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument] """Update the token when refreshed.""" session.set_user_token(token) def __client__(token) -> OAuth2Session: client = OAuth2Session( oauth2_clientid(), oauth2_clientsecret(), scope=SCOPE, token_endpoint=urljoin(authserver_uri(), "/auth/token"), token_endpoint_auth_method="client_secret_post", token=token, update_token=__update_token__) return client __update_auth_server_jwks__(session.session_info()) return session.user_token().either( lambda _notok: __client__(None), __client__) def user_logged_in(): """Check whether the user has logged in.""" suser = session.session_info()["user"] return suser["logged_in"] and suser["token"].is_right() def authserver_authorise_uri(): req_baseurl = urlparse(request.base_url, scheme=request.scheme) host_uri = f"{req_baseurl.scheme}://{req_baseurl.netloc}/" return urljoin( authserver_uri(), "auth/authorise?response_type=code" f"&client_id={oauth2_clientid()}" f"&redirect_uri={urljoin(host_uri, 'oauth2/code')}") def __no_token__(_err) -> Left: """Handle situation where request is attempted with no token.""" resp = requests.models.Response() resp._content = json.dumps({ "error": "AuthenticationError", "error-description": ("You need to authenticate to access requested " "information.")}).encode("utf-8") resp.status_code = 400 return Left(resp) def oauth2_get(url, **kwargs) -> Either: """Do a get request to the authentication/authorisation server.""" def __get__(token) -> Either: _uri = urljoin(authserver_uri(), url) try: resp = oauth2_client().get( _uri, **{ **kwargs, "headers": { **kwargs.get("headers", {}), "Content-Type": "application/json" } }) if resp.status_code in mrequests.SUCCESS_CODES: return Right(resp.json()) return Left(resp) except Exception as exc: app.logger.error("Error retriving data from auth server: (GET %s)", _uri, exc_info=True) return Left(exc) return session.user_token().either(__no_token__, __get__) def oauth2_post(url, data=None, json=None, **kwargs): """Do a POST request to the authentication/authorisation server.""" def __post__(token) -> Either: _uri = urljoin(authserver_uri(), url) _headers = ({ **kwargs.get("headers", {}), "Content-Type": "application/json" } if bool(json) else kwargs.get("headers", {})) try: request_data = { **(data or {}), **(json or {}), "client_id": oauth2_clientid(), "client_secret": oauth2_clientsecret() } resp = oauth2_client().post( _uri, data=(request_data if bool(data) else None), json=(request_data if bool(json) else None), **{**kwargs, "headers": _headers}) if resp.status_code in mrequests.SUCCESS_CODES: return Right(resp.json()) return Left(resp) except Exception as exc: app.logger.error("Error retriving data from auth server: (POST %s)", _uri, exc_info=True) return Left(exc) return session.user_token().either(__no_token__, __post__)