From 56b54da6fc6e97d5d6dac70f2393dcc98d93991c Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Mon, 12 Dec 2022 13:29:45 +0300 Subject: auth: pass cursor object to `user_group` function --- gn3/auth/authorisation/groups.py | 15 +++++++-------- tests/unit/auth/test_groups.py | 5 ++++- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/gn3/auth/authorisation/groups.py b/gn3/auth/authorisation/groups.py index cda11b3..dbc9f7d 100644 --- a/gn3/auth/authorisation/groups.py +++ b/gn3/auth/authorisation/groups.py @@ -108,15 +108,14 @@ def authenticated_user_group(conn) -> Maybe: return Nothing -def user_group(conn: db.DbConnection, user: User) -> Maybe: +def user_group(cursor: db.DbCursor, user: User) -> Maybe: """Returns the given user's group""" - with db.cursor(conn) as cursor: - cursor.execute( - ("SELECT groups.group_id, groups.group_name FROM group_users " - "INNER JOIN groups ON group_users.group_id=groups.group_id " - "WHERE group_users.user_id = ?"), - (str(user.user_id),)) - groups = tuple(Group(UUID(row[0]), row[1]) for row in cursor.fetchall()) + cursor.execute( + ("SELECT groups.group_id, groups.group_name FROM group_users " + "INNER JOIN groups ON group_users.group_id=groups.group_id " + "WHERE group_users.user_id = ?"), + (str(user.user_id),)) + groups = tuple(Group(UUID(row[0]), row[1]) for row in cursor.fetchall()) if len(groups) > 1: raise MembershipError(user, groups) diff --git a/tests/unit/auth/test_groups.py b/tests/unit/auth/test_groups.py index e1b44cc..5a27c71 100644 --- a/tests/unit/auth/test_groups.py +++ b/tests/unit/auth/test_groups.py @@ -117,4 +117,7 @@ def test_user_group(test_users_in_group, user, expected): Nothing """ conn, _group, _users = test_users_in_group - assert user_group(conn, user).maybe(Nothing, lambda val: val) == expected + with db.cursor(conn) as cursor: + assert ( + user_group(cursor, user).maybe(Nothing, lambda val: val) + == expected) -- cgit v1.2.3