aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-12-12 13:29:45 +0300
committerFrederick Muriuki Muriithi2022-12-12 13:29:45 +0300
commit56b54da6fc6e97d5d6dac70f2393dcc98d93991c (patch)
treee58fc122b169fa56898ac4e352407060d8c64df4
parent6c077671e9afdb4921e72d0a3018e3d8dedada8b (diff)
downloadgenenetwork3-56b54da6fc6e97d5d6dac70f2393dcc98d93991c.tar.gz
auth: pass cursor object to `user_group` function
-rw-r--r--gn3/auth/authorisation/groups.py15
-rw-r--r--tests/unit/auth/test_groups.py5
2 files changed, 11 insertions, 9 deletions
diff --git a/gn3/auth/authorisation/groups.py b/gn3/auth/authorisation/groups.py
index cda11b3..dbc9f7d 100644
--- a/gn3/auth/authorisation/groups.py
+++ b/gn3/auth/authorisation/groups.py
@@ -108,15 +108,14 @@ def authenticated_user_group(conn) -> Maybe:
return Nothing
-def user_group(conn: db.DbConnection, user: User) -> Maybe:
+def user_group(cursor: db.DbCursor, user: User) -> Maybe:
"""Returns the given user's group"""
- with db.cursor(conn) as cursor:
- cursor.execute(
- ("SELECT groups.group_id, groups.group_name FROM group_users "
- "INNER JOIN groups ON group_users.group_id=groups.group_id "
- "WHERE group_users.user_id = ?"),
- (str(user.user_id),))
- groups = tuple(Group(UUID(row[0]), row[1]) for row in cursor.fetchall())
+ cursor.execute(
+ ("SELECT groups.group_id, groups.group_name FROM group_users "
+ "INNER JOIN groups ON group_users.group_id=groups.group_id "
+ "WHERE group_users.user_id = ?"),
+ (str(user.user_id),))
+ groups = tuple(Group(UUID(row[0]), row[1]) for row in cursor.fetchall())
if len(groups) > 1:
raise MembershipError(user, groups)
diff --git a/tests/unit/auth/test_groups.py b/tests/unit/auth/test_groups.py
index e1b44cc..5a27c71 100644
--- a/tests/unit/auth/test_groups.py
+++ b/tests/unit/auth/test_groups.py
@@ -117,4 +117,7 @@ def test_user_group(test_users_in_group, user, expected):
Nothing
"""
conn, _group, _users = test_users_in_group
- assert user_group(conn, user).maybe(Nothing, lambda val: val) == expected
+ with db.cursor(conn) as cursor:
+ assert (
+ user_group(cursor, user).maybe(Nothing, lambda val: val)
+ == expected)