aboutsummaryrefslogtreecommitdiff
import sys
import uuid
import logging
import argparse
from pathlib import Path
from typing import Iterator
from functools import reduce

import requests
from MySQLdb.cursors import DictCursor

from gn_libs import jobs, mysqldb, sqlite3

logging.basicConfig(
    format="%(asctime)s — %(filename)s:%(lineno)s — %(levelname)s: %(message)s")
logger = logging.getLogger(__name__)


def check_ids(conn, ids: tuple[tuple[int, int], ...]) -> bool:
    """Verify that all the `UniqueIdentifier` values are valid."""
    logger.info("Checking the 'UniqueIdentifier' values.")
    with conn.cursor(cursorclass=DictCursor) as cursor:
        paramstr = ",".join(["(%s, %s)"] * len(ids))
        cursor.execute(
            "SELECT PhenotypeId AS phenotype_id, Id AS xref_id "
            "FROM PublishXRef "
            f"WHERE (PhenotypeId, Id) IN ({paramstr})",
            tuple(item for row in ids for item in row))
        mysqldb.debug_query(cursor, logger)
        found = tuple((row["phenotype_id"], row["xref_id"])
                 for row in cursor.fetchall())

    not_found = tuple(item for item in ids if item not in found)
    if len(not_found) == 0:
        logger.info("All 'UniqueIdentifier' are valid.")
        return True

    for item in not_found:
        logger.error(f"Invalid 'UniqueIdentifier' value: phId:%s::xrId:%s", item[0], item[1])

    return False


def check_for_mandatory_fields():
    """Verify that mandatory fields have values."""
    pass


def __fetch_phenotypes__(conn, ids: tuple[int, ...]) -> tuple[dict, ...]:
    """Fetch basic (non-numeric) phenotypes data from the database."""
    with conn.cursor(cursorclass=DictCursor) as cursor:
        paramstr = ",".join(["%s"] * len(ids))
        cursor.execute(f"SELECT * FROM Phenotype WHERE Id IN ({paramstr}) "
                       "ORDER BY Id ASC",
                       ids)
        return tuple(dict(row) for row in cursor.fetchall())


def descriptions_differences(file_data, db_data) -> dict[str, str]:
    """Compute differences in the descriptions."""
    logger.info("Computing differences in phenotype descriptions.")
    assert len(file_data) == len(db_data), "The counts of phenotypes differ!"
    description_columns = ("Pre_publication_description",
                           "Post_publication_description",
                           "Original_description",
                           "Pre_publication_abbreviation",
                           "Post_publication_abbreviation")
    diff = tuple()
    for file_row, db_row in zip(file_data, db_data):
        assert file_row["phenotype_id"] == db_row["Id"]
        inner_diff = {
            key: file_row[key]
                for key in description_columns
                if not file_row[key] == db_row[key]
        }
        if bool(inner_diff):
            diff = diff + ({
                "phenotype_id": file_row["phenotype_id"],
                **inner_diff
            },)

    return diff


def __fetch_publications__(conn, ids):
    """Fetch publication from database by ID."""
    paramstr = ",".join(["(%s, %s)"] * len(ids))
    query = (
        "SELECT "
        "pxr.PhenotypeId, pxr.Id AS xref_id, pxr.PublicationId, pub.PubMed_ID "
        "FROM PublishXRef AS pxr INNER JOIN Publication AS pub "
        "ON pxr.PublicationId=pub.Id "
        f"WHERE (pxr.PhenotypeId, pxr.Id) IN ({paramstr})")
    with conn.cursor(cursorclass=DictCursor) as cursor:
        cursor.execute(query, tuple(item for row in ids for item in row))
        return tuple(dict(row) for row in cursor.fetchall())


def __process_pubmed_publication_data__(text):
    """Process the data from PubMed into usable data."""
    # Process with lxml
    pass


def __fetch_new_pubmed_ids__(pubmed_ids):
    """Retrieve data on new publications from NCBI."""
    # See whether we can retrieve multiple publications in one go
    # Parse data and save to DB
    # Return PublicationId(s) for new publication(s).
    logger.info("Fetching publications data for the following PubMed IDs: %s",
                ", ".join(pubmed_ids))

    # Should we, perhaps, pass this in from a config variable?
    uri = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
    try:
        response = request.get(
            uri,
            params={
                "db": "pubmed",
                "retmode": "xml",
                "id": ",".join(str(item) for item in pubmed_ids)
            })

        if response.status_code == 200:
            return __process_pubmed_publication_data__(response.text)

        logger.error(
            "Could not fetch the new publication from %s (status code: %s)",
            uri,
            response.status_code)
    except requests.exceptions.ConnectionError:
        logger.error("Could not find the domain %s", uri)

    return tuple()


