aboutsummaryrefslogtreecommitdiff
"""Major function for handling admin users."""
import warnings

from gn_auth.auth.db import sqlite3 as db
from gn_auth.auth.authentication.users import User
from gn_auth.auth.authorisation.roles.models import Role, db_rows_to_roles


def sysadmin_role(conn: db.DbConnection) -> Role:
    """Fetch the `system-administrator` role details."""
    with db.cursor(conn) as cursor:
        cursor.execute(
            "SELECT roles.*, privileges.* "
            "FROM roles INNER JOIN role_privileges "
            "ON roles.role_id=role_privileges.role_id "
            "INNER JOIN privileges "
            "ON role_privileges.privilege_id=privileges.privilege_id "
            "WHERE role_name='system-administrator'")
        results = db_rows_to_roles(cursor.fetchall())

    assert len(results) == 1, (
        "There should only ever be one 'system-administrator' role.")
    return results[0]


def grant_sysadmin_role(cursor: db.DbCursor, user: User) -> User:
    """Grant `system-administrator` role to `user`."""
    cursor.execute(
            "SELECT * FROM roles WHERE role_name='system-administrator'")
    admin_role = cursor.fetchone()
    cursor.execute("SELECT resources.resource_id FROM resources")
    cursor.executemany(
        "INSERT INTO user_roles VALUES (:user_id, :role_id, :resource_id)",
        tuple({
            "user_id": str(user.user_id),
            "role_id": admin_role["role_id"],
            "resource_id": resource_id
        } for resource_id in cursor.fetchall()))
    return user


def make_sys_admin(cursor: db.DbCursor, user: User) -> User:
    """Make a given user into an system admin."""
    warnings.warn(
        DeprecationWarning(
            f"The function `{__name__}.make_sys_admin` will be removed soon"),
        stacklevel=1)
    return grant_sysadmin_role(cursor, user)


def revoke_sysadmin_role(conn: db.DbConnection, user: User):
    """Revoke `system-administrator` role from `user`."""
    with db.cursor(conn) as cursor:
        cursor.execute("DELETE FROM user_roles WHERE user_id=? AND role_id=?",
                       (str(user.user_id), str(sysadmin_role(conn).role_id)))