aboutsummaryrefslogtreecommitdiff
import sys
import uuid
import logging
import argparse
from pathlib import Path
from typing import Iterator
from functools import reduce

from MySQLdb.cursors import DictCursor

from gn_libs import jobs, mysqldb, sqlite3

from uploader.phenotypes.models import phenotypes_data_by_ids
from uploader.phenotypes.misc import phenotypes_data_differences
from uploader.phenotypes.views import BULK_EDIT_COMMON_FIELDNAMES

import uploader.publications.pubmed as pmed
from uploader.publications.misc import publications_differences
from uploader.publications.models import (
    update_publications, fetch_phenotype_publications)

logging.basicConfig(
    format="%(asctime)s — %(filename)s:%(lineno)s — %(levelname)s: %(message)s")
logger = logging.getLogger(__name__)


def check_ids(conn, ids: tuple[tuple[int, int], ...]) -> bool:
    """Verify that all the `UniqueIdentifier` values are valid."""
    logger.info("Checking the 'UniqueIdentifier' values.")
    with conn.cursor(cursorclass=DictCursor) as cursor:
        paramstr = ",".join(["(%s, %s)"] * len(ids))
        cursor.execute(
            "SELECT PhenotypeId AS phenotype_id, Id AS xref_id "
            "FROM PublishXRef "
            f"WHERE (PhenotypeId, Id) IN ({paramstr})",
            tuple(item for row in ids for item in row))
        mysqldb.debug_query(cursor, logger)
        found = tuple((row["phenotype_id"], row["xref_id"])
                 for row in cursor.fetchall())

    not_found = tuple(item for item in ids if item not in found)
    if len(not_found) == 0:
        logger.info("All 'UniqueIdentifier' are valid.")
        return True

    for item in not_found:
        logger.error(f"Invalid 'UniqueIdentifier' value: phId:%s::xrId:%s", item[0], item[1])

    return False


def check_for_mandatory_fields():
    """Verify that mandatory fields have values."""
    pass


def __fetch_phenotypes__(conn, ids: tuple[int, ...]) -> tuple[dict, ...]:
    """Fetch basic (non-numeric) phenotypes data from the database."""
    with conn.cursor(cursorclass=DictCursor) as cursor:
        paramstr = ",".join(["%s"] * len(ids))
        cursor.execute(f"SELECT * FROM Phenotype WHERE Id IN ({paramstr}) "
                       "ORDER BY Id ASC",
                       ids)
        return tuple(dict(row) for row in cursor.fetchall())


def descriptions_differences(file_data, db_data) -> dict[str, str]:
    """Compute differences in the descriptions."""
    logger.info("Computing differences in phenotype descriptions.")
    assert len(file_data) == len(db_data), "The counts of phenotypes differ!"
    description_columns = ("Pre_publication_description",
                           "Post_publication_description",
                           "Original_description",
                           "Pre_publication_abbreviation",
                           "Post_publication_abbreviation")
    diff = tuple()
    for file_row, db_row in zip(file_data, db_data):
        assert file_row["phenotype_id"] == db_row["Id"]
        inner_diff = {
            key: file_row[key]
                for key in description_columns
                if not file_row[key] == db_row[key]
        }
        if bool(inner_diff):
            diff = diff + ({
                "phenotype_id": file_row["phenotype_id"],
                **inner_diff
            },)

    return diff


def update_descriptions():
    """Update descriptions in the database"""
    logger.info("Updating descriptions")
    # Compute differences between db data and uploaded file
    # Only run query for changed descriptions
    pass


def link_publications():
    """Link phenotypes to relevant publications."""
    logger.info("Linking phenotypes to publications.")
    # Create publication if PubMed_ID doesn't exist in db
    pass


def update_values():
    """Update the phenotype values."""
    logger.info("Updating phenotypes values.")
    # Compute differences between db data and uploaded file
    # Only run query for changed data
    pass


def parse_args():
    parser = argparse.ArgumentParser(
        prog="Phenotypes Bulk-Edit Processor",
        description="Process the bulk-edits to phenotype data and descriptions.")
    parser.add_argument("db_uri", type=str, help="MariaDB/MySQL connection URL")
    parser.add_argument(
        "jobs_db_path", type=Path, help="Path to jobs' SQLite database.")
    parser.add_argument("job_id", type=uuid.UUID, help="ID of the running job")
    parser.add_argument(
        "--log-level",
        type=str,
        help="Determines what is logged out.",
        choices=("debug", "info", "warning", "error", "critical"),
        default="info")
    return parser.parse_args()


