aboutsummaryrefslogtreecommitdiff
import sys
import uuid
import json
import logging
import argparse
import datetime
from pathlib import Path
from zipfile import ZipFile
from typing import Any, Union
from urllib.parse import urljoin
from functools import reduce, partial

from MySQLdb.cursors import Cursor, DictCursor

from gn_libs import jobs, mysqldb, sqlite3, monadic_requests as mrequests

from r_qtl import r_qtl2 as rqtl2
from uploader.species.models import species_by_id
from uploader.population.models import population_by_species_and_id
from uploader.samples.models import samples_by_species_and_population
from uploader.phenotypes.models import (
    dataset_by_id,
    save_phenotypes_data,
    create_new_phenotypes,
    quick_save_phenotypes_data)
from uploader.publications.models import (
    create_new_publications,
    fetch_publication_by_id)

from scripts.rqtl2.bundleutils import build_line_joiner, build_line_splitter

logging.basicConfig(
    format="%(asctime)s — %(filename)s:%(lineno)s — %(levelname)s: %(message)s")
logger = logging.getLogger(__name__)



def __replace_na_strings__(line, na_strings):
    return ((None if value in na_strings else value) for value in line)


def save_phenotypes(
        cursor: mysqldb.Connection,
        control_data: dict[str, Any],
        filesdir: Path
) -> tuple[dict, ...]:
    """Read `phenofiles` and save the phenotypes therein."""
    ## TODO: Replace with something like this: ##
    # phenofiles = control_data["phenocovar"] + control_data.get(
    #     "gn-metadata", {}).get("pheno", [])
    #
    # This is meant to load (and merge) data from the "phenocovar" and
    # "gn-metadata -> pheno" files into a single collection of phenotypes.
    phenofiles = tuple(filesdir.joinpath(_file) for _file in control_data["phenocovar"])
    if len(phenofiles) <= 0:
        return tuple()

    if control_data["phenocovar_transposed"]:
        logger.info("Undoing transposition of the files rows and columns.")
        phenofiles = (
            rqtl2.transpose_csv_with_rename(
                _file,
                build_line_splitter(control_data),
                build_line_joiner(control_data))
            for _file in phenofiles)

    _headers = rqtl2.read_csv_file_headers(phenofiles[0],
                                           control_data["phenocovar_transposed"],
                                           control_data["sep"],
                                           control_data["comment.char"])
    return create_new_phenotypes(
        cursor,
        (dict(zip(_headers,
                  __replace_na_strings__(line, control_data["na.strings"])))
         for filecontent
         in (rqtl2.read_csv_file(path,
                                 separator=control_data["sep"],
                                 comment_char=control_data["comment.char"])
             for path in phenofiles)
         for idx, line in enumerate(filecontent)
         if idx != 0))


def __fetch_next_dataid__(conn: mysqldb.Connection) -> int:
    """Fetch the next available DataId value from the database."""
    with conn.cursor(cursorclass=DictCursor) as cursor:
        cursor.execute(
            "SELECT MAX(DataId) AS CurrentMaxDataId FROM PublishXRef")
        return int(cursor.fetchone()["CurrentMaxDataId"]) + 1


def __row_to_dataitems__(
        sample_row: dict,
        dataidmap: dict,
        pheno_name2id: dict[str, int],
        samples: dict
) -> tuple[dict, ...]:
    samplename = sample_row["id"]

    return ({
        "phenotype_id": dataidmap[pheno_name2id[phenoname]]["phenotype_id"],
        "data_id": dataidmap[pheno_name2id[phenoname]]["data_id"],
        "sample_name": samplename,
        "sample_id": samples[samplename]["Id"],
        "value": phenovalue
    } for phenoname, phenovalue in sample_row.items() if phenoname != "id")


def __build_dataitems__(
        filetype,
        phenofiles,
        control_data,
        samples,
        dataidmap,
        pheno_name2id
):
    _headers = rqtl2.read_csv_file_headers(
        phenofiles[0],
        False, # Any transposed files have been un-transposed by this point
        control_data["sep"],
        control_data["comment.char"])
    _filescontents = (
        rqtl2.read_csv_file(path,
                            separator=control_data["sep"],
                            comment_char=control_data["comment.char"])
        for path in phenofiles)
    _linescontents = (
        __row_to_dataitems__(
            dict(zip(("id",) + _headers[1:],
                     __replace_na_strings__(line, control_data["na.strings"]))),
            dataidmap,
            pheno_name2id,
            samples)
        for linenum, line in (enumline for filecontent in _filescontents
                              for enumline in enumerate(filecontent))
        if linenum > 0)
    return (item for items in _linescontents
            for item in items
            if item["value"] is not None)


