aboutsummaryrefslogtreecommitdiff
path: root/gn3/auth/authentication/users.py
blob: 54838a3fdc25752279fcbbca45127fca036ae73f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""User-specific code and data structures."""
from uuid import UUID, uuid4
from typing import Any, Tuple, NamedTuple

import bcrypt

from gn3.auth import db
from gn3.auth.authorisation.errors import NotFoundError

class User(NamedTuple):
    """Class representing a user."""
    user_id: UUID
    email: str
    name: str

    def get_user_id(self):
        """Return the user's UUID. Mostly for use with Authlib."""
        return self.user_id

    def dictify(self) -> dict[str, Any]:
        """Return a dict representation of `User` objects."""
        return {"user_id": self.user_id, "email": self.email, "name": self.name}

DUMMY_USER = User(user_id=UUID("a391cf60-e8b7-4294-bd22-ddbbda4b3530"),
                  email="gn3@dummy.user",
                  name="Dummy user to use as placeholder")

def user_by_email(conn: db.DbConnection, email: str) -> User:
    """Retrieve user from database by their email address"""
    with db.cursor(conn) as cursor:
        cursor.execute("SELECT * FROM users WHERE email=?", (email,))
        row = cursor.fetchone()

    if row:
        return User(UUID(row["user_id"]), row["email"], row["name"])

    raise NotFoundError(f"Could not find user with email {email}")

def user_by_id(conn: db.DbConnection, user_id: UUID) -> User:
    """Retrieve user from database by their user id"""
    with db.cursor(conn) as cursor:
        cursor.execute("SELECT * FROM users WHERE user_id=?", (str(user_id),))
        row = cursor.fetchone()

    if row:
        return User(UUID(row["user_id"]), row["email"], row["name"])

    raise NotFoundError(f"Could not find user with ID {user_id}")

def valid_login(conn: db.DbConnection, user: User, password: str) -> bool:
    """Check the validity of the provided credentials for login."""
    with db.cursor(conn) as cursor:
        cursor.execute(
            ("SELECT * FROM users LEFT JOIN user_credentials "
             "ON users.user_id=user_credentials.user_id "
             "WHERE users.user_id=?"),
            (str(user.user_id),))
        row = cursor.fetchone()

    if row is None:
        return False

    return bcrypt.checkpw(password.encode("utf-8"), row["password"])

def save_user(cursor: db.DbCursor, email: str, name: str) -> User:
    """
    Create and persist a user.

    The user creation could be done during a transaction, therefore the function
    takes a cursor object rather than a connection.

    The newly created and persisted user is then returned.
    """
    user_id = uuid4()
    cursor.execute("INSERT INTO users VALUES (?, ?, ?)",
                   (str(user_id), email, name))
    return User(user_id, email, name)

def set_user_password(
        cursor: db.DbCursor, user: User, password: str) -> Tuple[User, bytes]:
    """Set the given user's password in the database."""
    hashed_password = bcrypt.hashpw(password.encode("utf8"), bcrypt.gensalt())
    cursor.execute(
        ("INSERT INTO user_credentials VALUES (:user_id, :hash) "
         "ON CONFLICT (user_id) DO UPDATE SET password=:hash"),
        {"user_id": str(user.user_id), "hash": hashed_password})
    return user, hashed_password