about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn_auth/__init__.py2
-rw-r--r--gn_auth/hooks.py61
-rw-r--r--migrations/auth/20240924_01_thbvh-hooks-for-edu-domains.py24
3 files changed, 87 insertions, 0 deletions
diff --git a/gn_auth/__init__.py b/gn_auth/__init__.py
index 973110a..6e2a884 100644
--- a/gn_auth/__init__.py
+++ b/gn_auth/__init__.py
@@ -8,6 +8,7 @@ from flask import Flask
 from flask_cors import CORS
 from authlib.jose import JsonWebKey
 
+from gn_auth import hooks
 from gn_auth.misc_views import misc
 from gn_auth.auth.views import oauth2
 
@@ -87,5 +88,6 @@ def create_app(
     app.register_blueprint(oauth2, url_prefix="/auth")
 
     register_error_handlers(app)
+    hooks.register_hooks(app)
 
     return app
diff --git a/gn_auth/hooks.py b/gn_auth/hooks.py
new file mode 100644
index 0000000..a4240f4
--- /dev/null
+++ b/gn_auth/hooks.py
@@ -0,0 +1,61 @@
+from typing import List
+from flask import request_finished
+from flask import request, current_app
+from gn_auth.auth.db import sqlite3 as db
+import functools
+
+def register_hooks(app):
+    request_finished.connect(edu_domain_hook, app)
+
+
+def handle_register_request(func):
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        if request.method == "POST" and request.endpoint == "oauth2.users.register_user":
+            return func(*args, **kwargs)
+        else:
+            return lambda *args, **kwargs: None
+    return wrapper
+
+
+@handle_register_request
+def edu_domain_hook(sender, response, **extra):
+    if response.status_code >= 400:
+        return
+    data = request.get_json()
+    if data is None or "email" not in data or not data["email"].endswith("edu"):
+        return
+    registered_email = data["email"]
+    apply_edu_role(registered_email)
+
+
+def apply_edu_role(email):
+    with db.connection(current_app.config["AUTH_DB"]) as conn:
+        with db.cursor(conn) as cursor:
+            cursor.execute("SELECT user_id FROM users WHERE email= ?", (email,) )
+            user_result = cursor.fetchone()
+            cursor.execute("SELECT role_id FROM roles WHERE role_name='hook-role-from-edu-domain'")
+            role_result = cursor.fetchone()
+            resource_ids = get_resources_for_edu_domain(cursor)
+            if user_result is None or role_result is None:
+                return
+            user_id = user_result[0]
+            role_id = role_result[0]
+            cursor.executemany(
+                "INSERT INTO user_roles(user_id, role_id, resource_id) "
+                "VALUES(:user_id, :role_id, :resource_id)",
+                tuple({
+                    "user_id": user_id,
+                    "role_id": role_id,
+                    "resource_id": resource_id
+                } for resource_id in resource_ids))
+
+
+def get_resources_for_edu_domain(cursor) -> List[int]:
+    """FIXME: I still haven't figured out how to get resources to be assigned to edu domain"""
+    resources_query = """
+        SELECT resource_id FROM resources INNER JOIN resource_categories USING(resource_category_id) WHERE resource_categories.resource_category_key IN ('genotype', 'phenotype', 'mrna')
+    """
+    cursor.execute(resources_query)
+    resource_ids = [x[0] for x in cursor.fetchall()]
+    return resource_ids
diff --git a/migrations/auth/20240924_01_thbvh-hooks-for-edu-domains.py b/migrations/auth/20240924_01_thbvh-hooks-for-edu-domains.py
new file mode 100644
index 0000000..5c6e81d
--- /dev/null
+++ b/migrations/auth/20240924_01_thbvh-hooks-for-edu-domains.py
@@ -0,0 +1,24 @@
+"""
+hooks_for_edu_domains
+"""
+
+from yoyo import step
+
+__depends__ = {'20240819_01_p2vXR-create-forgot-password-tokens-table'}
+
+steps = [
+    step(
+        """
+        INSERT INTO roles(role_id, role_name, user_editable) VALUES
+            ('9bb203a2-7897-4fe3-ac4a-75e6a4f96f5d', 'hook-role-from-edu-domain', '0')
+        """,
+        "DELETE FROM roles WHERE role_name='hook-role-from-edu-domain'"),
+    step(
+        """
+        INSERT INTO role_privileges(role_id, privilege_id) VALUES
+            ('9bb203a2-7897-4fe3-ac4a-75e6a4f96f5d', 'group:resource:view-resource'),
+            ('9bb203a2-7897-4fe3-ac4a-75e6a4f96f5d', 'group:resource:edit-resource')
+        """,
+        "DELETE FROM role_privileges WHERE role_id='9bb203a2-7897-4fe3-ac4a-75e6a4f96f5d'"
+        )
+]