def save_numeric_data(
        conn: mysqldb.Connection,
        dataidmap: dict,
        pheno_name2id: dict[str, int],
        samples: tuple[dict, ...],
        control_data: dict,
        filesdir: Path,
        filetype: str,
        table: str
):
    """Read data from files and save to the database."""
    phenofiles = tuple(
        filesdir.joinpath(_file) for _file in control_data[filetype])
    if len(phenofiles) <= 0:
        return tuple()

    if control_data[f"{filetype}_transposed"]:
        logger.info("Undoing transposition of the files rows and columns.")
        phenofiles = tuple(
            rqtl2.transpose_csv_with_rename(
                _file,
                build_line_splitter(control_data),
                build_line_joiner(control_data))
            for _file in phenofiles)

    try:
        logger.debug("Attempt quick save with `LOAD … INFILE`.")
        return quick_save_phenotypes_data(
            conn,
            table,
            __build_dataitems__(
                filetype,
                phenofiles,
                control_data,
                samples,
                dataidmap,
                pheno_name2id),
            filesdir)
    except Exception as _exc:
        logger.debug("Could not use `LOAD … INFILE`, using raw query",
                     exc_info=True)
        import time;time.sleep(60)
        return save_phenotypes_data(
            conn,
            table,
            __build_dataitems__(
                filetype,
                phenofiles,
                control_data,
                samples,
                dataidmap,
                pheno_name2id))


save_pheno_data = partial(save_numeric_data,
                          filetype="pheno",
                          table="PublishData")


save_phenotypes_se = partial(save_numeric_data,
                             filetype="phenose",
                             table="PublishSE")


save_phenotypes_n = partial(save_numeric_data,
                             filetype="phenonum",
                             table="NStrain")


def cross_reference_phenotypes_publications_and_data(
        conn: mysqldb.Connection, xref_data: tuple[dict, ...]
):
    """Crossreference the phenotypes, publication and data."""
    with conn.cursor(cursorclass=DictCursor) as cursor:
        cursor.execute("SELECT MAX(Id) CurrentMaxId FROM PublishXRef")
        _nextid = int(cursor.fetchone()["CurrentMaxId"]) + 1
        _params = tuple({**row, "xref_id": _id}
                        for _id, row in enumerate(xref_data, start=_nextid))
        cursor.executemany(
            ("INSERT INTO PublishXRef("
             "Id, InbredSetId, PhenotypeId, PublicationId, DataId, comments"
             ") "
             "VALUES ("
             "%(xref_id)s, %(population_id)s, %(phenotype_id)s, "
             "%(publication_id)s, %(data_id)s, 'Upload of new data.'"
             ")"),
            _params)
        cursor.executemany(
            "UPDATE PublishXRef SET mean="
            "(SELECT AVG(value) FROM PublishData WHERE PublishData.Id=PublishXRef.DataId) "
            "WHERE PublishXRef.Id=%(xref_id)s AND "
            "InbredSetId=%(population_id)s",
            _params)
        return _params
    return tuple()


