1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
"""User-specific code and data structures."""
from uuid import UUID, uuid4
from typing import Any, Tuple, NamedTuple
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from gn_auth.auth import db
from gn_auth.auth.authorisation.errors import NotFoundError
class User(NamedTuple):
"""Class representing a user."""
user_id: UUID
email: str
name: str
def get_user_id(self):
"""Return the user's UUID. Mostly for use with Authlib."""
return self.user_id
def dictify(self) -> dict[str, Any]:
"""Return a dict representation of `User` objects."""
return {"user_id": self.user_id, "email": self.email, "name": self.name}
DUMMY_USER = User(user_id=UUID("a391cf60-e8b7-4294-bd22-ddbbda4b3530"),
email="gn3@dummy.user",
name="Dummy user to use as placeholder")
def user_by_email(conn: db.DbConnection, email: str) -> User:
"""Retrieve user from database by their email address"""
with db.cursor(conn) as cursor:
cursor.execute("SELECT * FROM users WHERE email=?", (email,))
row = cursor.fetchone()
if row:
return User(UUID(row["user_id"]), row["email"], row["name"])
raise NotFoundError(f"Could not find user with email {email}")
def user_by_id(conn: db.DbConnection, user_id: UUID) -> User:
"""Retrieve user from database by their user id"""
with db.cursor(conn) as cursor:
cursor.execute("SELECT * FROM users WHERE user_id=?", (str(user_id),))
row = cursor.fetchone()
if row:
return User(UUID(row["user_id"]), row["email"], row["name"])
raise NotFoundError(f"Could not find user with ID {user_id}")
def same_password(password: str, hashed: str) -> bool:
"""Check that `raw_password` is hashed to `hash`"""
try:
return hasher().verify(hashed, password)
except VerifyMismatchError as _vme:
return False
def valid_login(conn: db.DbConnection, user: User, password: str) -> bool:
"""Check the validity of the provided credentials for login."""
with db.cursor(conn) as cursor:
cursor.execute(
("SELECT * FROM users LEFT JOIN user_credentials "
"ON users.user_id=user_credentials.user_id "
"WHERE users.user_id=?"),
(str(user.user_id),))
row = cursor.fetchone()
if row is None:
return False
return same_password(password, row["password"])
def save_user(cursor: db.DbCursor, email: str, name: str) -> User:
"""
Create and persist a user.
The user creation could be done during a transaction, therefore the function
takes a cursor object rather than a connection.
The newly created and persisted user is then returned.
"""
user_id = uuid4()
cursor.execute("INSERT INTO users VALUES (?, ?, ?)",
(str(user_id), email, name))
return User(user_id, email, name)
def hasher():
"""Retrieve PasswordHasher object"""
# TODO: Maybe tune the parameters here...
# Tuneable Parameters:
# - time_cost (default: 2)
# - memory_cost (default: 102400)
# - parallelism (default: 8)
# - hash_len (default: 16)
# - salt_len (default: 16)
# - encoding (default: 'utf-8')
# - type (default: <Type.ID: 2>)
return PasswordHasher()
def hash_password(password):
"""Hash the password."""
return hasher().hash(password)
def set_user_password(
cursor: db.DbCursor, user: User, password: str) -> Tuple[User, bytes]:
"""Set the given user's password in the database."""
hashed_password = hash_password(password)
cursor.execute(
("INSERT INTO user_credentials VALUES (:user_id, :hash) "
"ON CONFLICT (user_id) DO UPDATE SET password=:hash"),
{"user_id": str(user.user_id), "hash": hashed_password})
return user, hashed_password
def users(conn: db.DbConnection,
ids: tuple[UUID, ...] = tuple()) -> tuple[User, ...]:
"""
Fetch all users with the given `ids`. If `ids` is empty, return ALL users.
"""
params = ", ".join(["?"] * len(ids))
with db.cursor(conn) as cursor:
query = "SELECT * FROM users" + (
f" WHERE user_id IN ({params})"
if len(ids) > 0 else "")
print(query)
cursor.execute(query, tuple(str(the_id) for the_id in ids))
return tuple(User(UUID(row["user_id"]), row["email"], row["name"])
for row in cursor.fetchall())
return tuple()
|