aboutsummaryrefslogtreecommitdiff
#!/usr/bin/env python3

"""
WARNING
    This script is to be run only one time to update the schema for the
    `CaseAttribute` and the `CaseAttributeXRefNew` tables.
    Running the script more than once has no useful purpose, and will just
    litter your database schema with `BACKUP_CaseAttribute*` tables.

DESCRIPTION
    The script makes the following schema updates

    For the `CaseAttribute` table:
    * Adds `InbredSetId` column: links each case attribute to a population
    * Rename `Id` to `CaseAttributeId`: Makes it explicit what the ID is for,
      and helps simplify queries with joins against this table

    For the `CaseAttributeXRefNew` table:
    * Reorganise order of columns.

    For out of date databases (e.g. small db), the script will also:
    * Ppdate the character set for the `InbredSet` and `Strain` tables to
      utf8mb4.
    * Change the database engine for the `InbredSet` and `Strain` tables to
      InnoDB.

TABLE BACKUPS
    The script will backup the `CaseAttribute` and `CaseAttributeXRefNew` tables.
    The backup table names take the form:

      BACKUP_<tablename>_<timestr>

    Where <tablename> is one of `CaseAttribute` or `CaseAttributeXRefNew` and
    <timestr> is a string indicating the date and time the script was run (in
    local timezone for system the script is run on).

USAGE
    python3 update-case-attribute-tables-20230818 SQL_URI

    The `SQL_URI` argument is mandatory and is of the form:

      mysql://<username>:<password>@<host>[:<port>]/<database>
"""
import os
import sys
import time
import random
import traceback
from datetime import datetime
from urllib.parse import urlparse

import click
import MySQLdb as mdb
from MySQLdb.cursors import DictCursor

# from gn3.db_utils import database_connection

def convert_to_innodb(cursor, table):
    """Convert `table` to InnoDB Engine."""
    cursor.execute(f"SHOW CREATE TABLE {table};")
    res = cursor.fetchone()
    tblname, tblcreate = res["Table"], res["Create Table"]
    engine_charset = {
        key: val for key, val in
        (item.split("=") for item in
         (item for item in tblcreate.split("\n")[-1].split(" ")
          if item.startswith("ENGINE=") or item.startswith("CHARSET=")))}
    if engine_charset.get("CHARSET") != "utf8mb4":
        cursor.execute(
            f"ALTER TABLE {table} CONVERT TO CHARACTER SET utf8mb4")
    if engine_charset["ENGINE"] == "MyISAM":
        cursor.execute(f"ALTER TABLE {table} ENGINE=InnoDB")


def table_exists(cursor, table: str) -> bool:
    """Check whether a table exists."""
    cursor.execute(f"SHOW TABLES")
    return table in tuple(tuple(row.values())[0] for row in cursor.fetchall())

def table_has_field(cursor, table: str, field: str) -> bool:
    """Check whether `table` has `field`."""
    cursor.execute(f"DESC {table}")
    return field in tuple(row["Field"] for row in cursor.fetchall())

def cleanup_inbred_set_schema(cursor):
    """
    Clean up the InbredSet schema to prevent issues with ForeignKey constraints.
    """
    cursor.execute("SELECT Id, InbredSetId FROM InbredSet "
                   "WHERE InbredSetId IS NULL OR InbredSetId = ''")
    fixed_nulls =tuple({"Id": row[0], "InbredSetId": row[0]}
                       for row in cursor.fetchall())
    if len(fixed_nulls) > 0:
        cursor.executemany(
            "UPDATE InbredSet SET InbredSetId=%(InbredSetId)s "
            "WHERE Id=%(Id)s",
            fixed_nulls)

    cursor.execute("""ALTER TABLE InbredSet
    CHANGE COLUMN InbredSetId InbredSetId INT(5) UNSIGNED NOT NULL""")

def create_temp_case_attributes_table(cursor):
    """Create the `CaseAttributeTemp` table."""
    case_attribute_exists = table_exists(cursor, "CaseAttribute")
    if (not case_attribute_exists
        or (case_attribute_exists
            and not table_has_field(cursor, "CaseAttribute", "InbredSet"))):
        cursor.execute(
            """CREATE TABLE IF NOT EXISTS CaseAttributeTemp(
              InbredSetId INT(5) UNSIGNED NOT NULL,
              CaseAttributeId INT(5) UNSIGNED NOT NULL,
              Name VARCHAR(30) NOT NULL,
              Description VARCHAR(250) NOT NULL,
              -- FOREIGN KEY(InbredSetId) REFERENCES InbredSet(InbredSetId)
              --   ON DELETE RESTRICT ON UPDATE CASCADE,
              PRIMARY KEY(InbredSetId, CaseAttributeId)
            ) ENGINE=InnoDB CHARSET=utf8mb4;""")

def create_temp_case_attributes_xref_table(cursor):
    """Create the `CaseAttributeXRefNewTemp` table."""
    if table_exists(cursor, "CaseAttributeTemp"):
        cursor.execute(
            """CREATE TABLE IF NOT EXISTS CaseAttributeXRefNewTemp(
              InbredSetId INT(5) UNSIGNED NOT NULL,
              StrainId INT(20) UNSIGNED NOT NULL,
              CaseAttributeId INT(5) UNSIGNED NOT NULL,
              Value VARCHAR(100) NOT NULL,
              -- FOREIGN KEY(InbredSetId) REFERENCES InbredSet(InbredSetId)
              --   ON UPDATE CASCADE ON DELETE RESTRICT,
              -- FOREIGN KEY(StrainId) REFERENCES Strain(Id)
              --   ON UPDATE CASCADE ON DELETE RESTRICT,
              -- FOREIGN KEY (CaseAttributeId) REFERENCES CaseAttribute(CaseAttributeId)
              --   ON UPDATE CASCADE ON DELETE RESTRICT,
              PRIMARY KEY(InbredSetId, StrainId, CaseAttributeId)
            ) ENGINE=InnoDB CHARSET=utf8mb4;""")