def update_auth(authserver, token, species, population, dataset, xrefdata):
    """Grant the user access to their data."""
    # TODO Call into the auth server to:
    #      1. Link the phenotypes with a user group
    #         - fetch group: http://localhost:8081/auth/user/group
    #         - link data to group: http://localhost:8081/auth/data/link/phenotype
    #         - *might need code update in gn-auth: remove restriction, perhaps*
    #      2. Create resource (perhaps?)
    #         - Get resource categories: http://localhost:8081/auth/resource/categories
    #         - Create a new resource: http://localhost:80host:8081/auth/resource/create
    #           - single resource for all phenotypes
    #           - resource name from user, species, population, dataset, datetime?
    #         - User will have "ownership" of resource by default
    #      3. Link data to the resource: http://localhost:8081/auth/resource/data/link
    #         - Update code to allow linking multiple items in a single request
    _tries = 0 # TODO use this to limit how many tries before quiting and bailing
    _delay = 1
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json"
    }
    def authserveruri(endpoint):
        return urljoin(authserver, endpoint)

    def __fetch_user_details__():
        logger.debug("… Fetching user details")
        return mrequests.get(
            authserveruri("/auth/user/"),
            headers=headers
        )

    def __link_data__(user):
        logger.debug("… linking uploaded data to user's group")
        return mrequests.post(
            authserveruri("/auth/data/link/phenotype"),
            headers=headers,
            json={
                "species_name": species["Name"],
                "group_id": user["group"]["group_id"],
                "selected": [
                    {
                        "SpeciesId": species["SpeciesId"],
                        "InbredSetId": population["Id"],
                        "PublishFreezeId": dataset["Id"],
                        "dataset_name": dataset["Name"],
                        "dataset_fullname": dataset["FullName"],
                        "dataset_shortname": dataset["ShortName"],
                        "PublishXRefId": item["xref_id"]
                    }
                    for item in xrefdata
                ],
                "using-raw-ids": "on"
            }).then(lambda ld_results: (user, ld_results))

    def __fetch_phenotype_category_details__(user, linkeddata):
        logger.debug("… fetching phenotype category details")
        return mrequests.get(
            authserveruri("/auth/resource/categories"),
            headers=headers
        ).then(
            lambda categories: (
                user,
                linkeddata,
                next(category for category in categories
                     if category["resource_category_key"] == "phenotype"))
        )

    def __create_resource__(user, linkeddata, category):
        logger.debug("… creating authorisation resource object")
        now = datetime.datetime.now().isoformat()
        return mrequests.post(
            authserveruri("/auth/resource/create"),
            headers=headers,
            json={
                "resource_category": category["resource_category_id"],
                "resource_name": (f"{user['email']}—{dataset['Name']}—{now}—"
                                  f"{len(xrefdata)} phenotypes"),
                "public": "off"
            }).then(lambda cr_results: (user, linkeddata, cr_results))

    def __attach_data_to_resource__(user, linkeddata, resource):
        logger.debug("… attaching data to authorisation resource object")
        return mrequests.post(
            authserveruri("/auth/resource/data/link"),
            headers=headers,
            json={
                "dataset_type": "phenotype",
                "resource_id": resource["resource_id"],
                "data_link_ids": [
                    item["data_link_id"] for item in linkeddata["traits"]]
            }).then(lambda attc: (user, linkeddata, resource, attc))

    def __handle_error__(resp):
        logger.error("ERROR: Updating the authorisation for the data failed.")
        logger.debug(
            "ERROR: The response from the authorisation server was:\n\t%s",
            resp.json())
        return 1

    def __handle_success__(val):
        logger.info(
            "The authorisation for the data has been updated successfully.")
        return 0

    return __fetch_user_details__().then(__link_data__).then(
        lambda result: __fetch_phenotype_category_details__(*result)
    ).then(
        lambda result: __create_resource__(*result)
    ).then(
        lambda result: __attach_data_to_resource__(*result)
    ).either(__handle_error__, __handle_success__)


def load_data(conn: mysqldb.Connection, job: dict) -> int:
    """Load the data attached in the given job."""
    _job_metadata = job["metadata"]
    # Steps
    # 0. Read data from the files: can be multiple files per type
    #
    _species = species_by_id(conn, int(_job_metadata["species_id"]))
    _population = population_by_species_and_id(
        conn,
        _species["SpeciesId"],
        int(_job_metadata["population_id"]))
    _dataset = dataset_by_id(
        conn,
        _species["SpeciesId"],
        _population["Id"],
        int(_job_metadata["dataset_id"]))
    # 1. Just retrive the publication: Don't create publications for now.
    _publication = fetch_publication_by_id(
        conn, int(_job_metadata.get("publication_id", "0"))) or {"Id": 0}
    # 2. Save all new phenotypes:
    #     -> return phenotype IDs
    bundle = Path(_job_metadata["bundle_file"])
    _control_data = rqtl2.control_data(bundle)
    logger.info("Extracting the zipped bundle of files.")
    _outdir = Path(bundle.parent, f"bundle_{bundle.stem}")
    with ZipFile(str(bundle), "r") as zfile:
        _files = rqtl2.extract(zfile, _outdir)
    logger.info("Saving new phenotypes.")
    _phenos = save_phenotypes(conn, _control_data, _outdir)
    def __build_phenos_maps__(accumulator, current):
        dataid, row = current
        return ({
            **accumulator[0],
            row["phenotype_id"]: {
                "population_id": _population["Id"],
                "phenotype_id": row["phenotype_id"],
                "data_id": dataid,
                "publication_id": _publication["Id"],
            }
        }, {
            **accumulator[1],
            row["id"]: row["phenotype_id"]
        })
    dataidmap, pheno_name2id = reduce(
        __build_phenos_maps__,
        enumerate(_phenos, start=__fetch_next_dataid__(conn)),
        ({},{}))
    # 3. a. Fetch the strain names and IDS: create name->ID map
    samples = {
        row["Name"]: row
        for row in samples_by_species_and_population(
                conn, _species["SpeciesId"], _population["Id"])}
    #    b. Save all the data items (DataIds are vibes), return new IDs
    logger.info("Saving new phenotypes data.")
    _num_data_rows = save_pheno_data(conn=conn,
                                     dataidmap=dataidmap,
                                     pheno_name2id=pheno_name2id,
                                     samples=samples,
                                     control_data=_control_data,
                                     filesdir=_outdir)
    logger.info("Saved %s new phenotype data rows.", _num_data_rows)
    # 4. Cross-reference Phenotype, Publication, and PublishData in PublishXRef
    logger.info("Cross-referencing new phenotypes to their data and publications.")
    _xrefs = cross_reference_phenotypes_publications_and_data(
        conn, tuple(dataidmap.values()))
    # 5. If standard errors and N exist, save them too
    #    (use IDs returned in `3. b.` above).
    if _control_data.get("phenose"):
        logger.info("Saving new phenotypes standard errors.")
        _num_se_rows = save_phenotypes_se(conn=conn,
                                          dataidmap=dataidmap,
                                          pheno_name2id=pheno_name2id,
                                          samples=samples,
                                          control_data=_control_data,
                                          filesdir=_outdir)
        logger.info("Saved %s new phenotype standard error rows.", _num_se_rows)

    if _control_data.get("phenonum"):
        logger.info("Saving new phenotypes sample counts.")
        _num_n_rows = save_phenotypes_n(conn=conn,
                                        dataidmap=dataidmap,
                                        pheno_name2id=pheno_name2id,
                                        samples=samples,
                                        control_data=_control_data,
                                        filesdir=_outdir)
        logger.info("Saved %s new phenotype sample counts rows.", _num_n_rows)

    return (_species, _population, _dataset, _xrefs)


