From 38df25ac8a1b4385e5987d41bb5447ad267408c5 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Fri, 26 Jul 2024 16:38:45 -0500 Subject: Add wrappers for OAuth2Session's `get` and `post` methods. Fix bugs. --- uploader/oauth2/client.py | 97 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 84 insertions(+), 13 deletions(-) diff --git a/uploader/oauth2/client.py b/uploader/oauth2/client.py index f20af4a..7204b1e 100644 --- a/uploader/oauth2/client.py +++ b/uploader/oauth2/client.py @@ -5,7 +5,7 @@ from urllib.parse import urljoin, urlparse import requests from flask import request, current_app as app -from pymonad.either import Left, Right +from pymonad.either import Left, Right, Either from authlib.jose import jwt from authlib.jose import KeySet, JsonWebKey @@ -13,6 +13,7 @@ from authlib.jose.errors import BadSignatureError from authlib.integrations.requests_client import OAuth2Session from uploader import session +import uploader.monadic_requests as mrequests SCOPE = ("profile group role resource register-client user masquerade " "introspect migrate-data") @@ -36,22 +37,25 @@ def oauth2_clientsecret(): def __make_token_validator__(keys: KeySet): """Make a token validator function.""" def __validator__(token: str): - try: - jwt.decode(token, keys) - return Right(token) - except BadSignatureError: - return Left("INVALID-TOKEN") + for key in keys.keys: + try: + jwt.decode(token["access_token"], key) + return Right(token) + except BadSignatureError: + pass + + return Left("INVALID-TOKEN") return __validator__ def __validate_token__(sess_info): """Validate that the token is really from the auth server.""" - info = __update_auth_server_jwks__(sess_info) + info = sess_info info["user"]["token"] = info["user"]["token"].then(__make_token_validator__( - KeySet(JsonWebKey(key) for key in info.get( + KeySet([JsonWebKey.import_key(key) for key in info.get( "auth_server_jwks", {}).get( - "jwks", {"keys": []})["keys"]))) + "jwks", {"keys": []})["keys"]]))) return session.save_session_info(info) @@ -80,8 +84,6 @@ def oauth2_client(): """Build the OAuth2 client for use fetching data.""" def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument] """Update the token when refreshed.""" - app.logger.debug(f"IN `{__name__}`:\n\trefresh_token: {refresh_token}" - f"\n\taccess_token: {access_token}\n\ttoken: {token}") session.set_user_token(token) def __client__(token) -> OAuth2Session: @@ -104,8 +106,8 @@ def oauth2_client(): def user_logged_in(): """Check whether the user has logged in.""" suser = session.session_info()["user"] - # return suser["logged_in"] and suser["token"].is_right() - return False + return suser["logged_in"] and suser["token"].is_right() + def authserver_authorise_uri(): req_baseurl = urlparse(request.base_url, scheme=request.scheme) @@ -115,3 +117,72 @@ def authserver_authorise_uri(): "auth/authorise?response_type=code" f"&client_id={oauth2_clientid()}" f"&redirect_uri={urljoin(host_uri, 'oauth2/code')}") + + +def __no_token__(_err) -> Left: + """Handle situation where request is attempted with no token.""" + resp = requests.models.Response() + resp._content = json.dumps({ + "error": "AuthenticationError", + "error-description": ("You need to authenticate to access requested " + "information.")}).encode("utf-8") + resp.status_code = 400 + return Left(resp) + + +def oauth2_get(url, **kwargs) -> Either: + """Do a get request to the authentication/authorisation server.""" + def __get__(token) -> Either: + _uri = urljoin(authserver_uri(), url) + try: + resp = oauth2_client().get( + _uri, + **{ + **kwargs, + "headers": { + **kwargs.get("headers", {}), + "Content-Type": "application/json" + } + }) + if resp.status_code in mrequests.SUCCESS_CODES: + return Right(resp.json()) + return Left(resp) + except Exception as exc: + app.logger.error("Error retriving data from auth server: (GET %s)", + _uri, + exc_info=True) + return Left(exc) + return session.user_token().either(__no_token__, __get__) + + +def oauth2_post(url, data=None, json=None, **kwargs): + """Do a POST request to the authentication/authorisation server.""" + def __post__(token) -> Either: + _uri = urljoin(authserver_uri(), url) + try: + request_data = { + **(data or {}), + **(json or {}), + "client_id": oauth2_clientid(), + "client_secret": oauth2_clientsecret() + } + resp = oauth2_client().post( + _uri, + data=(request_data if bool(data) else None), + json=(request_data if json else None), + **{ + **kwargs, + "headers": { + **kwargs.get("headers", {}), + "Content-Type": "application/json" + } + }) + if resp.status_code in requests.SUCCESS_CODES: + return Right(resp.json()) + return Left(resp) + except Exception as exc: + app.logger.error("Error retriving data from auth server: (POST %s)", + _uri, + exc_info=True) + return Left(exc) + return session.user_token().either(__no_token__, __get__) -- cgit v1.2.3