aboutsummaryrefslogtreecommitdiff
path: root/gn2/wqflask/oauth2/checks.py
blob: b8db6dc20cc02fdf1e4770ae624f59ed66fd7e68 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Various checkers for OAuth2"""
from functools import wraps
from urllib.parse import urljoin

from flask import flash, request, redirect
from authlib.integrations.requests_client import OAuth2Session

from . import session
from .session import clear_session_info
from .client import (
    oauth2_get,
    oauth2_client,
    authserver_uri,
    oauth2_clientid,
    oauth2_clientsecret)


def require_oauth2(func):
    """Decorator for ensuring user is logged in."""
    @wraps(func)
    def __token_valid__(*args, **kwargs):
        """Check that the user is logged in and their token is valid."""

        def __clear_session__(_no_token):
            session.clear_session_info()
            flash("You need to be logged in.", "alert-warning")
            return redirect("/")

        def __with_token__(token):
            resp = oauth2_client().get(
                urljoin(authserver_uri(), "auth/user/"))
            user_details = resp.json()
            if not user_details.get("error", False):
                return func(*args, **kwargs)

            return clear_session_info(token)

        return session.user_token().either(__clear_session__, __with_token__)

    return __token_valid__


def require_oauth2_edit_resource_access(func):
    """Check if a user has edit access for a given resource."""
    @wraps(func)
    def __check_edit_access__(*args, **kwargs):
        # Check edit access, if not return to the same page.

        # This is for a GET
        resource_name = request.args.get("name", "")
        # And for a POST request.
        if request.method == "POST":
            resource_name = request.form.get("name", "")
        result = oauth2_get(
            f"auth/resource/authorisation/{resource_name}"
        ).either(
            lambda _: {"roles": []},
            lambda val: val
        )
        if "group:resource:edit-resource" not in result.get("roles", []):
            return redirect(f"/datasets/{resource_name}")
        return func(*args, **kwargs)
    return __check_edit_access__