def read_file(filepath: Path) -> Iterator[str]:
    """Read the file, one line at a time."""
    with filepath.open(mode="r", encoding="utf-8") as infile:
        count = 0
        headers = None
        for line in infile:
            if line.startswith("#"): # ignore comments
                continue;

            fields = line.strip().split("\t")
            if count == 0:
                headers = fields
                count = count + 1
                continue

            _dict = dict(zip(
                headers,
                ((None if item.strip() == "" else item.strip())
                 for item in fields)))
            _pheno, _xref = _dict.pop("UniqueIdentifier").split("::")
            _dict = {
                key: ((float(val) if bool(val) else val)
                      if key not in BULK_EDIT_COMMON_FIELDNAMES
                      else val)
                for key, val in _dict.items()
            }
            _dict["phenotype_id"] = int(_pheno.split(":")[1])
            _dict["xref_id"] = int(_xref.split(":")[1])
            if _dict["PubMed_ID"] is not None:
                _dict["PubMed_ID"] = int(_dict["PubMed_ID"])

            yield _dict
            count = count + 1


def run(conn, job):
    """Process the data and update it."""
    file_contents = tuple(sorted(read_file(Path(job["metadata"]["edit-file"])),
                                 key=lambda item: item["phenotype_id"]))
    pheno_ids, pheno_xref_ids, pubmed_ids = reduce(
        lambda coll, curr: (
            coll[0] + (curr["phenotype_id"],),
            coll[1] + ((curr["phenotype_id"], curr["xref_id"]),),
            coll[2].union(set([curr["PubMed_ID"]]))),
        file_contents,
        (tuple(), tuple(), set([None])))
    check_ids(conn, pheno_xref_ids)
    check_for_mandatory_fields()
    # stop running here if any errors are found.

    ### Compute differences
    logger.info("Computing differences.")
    # 1. Basic Phenotype data differences
    #    a. Descriptions differences
    _desc_diff = descriptions_differences(
        file_contents, __fetch_phenotypes__(conn, pheno_ids))
    logger.debug("DESCRIPTIONS DIFFERENCES: %s", _desc_diff)

    #    b. Publications differences
    _db_publications = fetch_phenotype_publications(conn, pheno_xref_ids)
    logger.debug("DB PUBLICATIONS: %s", _db_publications)

    _pubmed_map = {
        (int(row["PubMed_ID"]) if bool(row["PubMed_ID"]) else None): f"{row['phenotype_id']}::{row['xref_id']}"
        for row in file_contents
    }
    _pub_id_map = {
        f"{pub['PhenotypeId']}::{pub['xref_id']}": pub["PublicationId"]
        for pub in _db_publications
    }

    _new_publications = update_publications(
        conn, tuple({
            **pub, "publication_id": _pub_id_map[_pubmed_map[pub["pubmed_id"]]]
        } for pub in pmed.fetch_publications(tuple(
            pubmed_id for pubmed_id in pubmed_ids
            if pubmed_id not in
            tuple(row["PubMed_ID"] for row in _db_publications)))))
    _pub_diff = publications_differences(
        file_contents, _db_publications, {
            row["PubMed_ID" if "PubMed_ID" in row else "pubmed_id"]: row[
                "PublicationId" if "PublicationId" in row else "publication_id"]
            for row in _db_publications + _new_publications})
    logger.debug("Publications diff: %s", _pub_diff)
    # 2. Data differences
    _db_pheno_data = phenotypes_data_by_ids(conn, tuple({
        "population_id": job["metadata"]["population-id"],
        "phenoid": row[0],
        "xref_id": row[1]
    } for row in pheno_xref_ids))

    data_diff = phenotypes_data_differences(
        ({
            "phenotype_id": row["phenotype_id"],
            "xref_id": row["xref_id"],
            "data": {
                key:val for key,val in row.items()
                if key not in BULK_EDIT_COMMON_FIELDNAMES + [
                        "phenotype_id", "xref_id"]
            }
        } for row in file_contents),
        ({
            **row,
            "PhenotypeId": row["Id"],
            "data": {
                dataitem["StrainName"]: dataitem
                for dataitem in row["data"].values()
            }
        } for row in _db_pheno_data))
    logger.debug("Data differences: %s", data_diff)
    ### END: Compute differences
    update_descriptions()
    link_publications()
    update_values()
    return 0


def main():
    """Entry-point for this script."""
    args = parse_args()
    logger.setLevel(args.log_level.upper())
    logger.debug("Arguments: %s", args)

    logging.getLogger("uploader.phenotypes.misc").setLevel(args.log_level.upper())
    logging.getLogger("uploader.phenotypes.models").setLevel(args.log_level.upper())
    logging.getLogger("uploader.publications.models").setLevel(args.log_level.upper())

    with (mysqldb.database_connection(args.db_uri) as conn,
          sqlite3.connection(args.jobs_db_path) as jobs_conn):
        return run(conn, jobs.job(jobs_conn, args.job_id))


if __name__ == "__main__":
    sys.exit(main())