diff options
Diffstat (limited to 'gn_auth/hooks.py')
-rw-r--r-- | gn_auth/hooks.py | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/gn_auth/hooks.py b/gn_auth/hooks.py new file mode 100644 index 0000000..bd7380b --- /dev/null +++ b/gn_auth/hooks.py @@ -0,0 +1,68 @@ +"""Authorisation hooks implementation""" +import functools +from typing import List + +from flask import request_finished +from flask import request, current_app + +from gn_auth.auth.db import sqlite3 as db + +def register_hooks(app): + """Initialise hooks system on the application.""" + request_finished.connect(edu_domain_hook, app) + + +def handle_register_request(func): + """Decorator for handling user registration hooks.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + if request.method == "POST" and request.endpoint == "oauth2.users.register_user": + return func(*args, **kwargs) + else: + return lambda *args, **kwargs: None + return wrapper + + +@handle_register_request +def edu_domain_hook(_sender, response, **_extra): + """Hook to run whenever a user with a `.edu` domain registers.""" + if response.status_code >= 400: + return + data = request.get_json() + if data is None or "email" not in data or not data["email"].endswith("edu"): + return + registered_email = data["email"] + apply_edu_role(registered_email) + + +def apply_edu_role(email): + """Assign 'hook-role-from-edu-domain' to user.""" + with db.connection(current_app.config["AUTH_DB"]) as conn: + with db.cursor(conn) as cursor: + cursor.execute("SELECT user_id FROM users WHERE email= ?", (email,) ) + user_result = cursor.fetchone() + cursor.execute("SELECT role_id FROM roles WHERE role_name='hook-role-from-edu-domain'") + role_result = cursor.fetchone() + resource_ids = get_resources_for_edu_domain(cursor) + if user_result is None or role_result is None: + return + user_id = user_result[0] + role_id = role_result[0] + cursor.executemany( + "INSERT INTO user_roles(user_id, role_id, resource_id) " + "VALUES(:user_id, :role_id, :resource_id)", + tuple({ + "user_id": user_id, + "role_id": role_id, + "resource_id": resource_id + } for resource_id in resource_ids)) + + +def get_resources_for_edu_domain(cursor) -> List[int]: + """FIXME: I still haven't figured out how to get resources to be assigned to edu domain""" + resources_query = """ + SELECT resource_id FROM resources INNER JOIN resource_categories USING(resource_category_id) WHERE resource_categories.resource_category_key IN ('genotype', 'phenotype', 'mrna') + """ + cursor.execute(resources_query) + resource_ids = [x[0] for x in cursor.fetchall()] + return resource_ids |