diff options
author | Frederick Muriuki Muriithi | 2023-08-04 10:10:28 +0300 |
---|---|---|
committer | Frederick Muriuki Muriithi | 2023-08-04 10:20:09 +0300 |
commit | 8b7c598407a5fea9a3d78473e72df87606998cd4 (patch) | |
tree | 8526433a17eca6b511feb082a0574f9b15cb9469 /gn_auth/auth/authentication/users.py | |
parent | f7fcbbcc014686ac597b783a8dcb38b43024b9d6 (diff) | |
download | gn-auth-8b7c598407a5fea9a3d78473e72df87606998cd4.tar.gz |
Copy over files from GN3 repository.
Diffstat (limited to 'gn_auth/auth/authentication/users.py')
-rw-r--r-- | gn_auth/auth/authentication/users.py | 128 |
1 files changed, 128 insertions, 0 deletions
diff --git a/gn_auth/auth/authentication/users.py b/gn_auth/auth/authentication/users.py new file mode 100644 index 0000000..0e72ed2 --- /dev/null +++ b/gn_auth/auth/authentication/users.py @@ -0,0 +1,128 @@ +"""User-specific code and data structures.""" +from uuid import UUID, uuid4 +from typing import Any, Tuple, NamedTuple + +from argon2 import PasswordHasher +from argon2.exceptions import VerifyMismatchError + +from gn3.auth import db +from gn3.auth.authorisation.errors import NotFoundError + +class User(NamedTuple): + """Class representing a user.""" + user_id: UUID + email: str + name: str + + def get_user_id(self): + """Return the user's UUID. Mostly for use with Authlib.""" + return self.user_id + + def dictify(self) -> dict[str, Any]: + """Return a dict representation of `User` objects.""" + return {"user_id": self.user_id, "email": self.email, "name": self.name} + +DUMMY_USER = User(user_id=UUID("a391cf60-e8b7-4294-bd22-ddbbda4b3530"), + email="gn3@dummy.user", + name="Dummy user to use as placeholder") + +def user_by_email(conn: db.DbConnection, email: str) -> User: + """Retrieve user from database by their email address""" + with db.cursor(conn) as cursor: + cursor.execute("SELECT * FROM users WHERE email=?", (email,)) + row = cursor.fetchone() + + if row: + return User(UUID(row["user_id"]), row["email"], row["name"]) + + raise NotFoundError(f"Could not find user with email {email}") + +def user_by_id(conn: db.DbConnection, user_id: UUID) -> User: + """Retrieve user from database by their user id""" + with db.cursor(conn) as cursor: + cursor.execute("SELECT * FROM users WHERE user_id=?", (str(user_id),)) + row = cursor.fetchone() + + if row: + return User(UUID(row["user_id"]), row["email"], row["name"]) + + raise NotFoundError(f"Could not find user with ID {user_id}") + +def same_password(password: str, hashed: str) -> bool: + """Check that `raw_password` is hashed to `hash`""" + try: + return hasher().verify(hashed, password) + except VerifyMismatchError as _vme: + return False + +def valid_login(conn: db.DbConnection, user: User, password: str) -> bool: + """Check the validity of the provided credentials for login.""" + with db.cursor(conn) as cursor: + cursor.execute( + ("SELECT * FROM users LEFT JOIN user_credentials " + "ON users.user_id=user_credentials.user_id " + "WHERE users.user_id=?"), + (str(user.user_id),)) + row = cursor.fetchone() + + if row is None: + return False + + return same_password(password, row["password"]) + +def save_user(cursor: db.DbCursor, email: str, name: str) -> User: + """ + Create and persist a user. + + The user creation could be done during a transaction, therefore the function + takes a cursor object rather than a connection. + + The newly created and persisted user is then returned. + """ + user_id = uuid4() + cursor.execute("INSERT INTO users VALUES (?, ?, ?)", + (str(user_id), email, name)) + return User(user_id, email, name) + +def hasher(): + """Retrieve PasswordHasher object""" + # TODO: Maybe tune the parameters here... + # Tuneable Parameters: + # - time_cost (default: 2) + # - memory_cost (default: 102400) + # - parallelism (default: 8) + # - hash_len (default: 16) + # - salt_len (default: 16) + # - encoding (default: 'utf-8') + # - type (default: <Type.ID: 2>) + return PasswordHasher() + +def hash_password(password): + """Hash the password.""" + return hasher().hash(password) + +def set_user_password( + cursor: db.DbCursor, user: User, password: str) -> Tuple[User, bytes]: + """Set the given user's password in the database.""" + hashed_password = hash_password(password) + cursor.execute( + ("INSERT INTO user_credentials VALUES (:user_id, :hash) " + "ON CONFLICT (user_id) DO UPDATE SET password=:hash"), + {"user_id": str(user.user_id), "hash": hashed_password}) + return user, hashed_password + +def users(conn: db.DbConnection, + ids: tuple[UUID, ...] = tuple()) -> tuple[User, ...]: + """ + Fetch all users with the given `ids`. If `ids` is empty, return ALL users. + """ + params = ", ".join(["?"] * len(ids)) + with db.cursor(conn) as cursor: + query = "SELECT * FROM users" + ( + f" WHERE user_id IN ({params})" + if len(ids) > 0 else "") + print(query) + cursor.execute(query, tuple(str(the_id) for the_id in ids)) + return tuple(User(UUID(row["user_id"]), row["email"], row["name"]) + for row in cursor.fetchall()) + return tuple() |