"""Common oauth2 client utilities."""
import json
import time
import random
import requests
from typing import Optional
from urllib.parse import urljoin
from datetime import datetime, timedelta
from flask import current_app as app
from pymonad.either import Left, Right, Either
from authlib.common.urls import url_decode
from authlib.jose.errors import BadSignatureError
from authlib.jose import KeySet, JsonWebKey, JsonWebToken
from authlib.integrations.requests_client import OAuth2Session
from gn2.wqflask.oauth2 import session
from gn2.wqflask.external_errors import ExternalRequestError
SCOPE = ("profile group role resource user masquerade introspect")
def authserver_uri():
"""Return URI to authorisation server."""
return app.config["AUTH_SERVER_URL"]
def oauth2_clientid():
"""Return the client id."""
return app.config["OAUTH2_CLIENT_ID"]
def oauth2_clientsecret():
"""Return the client secret."""
return app.config["OAUTH2_CLIENT_SECRET"]
def user_logged_in():
"""Check whether the user has logged in."""
suser = session.session_info()["user"]
return suser["logged_in"] and suser["token"].is_right()
def __make_token_validator__(keys: KeySet):
"""Make a token validator function."""
def __validator__(token: dict):
for key in keys.keys:
try:
# Fixes CVE-2016-10555. See
# https://docs.authlib.org/en/latest/jose/jwt.html
jwt = JsonWebToken(["RS256"])
jwt.decode(token["access_token"], key)
return Right(token)
except BadSignatureError:
pass
return Left("INVALID-TOKEN")
return __validator__
def auth_server_jwks() -> Optional[KeySet]:
"""Fetch the auth-server JSON Web Keys information."""
_jwks = session.session_info().get("auth_server_jwks")
if bool(_jwks):
return {
"last-updated": _jwks["last-updated"],
"jwks": KeySet([
JsonWebKey.import_key(key) for key in _jwks.get(
"auth_server_jwks", {}).get(
"jwks", {"keys": []})["keys"]])}
def __validate_token__(keys):
"""Validate that the token is really from the auth server."""
def __index__(_sess):
return _sess
return session.user_token().then(__make_token_validator__(keys)).then(
session.set_user_token).either(__index__, __index__)
def __update_auth_server_jwks__():
"""Updates the JWKs every 2 hours or so."""
jwks = auth_server_jwks()
if bool(jwks):
last_updated = jwks.get("last-updated")
now = datetime.now().timestamp()
if bool(last_updated) and (now - last_updated) < timedelta(hours=2).seconds:
return __validate_token__(jwks["jwks"])
jwksuri = urljoin(authserver_uri(), "auth/public-jwks")
jwks = KeySet([
JsonWebKey.import_key(key)
for key in requests.get(jwksuri).json()["jwks"]])
return __validate_token__(jwks)
def is_token_expired(token):
"""Check whether the token has expired."""
__update_auth_server_jwks__()
jwks = auth_server_jwks()
if bool(jwks):
for jwk in jwks["jwks"].keys:
try:
jwt = JsonWebToken(["RS256"]).decode(
token["access_token"], key=jwk)
return datetime.now().timestamp() > jwt["exp"]
except BadSignatureError as _bse:
pass
return False
def oauth2_client():
def __update_token__(token, refresh_token=None, access_token=None):
"""Update the token when refreshed."""
session.set_user_token(token)
return token
def __delay__():
"""Do a tiny delay."""
time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
def __refresh_token__(token):
"""Synchronise token refresh."""
if is_token_expired(token):
__delay__()
if session.is_token_refreshing():
while session.is_token_refreshing():
__delay__()
_token = session.user_token().either(None, lambda _tok: _tok)
return _token
session.toggle_token_refreshing()
_client = __client__(token)
_client.get(urljoin(authserver_uri(), "auth/user/"))
session.toggle_token_refreshing()
return _client.token
return token
def __json_auth__(client, method, uri, headers, body):
return (
uri,
{**headers, "Content-Type": "application/json"},
json.dumps({
**dict(url_decode(body)),
"client_id": oauth2_clientid(),
"client_secret": oauth2_clientsecret()
}))
def __client__(token) -> OAuth2Session:
client = OAuth2Session(
oauth2_clientid(),
oauth2_clientsecret(),
scope=token["scope"],
token_endpoint=urljoin(authserver_uri(), "auth/token"),
token_endpoint_auth_method="client_secret_post",
token=token,
update_token=__update_token__)
client.register_client_auth_method(
("client_secret_post", __json_auth__))
return client
__update_auth_server_jwks__()
return session.user_token().then(__refresh_token__).either(
lambda _notok: __client__(None),
lambda token: __client__(token))
def __no_token__(_err) -> Left:
"""Handle situation where request is attempted with no token."""
resp = requests.models.Response()
resp._content = json.dumps({
"error": "AuthenticationError",
"error-description": ("You need to authenticate to access requested "
"information.")}).encode("utf-8")
resp.status_code = 400
return Left(resp)
def oauth2_get(
uri_path: str,
data: dict = {},
jsonify_p: bool = False,
headers: dict = {"Content-Type": "application/json"},
**kwargs
) -> Either:
def __get__(token) -> Either:
resp = oauth2_client().get(
urljoin(authserver_uri(), uri_path),
data=data,
headers=headers,
**kwargs)
if resp.status_code == 200:
if jsonify_p:
return Right(resp)
return Right(resp.json())
return Left(resp)
return session.user_token().either(__no_token__, __get__)
def oauth2_post(
uri_path: str,
data: Optional[dict] = None,
json: Optional[dict] = None,
headers: dict = {"Content-Type": "application/json"},
**kwargs
) -> Either:
def __post__(token) -> Either:
resp = oauth2_client().post(
urljoin(authserver_uri(), uri_path),
data=data,
json=json,
headers=headers,
**kwargs)
if resp.status_code == 200:
return Right(resp.json())
return Left(resp)
return session.user_token().either(__no_token__, __post__)
def no_token_get(
uri_path: str,
headers: dict = {"Content-Type": "application/json"},
**kwargs
) -> Either:
uri = urljoin(authserver_uri(), uri_path)
try:
resp = requests.get(uri, headers=headers, **kwargs)
if resp.status_code == 200:
return Right(resp.json())
return Left(resp)
except requests.exceptions.RequestException as exc:
raise ExternalRequestError(uri, exc) from exc
def no_token_post(uri_path: str, **kwargs) -> Either:
data = kwargs.get("data", {})
the_json = kwargs.get("json", {})
request_data = {
**data,
**the_json,
"client_id": oauth2_clientid(),
"client_secret": oauth2_clientsecret()
}
new_kwargs = {
**{
key: value for key, value in kwargs.items()
if key not in ("data", "json")
},
("data" if bool(data) else "json"): request_data
}
try:
resp = requests.post(urljoin(authserver_uri(), uri_path),
**new_kwargs)
if resp.status_code == 200:
return Right(resp.json())
return Left(resp)
except requests.exceptions.RequestException as exc:
raise ExternalRequestError(uri_path, exc) from exc
def post(uri_path: str, **kwargs) -> Either:
"""
Generic function to do POST requests, that checks whether or not the user is
logged in and selects the appropriate function/method to run.
"""
if user_logged_in():
return oauth2_post(uri_path, **kwargs)
return no_token_post(uri_path, **kwargs)
def get(uri_path: str, **kwargs) -> Either:
"""
Generic function to do GET requests, that checks whether or not the user is
logged in and selects the appropriate function/method to run.
"""
if user_logged_in():
return oauth2_get(uri_path, **kwargs)
return no_token_get(uri_path, **kwargs)