def publications_differences(file_data, db_data, pubmed_ids):
    """Compute differences in the publications."""
    logger.info("Computing differences in publications.")
    db_pubmed_ids = reduce(lambda coll, curr: coll.union(set([curr["PubMed_ID"]])),
                           db_data,
                           set([None]))
    new_pubmeds = __fetch_new_pubmed_ids__(tuple(
        pubmed_ids.difference(db_pubmed_ids)))
    pass


def compute_differences(conn, file_contents, pheno_ids, pheno_xref_ids, pubmed_ids) -> tuple[tuple[dict, ...], tuple[dict, ...], tuple[dict, ...]]:
    """Compute differences between data in DB and edited data."""
    logger.info("Computing differences.")
    # 1. Basic Phenotype data differences
    #    a. Descriptions differences
    desc_diff = descriptions_differences(file_contents, __fetch_phenotypes__(conn, pheno_ids))
    logger.debug("DESCRIPTIONS DIFFERENCES: %s", desc_diff)
    #    b. Publications differences
    # pub_diff = publications_differences(...)
    # 2. Data differences
    # data_diff = data_differences(...)
    pass


def update_descriptions():
    """Update descriptions in the database"""
    logger.info("Updating descriptions")
    # Compute differences between db data and uploaded file
    # Only run query for changed descriptions
    pass


def link_publications():
    """Link phenotypes to relevant publications."""
    logger.info("Linking phenotypes to publications.")
    # Create publication if PubMed_ID doesn't exist in db
    pass


def update_values():
    """Update the phenotype values."""
    logger.info("Updating phenotypes values.")
    # Compute differences between db data and uploaded file
    # Only run query for changed data
    pass


def parse_args():
    parser = argparse.ArgumentParser(
        prog="Phenotypes Bulk-Edit Processor",
        description="Process the bulk-edits to phenotype data and descriptions.")
    parser.add_argument("db_uri", type=str, help="MariaDB/MySQL connection URL")
    parser.add_argument(
        "jobs_db_path", type=Path, help="Path to jobs' SQLite database.")
    parser.add_argument("job_id", type=uuid.UUID, help="ID of the running job")
    parser.add_argument(
        "--log-level",
        type=str,
        help="Determines what is logged out.",
        choices=("debug", "info", "warning", "error", "critical"),
        default="info")
    return parser.parse_args()


def read_file(filepath: Path) -> Iterator[str]:
    """Read the file, one line at a time."""
    with filepath.open(mode="r", encoding="utf-8") as infile:
        count = 0
        headers = None
        for line in infile:
            if line.startswith("#"): # ignore comments
                continue;

            fields = line.strip().split("\t")
            if count == 0:
                headers = fields
                count = count + 1
                continue

            _dict = dict(zip(
                headers,
                ((None if item.strip() == "" else item.strip())
                 for item in fields)))
            _pheno, _xref = _dict.pop("UniqueIdentifier").split("::")
            _dict["phenotype_id"] = int(_pheno.split(":")[1])
            _dict["xref_id"] = int(_xref.split(":")[1])
            yield _dict
            count = count + 1


def run(conn, job):
    """Process the data and update it."""
    file_contents = tuple(sorted(read_file(Path(job["metadata"]["edit-file"])),
                                 key=lambda item: item["phenotype_id"]))
    pheno_ids, pheno_xref_ids, pubmed_ids = reduce(
        lambda coll, curr: (
            coll[0] + (curr["phenotype_id"],),
            coll[1] + ((curr["phenotype_id"], curr["xref_id"]),),
            coll[2].union(set([curr["PubMed_ID"]]))),
        file_contents,
        (tuple(), tuple(), set([None])))
    check_ids(conn, pheno_xref_ids)
    check_for_mandatory_fields()
    # stop running here if any errors are found.
    compute_differences(conn,
                        file_contents,
                        pheno_ids,
                        pheno_xref_ids,
                        pubmed_ids)
    update_descriptions()
    link_publications()
    update_values()
    return 0


def main():
    """Entry-point for this script."""
    args = parse_args()
    logger.setLevel(args.log_level.upper())
    logger.debug("Arguments: %s", args)

    with (mysqldb.database_connection(args.db_uri) as conn,
          sqlite3.connection(args.jobs_db_path) as jobs_conn):
        return run(conn, jobs.job(jobs_conn, args.job_id))


if __name__ == "__main__":
    sys.exit(main())