about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/decorators.py27
-rw-r--r--wqflask/wqflask/group_manager.py10
-rw-r--r--wqflask/wqflask/metadata_edits.py11
-rw-r--r--wqflask/wqflask/resource_manager.py10
4 files changed, 30 insertions, 28 deletions
diff --git a/wqflask/wqflask/decorators.py b/wqflask/wqflask/decorators.py
index e33e6bb6..b32c8fc8 100644
--- a/wqflask/wqflask/decorators.py
+++ b/wqflask/wqflask/decorators.py
@@ -13,22 +13,25 @@ from gn3.authentication import DataRole
 
 from wqflask.oauth2 import client
 from wqflask.oauth2.session import session_info
+from wqflask.oauth2.checks import user_logged_in
 from wqflask.oauth2.request_utils import process_error
 
 
-def login_required(f):
+def login_required(pagename: str = ""):
     """Use this for endpoints where login is required"""
-    @wraps(f)
-    def wrap(*args, **kwargs):
-        user_id = ((g.user_session.record.get(b"user_id") or
-                    b"").decode("utf-8")
-                   or g.user_session.record.get("user_id") or "")
-        redis_conn = redis.from_url(current_app.config["REDIS_URL"],
-                                    decode_responses=True)
-        if not redis_conn.hget("users", user_id):
-            return "You need to be logged in!", 401
-        return f(*args, **kwargs)
-    return wrap
+    def __build_wrap__(func):
+        @wraps(func)
+        def wrap(*args, **kwargs):
+            if not user_logged_in():
+                msg = ("You need to be logged in to access that page."
+                       if not bool(pagename) else
+                       ("You need to be logged in to access the "
+                        f"'{pagename.title()}' page."))
+                flash(msg, "alert-warning")
+                return redirect("/")
+            return func(*args, **kwargs)
+        return wrap
+    return __build_wrap__
 
 
 def edit_access_required(f):
diff --git a/wqflask/wqflask/group_manager.py b/wqflask/wqflask/group_manager.py
index 3936e36e..71ced4dd 100644
--- a/wqflask/wqflask/group_manager.py
+++ b/wqflask/wqflask/group_manager.py
@@ -18,7 +18,7 @@ group_management = Blueprint("group_management", __name__)
 
 
 @group_management.route("/groups")
-@login_required
+@login_required()
 def display_groups():
     groups = get_groups_by_user_uid(
         user_uid=(g.user_session.record.get(b"user_id",
@@ -33,13 +33,13 @@ def display_groups():
 
 
 @group_management.route("/groups/create", methods=("GET",))
-@login_required
+@login_required()
 def view_create_group_page():
     return render_template("admin/create_group.html")
 
 
 @group_management.route("/groups/create", methods=("POST",))
-@login_required
+@login_required()
 def create_new_group():
     conn = redis.from_url(current_app.config["REDIS_URL"],
                           decode_responses=True)
@@ -75,7 +75,7 @@ def create_new_group():
 
 
 @group_management.route("/groups/delete", methods=("POST",))
-@login_required
+@login_required()
 def delete_groups():
     conn = redis.from_url(current_app.config["REDIS_URL"],
                           decode_responses=True)
@@ -92,7 +92,7 @@ def delete_groups():
 
 
 @group_management.route("/groups/<group_id>")
-@login_required
+@login_required()
 def view_group(group_id: str):
     conn = redis.from_url(current_app.config["REDIS_URL"],
                           decode_responses=True)
diff --git a/wqflask/wqflask/metadata_edits.py b/wqflask/wqflask/metadata_edits.py
index 8f0166e5..1ff1e5c5 100644
--- a/wqflask/wqflask/metadata_edits.py
+++ b/wqflask/wqflask/metadata_edits.py
@@ -486,7 +486,7 @@ def update_probeset(name: str):
 
 
 @metadata_edit.route("/<dataset_id>/traits/<phenotype_id>/csv")
-@login_required
+@login_required()
 def get_sample_data_as_csv(dataset_id: str, phenotype_id: int):
     from utility.tools import get_setting
     with database_connection(get_setting("SQL_URI")) as conn:
@@ -505,7 +505,7 @@ filename=sample-data-{dataset_id}.csv"
 
 
 @metadata_edit.route("/diffs")
-# @login_required
+@login_required(pagename="Sample Data Diffs")
 def list_diffs():
     files = _get_diffs(
         diff_dir=f"{current_app.config.get('TMPDIR')}/sample-data/diffs",
@@ -573,6 +573,7 @@ def list_diffs():
 
 
 @metadata_edit.route("/diffs/<name>")
+@login_required(pagename="diff display")
 def show_diff(name):
     TMPDIR = current_app.config.get("TMPDIR")
     with open(
@@ -651,8 +652,7 @@ def show_history(dataset_id: str = "", name: str = ""):
 
 
 @metadata_edit.route("<resource_id>/diffs/<file_name>/reject")
-@edit_admins_access_required
-@login_required
+@login_required(pagename="sample data rejection")
 def reject_data(resource_id: str, file_name: str):
     TMPDIR = current_app.config.get("TMPDIR")
     os.rename(
@@ -664,8 +664,7 @@ def reject_data(resource_id: str, file_name: str):
 
 
 @metadata_edit.route("<resource_id>/diffs/<file_name>/approve")
-@edit_admins_access_required
-@login_required
+@login_required(pagename="Sample Data Approval")
 def approve_data(resource_id: str, file_name: str):
     from utility.tools import get_setting
     sample_data = {file_name: str}
diff --git a/wqflask/wqflask/resource_manager.py b/wqflask/wqflask/resource_manager.py
index 91731a4a..375e3d64 100644
--- a/wqflask/wqflask/resource_manager.py
+++ b/wqflask/wqflask/resource_manager.py
@@ -72,7 +72,7 @@ unique identifiers so they aren't human readable names.
 
 
 @resource_management.route("/resources/<resource_id>")
-@login_required
+@login_required()
 def view_resource(resource_id: str):
     user_id = (g.user_session.record.get(b"user_id",
                                          b"").decode("utf-8") or
@@ -99,7 +99,7 @@ def view_resource(resource_id: str):
 @resource_management.route("/resources/<resource_id>/make-public",
                            methods=('POST',))
 @edit_access_required
-@login_required
+@login_required()
 def update_resource_publicity(resource_id: str):
     redis_conn = redis.from_url(
         current_app.config["REDIS_URL"],
@@ -128,7 +128,7 @@ def update_resource_publicity(resource_id: str):
 
 @resource_management.route("/resources/<resource_id>/change-owner")
 @edit_admins_access_required
-@login_required
+@login_required()
 def view_resource_owner(resource_id: str):
     return render_template(
         "admin/change_resource_owner.html",
@@ -138,7 +138,7 @@ def view_resource_owner(resource_id: str):
 @resource_management.route("/resources/<resource_id>/change-owner",
                            methods=('POST',))
 @edit_admins_access_required
-@login_required
+@login_required()
 def change_owner(resource_id: str):
     if user_id := request.form.get("new_owner"):
         redis_conn = redis.from_url(
@@ -154,7 +154,7 @@ def change_owner(resource_id: str):
 
 @resource_management.route("<resource_id>/users/search", methods=('POST',))
 @edit_admins_access_required
-@login_required
+@login_required()
 def search_user(resource_id: str):
     results = {}
     for user in (users := redis.from_url(