"""OAuth2 client utilities."""
import json
import time
import uuid
import random
from datetime import datetime, timedelta
from urllib.parse import urljoin, urlparse
import requests
from flask import request, current_app as app
from pymonad.either import Left, Right, Either
from authlib.common.urls import url_decode
from authlib.jose.errors import BadSignatureError
from authlib.jose import KeySet, JsonWebKey, JsonWebToken
from authlib.integrations.requests_client import OAuth2Session
from uploader import session
import uploader.monadic_requests as mrequests
SCOPE = ("profile group role resource register-client user masquerade "
"introspect migrate-data")
def authserver_uri():
"""Return URI to authorisation server."""
return app.config["AUTH_SERVER_URL"]
def oauth2_clientid():
"""Return the client id."""
return app.config["OAUTH2_CLIENT_ID"]
def oauth2_clientsecret():
"""Return the client secret."""
return app.config["OAUTH2_CLIENT_SECRET"]
def __fetch_auth_server_jwks__() -> KeySet:
"""Fetch the JWKs from the auth server."""
return KeySet([
JsonWebKey.import_key(key)
for key in requests.get(
urljoin(authserver_uri(), "auth/public-jwks"),
timeout=(9.13, 20)
).json()["jwks"]])
def __update_auth_server_jwks__(jwks) -> KeySet:
"""Update the JWKs from the servers if necessary."""
last_updated = jwks["last-updated"]
now = datetime.now().timestamp()
# Maybe the `two_hours` variable below can be made into a configuration
# variable and passed in to this function
two_hours = timedelta(hours=2).seconds
if bool(last_updated) and (now - last_updated) < two_hours:
return jwks["jwks"]
return session.set_auth_server_jwks(__fetch_auth_server_jwks__())
def auth_server_jwks() -> KeySet:
"""Fetch the auth-server JSON Web Keys information."""
_jwks = session.session_info().get("auth_server_jwks") or {}
if bool(_jwks):
return __update_auth_server_jwks__({
"last-updated": _jwks["last-updated"],
"jwks": KeySet([
JsonWebKey.import_key(key) for key in _jwks.get(
"jwks", {"keys": []})["keys"]])
})
return __update_auth_server_jwks__({
"last-updated": (datetime.now() - timedelta(hours=3)).timestamp()
})
def oauth2_client():
"""Build the OAuth2 client for use fetching data."""
def __update_token__(token, refresh_token=None, access_token=None):# pylint: disable=[unused-argument]
"""Update the token when refreshed."""
session.set_user_token(token)
def __json_auth__(client, _method, uri, headers, body):
return (
uri,
{**headers, "Content-Type": "application/json"},
json.dumps({
**dict(url_decode(body)),
"client_id": client.client_id,
"client_secret": client.client_secret
}))
def __client__(token) -> OAuth2Session:
client = OAuth2Session(
oauth2_clientid(),
oauth2_clientsecret(),
scope=SCOPE,
token_endpoint=urljoin(authserver_uri(), "/auth/token"),
token_endpoint_auth_method="client_secret_post",
token=token,
update_token=__update_token__)
client.register_client_auth_method(
("client_secret_post", __json_auth__))
return client
def __token_expired__(token):
"""Check whether the token has expired."""
jwks = auth_server_jwks()
if bool(jwks):
for jwk in jwks.keys:
try:
jwt = JsonWebToken(["RS256"]).decode(
token["access_token"], key=jwk)
if bool(jwt.get("exp")):
return datetime.now().timestamp() > jwt["exp"]
except BadSignatureError as _bse:
pass
return False
def __delay__():
"""Do a tiny delay."""
time.sleep(random.choice(tuple(i/1000.0 for i in range(0,100))))
def __refresh_token__(token):
"""Refresh the token if necessary — synchronise amongst threads."""
if __token_expired__(token):
__delay__()
if session.is_token_refreshing():
while session.is_token_refreshing():
__delay__()
return session.user_token().either(None, lambda _tok: _tok)
session.toggle_token_refreshing()
_client = __client__(token)
_client.get(urljoin(authserver_uri(), "auth/user/"))
session.toggle_token_refreshing()
return _client.token
return token
return session.user_token().then(__refresh_token__).either(
lambda _notok: __client__(None),
__client__)
def fetch_user_details() -> Either:
"""Retrieve user details from the auth server"""
suser = session.session_info()["user"]
if suser["email"] == "anon@ymous.user":
udets = oauth2_get("auth/user/").then(
lambda usrdets: session.set_user_details({
"user_id": uuid.UUID(usrdets["user_id"]),
"name": usrdets["name"],
"email": usrdets["email"],
"token": session.user_token(),
"logged_in": session.user_token().either(
lambda _e: False, lambda _t: True)
}))
return udets
return Right(suser)
def user_logged_in():
"""Check whether the user has logged in."""
suser = session.session_info()["user"]
fetch_user_details()
return suser["logged_in"] and suser["token"].is_right()
def authserver_authorise_uri():
"""Build up the authorisation URI."""
req_baseurl = urlparse(request.base_url, scheme=request.scheme)
host_uri = f"{req_baseurl.scheme}://{req_baseurl.netloc}/"
return urljoin(
authserver_uri(),
"auth/authorise?response_type=code"
f"&client_id={oauth2_clientid()}"
f"&redirect_uri={urljoin(host_uri, 'oauth2/code')}")
def __no_token__(_err) -> Left:
"""Handle situation where request is attempted with no token."""
resp = requests.models.Response()
resp._content = json.dumps({#pylint: disable=[protected-access]
"error": "AuthenticationError",
"error-description": ("You need to authenticate to access requested "
"information.")}).encode("utf-8")
resp.status_code = 400
return Left(resp)
def oauth2_get(url, **kwargs) -> Either:
"""Do a get request to the authentication/authorisation server."""
def __get__(_token) -> Either:
_uri = urljoin(authserver_uri(), url)
try:
resp = oauth2_client().get(
_uri,
**{
**kwargs,
"headers": {
**kwargs.get("headers", {}),
"Content-Type": "application/json"
}
})
if resp.status_code in mrequests.SUCCESS_CODES:
return Right(resp.json())
return Left(resp)
except Exception as exc:#pylint: disable=[broad-except]
app.logger.error("Error retrieving data from auth server: (GET %s)",
_uri,
exc_info=True)
return Left(exc)
return session.user_token().either(__no_token__, __get__)
def oauth2_post(url, data=None, json=None, **kwargs):#pylint: disable=[redefined-outer-name]
"""Do a POST request to the authentication/authorisation server."""
def __post__(_token) -> Either:
_uri = urljoin(authserver_uri(), url)
_headers = ({
**kwargs.get("headers", {}),
"Content-Type": "application/json"
}
if bool(json) else kwargs.get("headers", {}))
try:
request_data = {
**(data or {}),
**(json or {}),
"client_id": oauth2_clientid(),
"client_secret": oauth2_clientsecret()
}
resp = oauth2_client().post(
_uri,
data=(request_data if bool(data) else None),
json=(request_data if bool(json) else None),
**{**kwargs, "headers": _headers})
if resp.status_code in mrequests.SUCCESS_CODES:
return Right(resp.json())
return Left(resp)
except Exception as exc:#pylint: disable=[broad-except]
app.logger.error("Error retrieving data from auth server: (POST %s)",
_uri,
exc_info=True)
return Left(exc)
return session.user_token().either(__no_token__, __post__)