if __name__ == "__main__":
    def parse_args():
        """Setup command-line arguments."""
        parser = argparse.ArgumentParser(
            prog="load_phenotypes_to_db",
            description="Process the phenotypes' data and load it into the database.")
        parser.add_argument("db_uri", type=str, help="MariaDB/MySQL connection URL")
        parser.add_argument(
            "jobs_db_path", type=Path, help="Path to jobs' SQLite database.")
        parser.add_argument("job_id", type=uuid.UUID, help="ID of the running job")
        parser.add_argument(
            "--log-level",
            type=str,
            help="Determines what is logged out.",
            choices=("debug", "info", "warning", "error", "critical"),
            default="info")
        return parser.parse_args()

    def setup_logging(log_level: str):
        """Setup logging for the script."""
        logger.setLevel(log_level)
        logging.getLogger("uploader.phenotypes.models").setLevel(log_level)


    def main():
        """Entry-point for this script."""
        args = parse_args()
        setup_logging(args.log_level.upper())

        with (mysqldb.database_connection(args.db_uri) as conn,
              conn.cursor(cursorclass=DictCursor) as cursor,
              sqlite3.connection(args.jobs_db_path) as jobs_conn):
            job = jobs.job(jobs_conn, args.job_id)

            # Lock the PublishXRef/PublishData/PublishSE/NStrain here: Why?
            #     The `DataId` values are sequential, but not auto-increment
            #     Can't convert `PublishXRef`.`DataId` to AUTO_INCREMENT.
            #     `SELECT MAX(DataId) FROM PublishXRef;`
            #     How do you check for a table lock?
            #     https://oracle-base.com/articles/mysql/mysql-identify-locked-tables
            #     `SHOW OPEN TABLES LIKE 'Publish%';`
            _db_tables_ = (
                "Species",
                "InbredSet",
                "Strain",
                "StrainXRef",
                "Publication",
                "Phenotype",
                "PublishXRef",
                "PublishFreeze",
                "PublishData",
                "PublishSE",
                "NStrain")

            logger.debug(
                ("Locking database tables for the connection:" +
                 "".join("\n\t- %s" for _ in _db_tables_) + "\n"),
                *_db_tables_)
            cursor.execute(# Lock the tables to avoid race conditions
                "LOCK TABLES " + ", ".join(
                    f"{_table} WRITE" for _table in _db_tables_))

            db_results = load_data(conn, job)
            jobs.update_metadata(
                jobs_conn,
                args.job_id,
                "xref_ids",
                json.dumps([xref["xref_id"] for xref in db_results[3]]))

            logger.info("Unlocking all database tables.")
            cursor.execute("UNLOCK TABLES")

        # Update authorisations (break this down) — maybe loop until it works?
        logger.info("Updating authorisation.")
        _job_metadata = job["metadata"]
        return update_auth(_job_metadata["authserver"],
                           _job_metadata["token"],
                           *db_results)


    try:
        sys.exit(main())
    except Exception as _exc:
        logger.debug("Data loading failed… Halting!",
                     exc_info=True)
        sys.exit(1)