diff options
author | Frederick Muriuki Muriithi | 2022-12-12 13:29:45 +0300 |
---|---|---|
committer | Frederick Muriuki Muriithi | 2022-12-12 13:29:45 +0300 |
commit | 56b54da6fc6e97d5d6dac70f2393dcc98d93991c (patch) | |
tree | e58fc122b169fa56898ac4e352407060d8c64df4 | |
parent | 6c077671e9afdb4921e72d0a3018e3d8dedada8b (diff) | |
download | genenetwork3-56b54da6fc6e97d5d6dac70f2393dcc98d93991c.tar.gz |
auth: pass cursor object to `user_group` function
-rw-r--r-- | gn3/auth/authorisation/groups.py | 15 | ||||
-rw-r--r-- | tests/unit/auth/test_groups.py | 5 |
2 files changed, 11 insertions, 9 deletions
diff --git a/gn3/auth/authorisation/groups.py b/gn3/auth/authorisation/groups.py index cda11b3..dbc9f7d 100644 --- a/gn3/auth/authorisation/groups.py +++ b/gn3/auth/authorisation/groups.py @@ -108,15 +108,14 @@ def authenticated_user_group(conn) -> Maybe: return Nothing -def user_group(conn: db.DbConnection, user: User) -> Maybe: +def user_group(cursor: db.DbCursor, user: User) -> Maybe: """Returns the given user's group""" - with db.cursor(conn) as cursor: - cursor.execute( - ("SELECT groups.group_id, groups.group_name FROM group_users " - "INNER JOIN groups ON group_users.group_id=groups.group_id " - "WHERE group_users.user_id = ?"), - (str(user.user_id),)) - groups = tuple(Group(UUID(row[0]), row[1]) for row in cursor.fetchall()) + cursor.execute( + ("SELECT groups.group_id, groups.group_name FROM group_users " + "INNER JOIN groups ON group_users.group_id=groups.group_id " + "WHERE group_users.user_id = ?"), + (str(user.user_id),)) + groups = tuple(Group(UUID(row[0]), row[1]) for row in cursor.fetchall()) if len(groups) > 1: raise MembershipError(user, groups) diff --git a/tests/unit/auth/test_groups.py b/tests/unit/auth/test_groups.py index e1b44cc..5a27c71 100644 --- a/tests/unit/auth/test_groups.py +++ b/tests/unit/auth/test_groups.py @@ -117,4 +117,7 @@ def test_user_group(test_users_in_group, user, expected): Nothing """ conn, _group, _users = test_users_in_group - assert user_group(conn, user).maybe(Nothing, lambda val: val) == expected + with db.cursor(conn) as cursor: + assert ( + user_group(cursor, user).maybe(Nothing, lambda val: val) + == expected) |