aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/hooks.py
blob: a4240f4cd70aaa82b04ad8e87e94e8fd7d1c9ef2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import List
from flask import request_finished
from flask import request, current_app
from gn_auth.auth.db import sqlite3 as db
import functools

def register_hooks(app):
    request_finished.connect(edu_domain_hook, app)


def handle_register_request(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if request.method == "POST" and request.endpoint == "oauth2.users.register_user":
            return func(*args, **kwargs)
        else:
            return lambda *args, **kwargs: None
    return wrapper


@handle_register_request
def edu_domain_hook(sender, response, **extra):
    if response.status_code >= 400:
        return
    data = request.get_json()
    if data is None or "email" not in data or not data["email"].endswith("edu"):
        return
    registered_email = data["email"]
    apply_edu_role(registered_email)


def apply_edu_role(email):
    with db.connection(current_app.config["AUTH_DB"]) as conn:
        with db.cursor(conn) as cursor:
            cursor.execute("SELECT user_id FROM users WHERE email= ?", (email,) )
            user_result = cursor.fetchone()
            cursor.execute("SELECT role_id FROM roles WHERE role_name='hook-role-from-edu-domain'")
            role_result = cursor.fetchone()
            resource_ids = get_resources_for_edu_domain(cursor)
            if user_result is None or role_result is None:
                return
            user_id = user_result[0]
            role_id = role_result[0]
            cursor.executemany(
                "INSERT INTO user_roles(user_id, role_id, resource_id) "
                "VALUES(:user_id, :role_id, :resource_id)",
                tuple({
                    "user_id": user_id,
                    "role_id": role_id,
                    "resource_id": resource_id
                } for resource_id in resource_ids))


def get_resources_for_edu_domain(cursor) -> List[int]:
    """FIXME: I still haven't figured out how to get resources to be assigned to edu domain"""
    resources_query = """
        SELECT resource_id FROM resources INNER JOIN resource_categories USING(resource_category_id) WHERE resource_categories.resource_category_key IN ('genotype', 'phenotype', 'mrna')
    """
    cursor.execute(resources_query)
    resource_ids = [x[0] for x in cursor.fetchall()]
    return resource_ids