about summary refs log tree commit diff
path: root/gn3/auth/authorisation/groups/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/auth/authorisation/groups/models.py')
-rw-r--r--gn3/auth/authorisation/groups/models.py52
1 files changed, 27 insertions, 25 deletions
diff --git a/gn3/auth/authorisation/groups/models.py b/gn3/auth/authorisation/groups/models.py
index 5a58322..bbe4ad6 100644
--- a/gn3/auth/authorisation/groups/models.py
+++ b/gn3/auth/authorisation/groups/models.py
@@ -142,30 +142,31 @@ def authenticated_user_group(conn) -> Maybe:
 
     return Nothing
 
-def user_group(cursor: db.DbCursor, user: User) -> Maybe[Group]:
+def user_group(conn: db.DbConnection, user: User) -> Maybe[Group]:
     """Returns the given user's group"""
-    cursor.execute(
-        ("SELECT groups.group_id, groups.group_name, groups.group_metadata "
-         "FROM group_users "
-         "INNER JOIN groups ON group_users.group_id=groups.group_id "
-         "WHERE group_users.user_id = ?"),
-        (str(user.user_id),))
-    groups = tuple(
-        Group(UUID(row[0]), row[1], json.loads(row[2] or "{}"))
-        for row in cursor.fetchall())
+    with db.cursor(conn) as cursor:
+        cursor.execute(
+            ("SELECT groups.group_id, groups.group_name, groups.group_metadata "
+             "FROM group_users "
+             "INNER JOIN groups ON group_users.group_id=groups.group_id "
+             "WHERE group_users.user_id = ?"),
+            (str(user.user_id),))
+        groups = tuple(
+            Group(UUID(row[0]), row[1], json.loads(row[2] or "{}"))
+            for row in cursor.fetchall())
 
-    if len(groups) > 1:
-        raise MembershipError(user, groups)
+        if len(groups) > 1:
+            raise MembershipError(user, groups)
 
-    if len(groups) == 1:
-        return Just(groups[0])
+        if len(groups) == 1:
+            return Just(groups[0])
 
     return Nothing
 
-def is_group_leader(cursor: db.DbCursor, user: User, group: Group):
+def is_group_leader(conn: db.DbConnection, user: User, group: Group) -> bool:
     """Check whether the given `user` is the leader of `group`."""
 
-    ugroup = user_group(cursor, user).maybe(
+    ugroup = user_group(conn, user).maybe(
         False, lambda val: val) # type: ignore[arg-type, misc]
     if not group:
         # User cannot be a group leader if not a member of ANY group
@@ -175,13 +176,14 @@ def is_group_leader(cursor: db.DbCursor, user: User, group: Group):
         # User cannot be a group leader if not a member of THIS group
         return False
 
-    cursor.execute(
-        ("SELECT roles.role_name FROM user_roles LEFT JOIN roles "
-         "ON user_roles.role_id = roles.role_id WHERE user_id = ?"),
-        (str(user.user_id),))
-    role_names = tuple(row[0] for row in cursor.fetchall())
+    with db.cursor(conn) as cursor:
+        cursor.execute(
+            ("SELECT roles.role_name FROM user_roles LEFT JOIN roles "
+             "ON user_roles.role_id = roles.role_id WHERE user_id = ?"),
+            (str(user.user_id),))
+        role_names = tuple(row[0] for row in cursor.fetchall())
 
-    return "group-leader" in role_names
+        return "group-leader" in role_names
 
 def all_groups(conn: db.DbConnection) -> Maybe[Sequence[Group]]:
     """Retrieve all existing groups"""
@@ -258,8 +260,8 @@ def group_by_id(conn: db.DbConnection, group_id: UUID) -> Group:
 def join_requests(conn: db.DbConnection, user: User):
     """List all the join requests for the user's group."""
     with db.cursor(conn) as cursor:
-        group = user_group(cursor, user).maybe(DUMMY_GROUP, lambda grp: grp)# type: ignore[misc]
-        if group != DUMMY_GROUP and is_group_leader(cursor, user, group):
+        group = user_group(conn, user).maybe(DUMMY_GROUP, lambda grp: grp)# type: ignore[misc]
+        if group != DUMMY_GROUP and is_group_leader(conn, user, group):
             cursor.execute(
                 "SELECT gjr.*, u.email, u.name FROM group_join_requests AS gjr "
                 "INNER JOIN users AS u ON gjr.requester_id=u.user_id "
@@ -280,7 +282,7 @@ def accept_reject_join_request(
     """Accept/Reject a join request."""
     assert status in ("ACCEPTED", "REJECTED"), f"Invalid status '{status}'."
     with db.cursor(conn) as cursor:
-        group = user_group(cursor, user).maybe(DUMMY_GROUP, lambda grp: grp) # type: ignore[misc]
+        group = user_group(conn, user).maybe(DUMMY_GROUP, lambda grp: grp) # type: ignore[misc]
         cursor.execute("SELECT * FROM group_join_requests WHERE request_id=?",
                        (str(request_id),))
         row = cursor.fetchone()