From ef60b19dcb338ad80707ecffc5a959f3c6f66209 Mon Sep 17 00:00:00 2001 From: John Nduli Date: Wed, 25 Sep 2024 18:35:48 +0300 Subject: feat: add base implementation for hooks system --- gn_auth/__init__.py | 2 ++ gn_auth/hooks.py | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 gn_auth/hooks.py (limited to 'gn_auth') diff --git a/gn_auth/__init__.py b/gn_auth/__init__.py index 973110a..6e2a884 100644 --- a/gn_auth/__init__.py +++ b/gn_auth/__init__.py @@ -8,6 +8,7 @@ from flask import Flask from flask_cors import CORS from authlib.jose import JsonWebKey +from gn_auth import hooks from gn_auth.misc_views import misc from gn_auth.auth.views import oauth2 @@ -87,5 +88,6 @@ def create_app( app.register_blueprint(oauth2, url_prefix="/auth") register_error_handlers(app) + hooks.register_hooks(app) return app diff --git a/gn_auth/hooks.py b/gn_auth/hooks.py new file mode 100644 index 0000000..a4240f4 --- /dev/null +++ b/gn_auth/hooks.py @@ -0,0 +1,61 @@ +from typing import List +from flask import request_finished +from flask import request, current_app +from gn_auth.auth.db import sqlite3 as db +import functools + +def register_hooks(app): + request_finished.connect(edu_domain_hook, app) + + +def handle_register_request(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if request.method == "POST" and request.endpoint == "oauth2.users.register_user": + return func(*args, **kwargs) + else: + return lambda *args, **kwargs: None + return wrapper + + +@handle_register_request +def edu_domain_hook(sender, response, **extra): + if response.status_code >= 400: + return + data = request.get_json() + if data is None or "email" not in data or not data["email"].endswith("edu"): + return + registered_email = data["email"] + apply_edu_role(registered_email) + + +def apply_edu_role(email): + with db.connection(current_app.config["AUTH_DB"]) as conn: + with db.cursor(conn) as cursor: + cursor.execute("SELECT user_id FROM users WHERE email= ?", (email,) ) + user_result = cursor.fetchone() + cursor.execute("SELECT role_id FROM roles WHERE role_name='hook-role-from-edu-domain'") + role_result = cursor.fetchone() + resource_ids = get_resources_for_edu_domain(cursor) + if user_result is None or role_result is None: + return + user_id = user_result[0] + role_id = role_result[0] + cursor.executemany( + "INSERT INTO user_roles(user_id, role_id, resource_id) " + "VALUES(:user_id, :role_id, :resource_id)", + tuple({ + "user_id": user_id, + "role_id": role_id, + "resource_id": resource_id + } for resource_id in resource_ids)) + + +def get_resources_for_edu_domain(cursor) -> List[int]: + """FIXME: I still haven't figured out how to get resources to be assigned to edu domain""" + resources_query = """ + SELECT resource_id FROM resources INNER JOIN resource_categories USING(resource_category_id) WHERE resource_categories.resource_category_key IN ('genotype', 'phenotype', 'mrna') + """ + cursor.execute(resources_query) + resource_ids = [x[0] for x in cursor.fetchall()] + return resource_ids -- cgit v1.2.3