def fetch_case_attribute_data(cursor, limit: int = 1000):
    """Fetch case attribute data."""
    offset = 0
    while True:
        cursor.execute(
            "SELECT "
            "caxrn.StrainId, caxrn.CaseAttributeId, caxrn.Value, "
            "ca.Name AS CaseAttributeName, "
            "ca.Description AS CaseAttributeDescription, iset.InbredSetId "
            "FROM "
            "CaseAttribute AS ca INNER JOIN CaseAttributeXRefNew AS caxrn "
            "ON ca.Id=caxrn.CaseAttributeId "
            "INNER JOIN "
            "StrainXRef AS sxr "
            "ON caxrn.StrainId=sxr.StrainId "
            "INNER JOIN "
            "InbredSet AS iset "
            "ON sxr.InbredSetId=iset.InbredSetId "
            "WHERE "
            "caxrn.value != 'x' "
            "AND caxrn.value IS NOT NULL "
            f"LIMIT {limit} OFFSET {offset}")
        results = cursor.fetchall()
        if len(results) <= 0:
            break
        yield results
        offset = offset + len(results)

def copy_data(cursor):
    """Copy data from existing tables into new temp tables."""
    if table_exists(cursor, "CaseAttributeTemp") and table_exists(cursor, "CaseAttributeXRefNewTemp"):
        for bunch_of_data in fetch_case_attribute_data(cursor):
            ca_data = tuple({key: val for key, val in item} for item in set(
                (("InbredSetId", item["InbredSetId"]),
                 ("StrainId", item["StrainId"]),
                 ("CaseAttributeId", item["CaseAttributeId"]),
                 ("Name", item["CaseAttributeName"]),
                 ("Description", (item["CaseAttributeDescription"]
                                  or item["CaseAttributeName"])),
                 ("Value", item["Value"]))
                for item in bunch_of_data))
            cursor.executemany(
                "INSERT INTO "
                "CaseAttributeTemp("
                "InbredSetId, CaseAttributeId, Name, Description) "
                "VALUES("
                "%(InbredSetId)s, %(CaseAttributeId)s, %(Name)s, "
                "%(Description)s) "
                "ON DUPLICATE KEY UPDATE Name=VALUES(Name)",
                ca_data)
            cursor.executemany(
                "INSERT INTO "
                "CaseAttributeXRefNewTemp("
                "InbredSetId, StrainId, CaseAttributeId, Value) "
                "VALUES("
                "%(InbredSetId)s, %(StrainId)s, %(CaseAttributeId)s, %(Value)s) "
                "ON DUPLICATE KEY UPDATE `Value`=VALUES(`Value`)",
                bunch_of_data)
            time.sleep(random.randint(5, 20))

def rename_table(cursor, table, newtable):
    """Rename `table` to `newtable`."""
    cursor.execute(f"ALTER TABLE {table} RENAME TO {newtable}")

def parse_db_url(sql_uri: str) -> tuple:
    """function to parse SQL_URI env variable note:there\
    is a default value for SQL_URI so a tuple result is\
    always expected"""
    parsed_db = urlparse(sql_uri)
    return (
        parsed_db.hostname, parsed_db.username, parsed_db.password,
        parsed_db.path[1:], parsed_db.port)

@click.command(help="Update DB schema for Case-Attributes")
@click.argument("sql_uri")
def main(sql_uri: str) -> None:
    # for innodb: `SET autocommit=0` to prevent releasing locks immediately
    # after next commit.
    # See https://dev.mysql.com/doc/refman/8.0/en/lock-tables.html
    host, user, passwd, db_name, port = parse_db_url(sql_uri)
    conn = mdb.connect(
        db=db_name, user=user, passwd=passwd or '', host=host, port=port or 3306)
    try:
        cursor = conn.cursor(cursorclass=DictCursor)
        convert_to_innodb(cursor, "InbredSet")
        convert_to_innodb(cursor, "Strain")
        cleanup_inbred_set_schema(cursor)
        create_temp_case_attributes_table(cursor)
        create_temp_case_attributes_xref_table(cursor)
        copy_data(cursor)
        timestr = datetime.now().isoformat().replace(
            "-", "").replace(":", "_").replace(".", "__")
        rename_table(cursor, "CaseAttribute", f"BACKUP_CaseAttribute_{timestr}")
        rename_table(cursor, "CaseAttributeXRefNew", f"BACKUP_CaseAttributeXRefNew_{timestr}")
        rename_table(cursor, "CaseAttributeTemp", "CaseAttribute")
        rename_table(cursor, "CaseAttributeXRefNewTemp", "CaseAttributeXRefNew")
        conn.commit()
    except Exception as _exc:
        print(traceback.format_exc(), file=sys.stderr)
        conn.rollback()
    finally:
        conn.close()

if __name__ == "__main__":
    # pylint: disable=no-value-for-parameter
    main()