about summary refs log tree commit diff
"""Load phenotypes and their data provided in files into the database."""
import sys
import uuid
import json
import time
import logging
import argparse
import datetime
from typing import Any
from pathlib import Path
from zipfile import ZipFile
from urllib.parse import urljoin
from functools import reduce, partial

from MySQLdb.cursors import DictCursor

from gn_libs import jobs, mysqldb, sqlite3, monadic_requests as mrequests

from r_qtl import r_qtl2 as rqtl2
from uploader.species.models import species_by_id
from uploader.population.models import population_by_species_and_id
from uploader.samples.models import samples_by_species_and_population
from uploader.phenotypes.models import (
    dataset_by_id,
    save_phenotypes_data,
    create_new_phenotypes,
    quick_save_phenotypes_data)
from uploader.publications.models import fetch_publication_by_id

from scripts.rqtl2.bundleutils import build_line_joiner, build_line_splitter

from functional_tools import take

logging.basicConfig(
    format="%(asctime)s — %(filename)s:%(lineno)s — %(levelname)s: %(message)s")
logger = logging.getLogger(__name__)



def __replace_na_strings__(line, na_strings):
    return ((None if value in na_strings else value) for value in line)


def save_phenotypes(
        conn: mysqldb.Connection,
        control_data: dict[str, Any],
        population_id,
        publication_id,
        filesdir: Path
) -> tuple[dict, ...]:
    """Read `phenofiles` and save the phenotypes therein."""
    phenofiles = tuple(filesdir.joinpath(_file) for _file in control_data["phenocovar"])
    if len(phenofiles) <= 0:
        return tuple()

    if control_data["phenocovar_transposed"]:
        logger.info("Undoing transposition of the files rows and columns.")
        phenofiles = (
            rqtl2.transpose_csv_with_rename(
                _file,
                build_line_splitter(control_data),
                build_line_joiner(control_data))
            for _file in phenofiles)

    _headers = rqtl2.read_csv_file_headers(phenofiles[0],
                                           control_data["phenocovar_transposed"],
                                           control_data["sep"],
                                           control_data["comment.char"])
    return create_new_phenotypes(
        conn,
        population_id,
        publication_id,
        (dict(zip(_headers,
                  __replace_na_strings__(line, control_data["na.strings"])))
         for filecontent
         in (rqtl2.read_csv_file(path,
                                 separator=control_data["sep"],
                                 comment_char=control_data["comment.char"])
             for path in phenofiles)
         for idx, line in enumerate(filecontent)
         if idx != 0))


def __row_to_dataitems__(
        sample_row: dict,
        dataidmap: dict,
        pheno_name2id: dict[str, int],
        samples: dict
) -> tuple[dict, ...]:
    samplename = sample_row["id"]

    return ({
        "phenotype_id": dataidmap[pheno_name2id[phenoname]]["phenotype_id"],
        "data_id": dataidmap[pheno_name2id[phenoname]]["data_id"],
        "sample_name": samplename,
        "sample_id": samples[samplename]["Id"],
        "value": phenovalue
    } for phenoname, phenovalue in sample_row.items() if phenoname != "id")


def __build_dataitems__(
        phenofiles,
        control_data,
        samples,
        dataidmap,
        pheno_name2id
):
    _headers = rqtl2.read_csv_file_headers(
        phenofiles[0],
        False, # Any transposed files have been un-transposed by this point
        control_data["sep"],
        control_data["comment.char"])
    _filescontents = (
        rqtl2.read_csv_file(path,
                            separator=control_data["sep"],
                            comment_char=control_data["comment.char"])
        for path in phenofiles)
    _linescontents = (
        __row_to_dataitems__(
            dict(zip(("id",) + _headers[1:],
                     __replace_na_strings__(line, control_data["na.strings"]))),
            dataidmap,
            pheno_name2id,
            samples)
        for linenum, line in (enumline for filecontent in _filescontents
                              for enumline in enumerate(filecontent))
        if linenum > 0)
    return (item for items in _linescontents
            for item in items
            if item["value"] is not None)


def save_numeric_data(# pylint: disable=[too-many-positional-arguments,too-many-arguments]
        conn: mysqldb.Connection,
        dataidmap: dict,
        pheno_name2id: dict[str, int],
        samples: tuple[dict, ...],
        control_data: dict,
        filesdir: Path,
        filetype: str,
        table: str
):
    """Read data from files and save to the database."""
    phenofiles = tuple(
        filesdir.joinpath(_file) for _file in control_data[filetype])
    if len(phenofiles) <= 0:
        return tuple()

    if control_data[f"{filetype}_transposed"]:
        logger.info("Undoing transposition of the files rows and columns.")
        phenofiles = tuple(
            rqtl2.transpose_csv_with_rename(
                _file,
                build_line_splitter(control_data),
                build_line_joiner(control_data))
            for _file in phenofiles)

    try:
        logger.debug("Attempt quick save with `LOAD … INFILE`.")
        return quick_save_phenotypes_data(
            conn,
            table,
            __build_dataitems__(
                phenofiles,
                control_data,
                samples,
                dataidmap,
                pheno_name2id),
            filesdir)
    except Exception as _exc:# pylint: disable=[broad-exception-caught]
        logger.debug("Could not use `LOAD … INFILE`, using raw query",
                     exc_info=True)
        time.sleep(60)
        return save_phenotypes_data(
            conn,
            table,
            __build_dataitems__(
                phenofiles,
                control_data,
                samples,
                dataidmap,
                pheno_name2id))


save_pheno_data = partial(save_numeric_data,
                          filetype="pheno",
                          table="PublishData")


save_phenotypes_se = partial(save_numeric_data,
                             filetype="phenose",
                             table="PublishSE")


save_phenotypes_n = partial(save_numeric_data,
                             filetype="phenonum",
                             table="NStrain")


def update_auth(# pylint: disable=[too-many-locals,too-many-positional-arguments,too-many-arguments]
        authserver,
        token,
        species,
        population,
        dataset,
        xrefdata):
    """Grant the user access to their data."""
    _tries = 0
    _delay = 1
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json"
    }
    def authserveruri(endpoint):
        return urljoin(authserver, endpoint)

    def __fetch_user_details__():
        logger.debug("… Fetching user details")
        return mrequests.get(
            authserveruri("/auth/user/"),
            headers=headers
        )

    def __link_data__(user):
        logger.debug("… linking uploaded data to user's group")
        return mrequests.post(
            authserveruri("/auth/data/link/phenotype"),
            headers=headers,
            json={
                "species_name": species["Name"],
                "group_id": user["group"]["group_id"],
                "selected": [
                    {
                        "SpeciesId": species["SpeciesId"],
                        "InbredSetId": population["Id"],
                        "PublishFreezeId": dataset["Id"],
                        "dataset_name": dataset["Name"],
                        "dataset_fullname": dataset["FullName"],
                        "dataset_shortname": dataset["ShortName"],
                        "PublishXRefId": item["xref_id"]
                    }
                    for item in xrefdata
                ],
                "using-raw-ids": "on"
            }).then(lambda ld_results: (user, ld_results))

    def __fetch_phenotype_category_details__(user, linkeddata):
        logger.debug("… fetching phenotype category details")
        return mrequests.get(
            authserveruri("/auth/resource/categories"),
            headers=headers
        ).then(
            lambda categories: (
                user,
                linkeddata,
                next(category for category in categories
                     if category["resource_category_key"] == "phenotype"))
        )

    def __create_resource__(user, linkeddata, category):
        logger.debug("… creating authorisation resource object")
        now = datetime.datetime.now().isoformat()
        return mrequests.post(
            authserveruri("/auth/resource/create"),
            headers=headers,
            json={
                "resource_category": category["resource_category_id"],
                "resource_name": (f"{user['email']}—{dataset['Name']}—{now}—"
                                  f"{len(xrefdata)} phenotypes"),
                "public": "off"
            }).then(lambda cr_results: (user, linkeddata, cr_results))

    def __attach_data_to_resource__(user, linkeddata, resource):
        logger.debug("… attaching data to authorisation resource object")
        return mrequests.post(
            authserveruri("/auth/resource/data/link"),
            headers=headers,
            json={
                "dataset_type": "phenotype",
                "resource_id": resource["resource_id"],
                "data_link_ids": [
                    item["data_link_id"] for item in linkeddata["traits"]]
            }).then(lambda attc: (user, linkeddata, resource, attc))

    def __handle_error__(resp):
        error = resp.json()
        if error.get("error") == "IntegrityError":
            # This is hacky. If the auth already exists, something went wrong
            # somewhere.
            # This needs investigation to recover correctly.
            logger.info(
                "The authorisation for the data was already set up.")
            return 0
        logger.error("ERROR: Updating the authorisation for the data failed.")
        logger.debug(
            "ERROR: The response from the authorisation server was:\n\t%s",
            error)
        return 1

    def __handle_success__(_val):
        logger.info(
            "The authorisation for the data has been updated successfully.")
        return 0

    return __fetch_user_details__().then(__link_data__).then(
        lambda result: __fetch_phenotype_category_details__(*result)
    ).then(
        lambda result: __create_resource__(*result)
    ).then(
        lambda result: __attach_data_to_resource__(*result)
    ).either(__handle_error__, __handle_success__)


def load_data(conn: mysqldb.Connection, job: dict) -> int:#pylint: disable=[too-many-locals]
    """Load the data attached in the given job."""
    _job_metadata = job["metadata"]
    # Steps
    # 0. Read data from the files: can be multiple files per type
    #
    _species = species_by_id(conn, int(_job_metadata["species_id"]))
    _population = population_by_species_and_id(
        conn,
        _species["SpeciesId"],
        int(_job_metadata["population_id"]))
    _dataset = dataset_by_id(
        conn,
        _species["SpeciesId"],
        _population["Id"],
        int(_job_metadata["dataset_id"]))
    # 1. Just retrive the publication: Don't create publications for now.
    _publication = fetch_publication_by_id(
        conn, int(_job_metadata.get("publication_id", "0"))) or {"Id": 0}
    # 2. Save all new phenotypes:
    #     -> return phenotype IDs
    bundle = Path(_job_metadata["bundle_file"])
    _control_data = rqtl2.control_data(bundle)
    logger.info("Extracting the zipped bundle of files.")
    _outdir = Path(bundle.parent, f"bundle_{bundle.stem}")
    with ZipFile(str(bundle), "r") as zfile:
        _files = rqtl2.extract(zfile, _outdir)
    logger.info("Saving new phenotypes.")
    _phenos = save_phenotypes(conn,
                              _control_data,
                              _population["Id"],
                              _publication["Id"],
                              _outdir)

    def __build_phenos_maps__(accumulator, row):
        return ({
            **accumulator[0],
            row["phenotype_id"]: {
                "population_id": _population["Id"],
                "phenotype_id": row["phenotype_id"],
                "data_id": row["data_id"],
                "publication_id": row["publication_id"],
            }
        }, {
            **accumulator[1],
            row["pre_publication_abbreviation"]: row["phenotype_id"]
        }, (
            accumulator[2] + ({
                "xref_id": row["xref_id"],
                "population_id": row["population_id"],
                "phenotype_id": row["phenotype_id"],
                "publication_id": row["publication_id"],
                "data_id": row["data_id"]
            },)))
    dataidmap, pheno_name2id, _xrefs = reduce(__build_phenos_maps__,
                                      _phenos,
                                      ({},{}, tuple()))
    # 3. a. Fetch the strain names and IDS: create name->ID map
    samples = {
        row["Name"]: row
        for row in samples_by_species_and_population(
                conn, _species["SpeciesId"], _population["Id"])}
    #    b. Save all the data items (DataIds are vibes), return new IDs
    logger.info("Saving new phenotypes data.")
    _num_data_rows = save_pheno_data(conn=conn,
                                     dataidmap=dataidmap,
                                     pheno_name2id=pheno_name2id,
                                     samples=samples,
                                     control_data=_control_data,
                                     filesdir=_outdir)
    logger.info("Saved %s new phenotype data rows.", _num_data_rows)

    # 4. If standard errors and N exist, save them too
    #    (use IDs returned in `3. b.` above).
    if _control_data.get("phenose"):
        logger.info("Saving new phenotypes standard errors.")
        _num_se_rows = save_phenotypes_se(conn=conn,
                                          dataidmap=dataidmap,
                                          pheno_name2id=pheno_name2id,
                                          samples=samples,
                                          control_data=_control_data,
                                          filesdir=_outdir)
        logger.info("Saved %s new phenotype standard error rows.", _num_se_rows)

    if _control_data.get("phenonum"):
        logger.info("Saving new phenotypes sample counts.")
        _num_n_rows = save_phenotypes_n(conn=conn,
                                        dataidmap=dataidmap,
                                        pheno_name2id=pheno_name2id,
                                        samples=samples,
                                        control_data=_control_data,
                                        filesdir=_outdir)
        logger.info("Saved %s new phenotype sample counts rows.", _num_n_rows)

    return (_species, _population, _dataset, _xrefs)


def update_means(
        conn: mysqldb.Connection,
        population_id: int,
        xref_ids: tuple[int, ...]
):
    """Compute the means from the data and update them in the database."""
    logger.info("Computing means for %02d phenotypes.", len(xref_ids))
    query = (
        "UPDATE PublishXRef SET mean = "
        "(SELECT AVG(value) FROM PublishData"
        " WHERE PublishData.Id=PublishXRef.DataId) "
        "WHERE PublishXRef.Id=%(xref_id)s "
        "AND PublishXRef.InbredSetId=%(population_id)s")
    _xref_iterator = (_xref_id for _xref_id in xref_ids)
    with conn.cursor(cursorclass=DictCursor) as cursor:
        while True:
            batch = take(_xref_iterator, 10000)
            if len(batch) == 0:
                break
            logger.info("\tComputing means for batch of %02d phenotypes.", len(batch))
            cursor.executemany(
                query,
                tuple({
                    "population_id": population_id,
                    "xref_id": _xref_id
                } for _xref_id in batch))


if __name__ == "__main__":
    def parse_args():
        """Setup command-line arguments."""
        parser = argparse.ArgumentParser(
            prog="load_phenotypes_to_db",
            description="Process the phenotypes' data and load it into the database.")
        parser.add_argument("db_uri", type=str, help="MariaDB/MySQL connection URL")
        parser.add_argument(
            "jobs_db_path", type=Path, help="Path to jobs' SQLite database.")
        parser.add_argument("job_id", type=uuid.UUID, help="ID of the running job")
        parser.add_argument(
            "--log-level",
            type=str,
            help="Determines what is logged out.",
            choices=("debug", "info", "warning", "error", "critical"),
            default="info")
        return parser.parse_args()

    def setup_logging(log_level: str):
        """Setup logging for the script."""
        logger.setLevel(log_level)
        logging.getLogger("uploader.phenotypes.models").setLevel(log_level)


    def main():
        """Entry-point for this script."""
        args = parse_args()
        setup_logging(args.log_level.upper())

        with (mysqldb.database_connection(args.db_uri) as conn,
              conn.cursor(cursorclass=DictCursor) as cursor,
              sqlite3.connection(args.jobs_db_path) as jobs_conn):
            job = jobs.job(jobs_conn, args.job_id)

            # Lock the PublishXRef/PublishData/PublishSE/NStrain here: Why?
            #     The `DataId` values are sequential, but not auto-increment
            #     Can't convert `PublishXRef`.`DataId` to AUTO_INCREMENT.
            #     `SELECT MAX(DataId) FROM PublishXRef;`
            #     How do you check for a table lock?
            #     https://oracle-base.com/articles/mysql/mysql-identify-locked-tables
            #     `SHOW OPEN TABLES LIKE 'Publish%';`
            _db_tables_ = (
                "Species",
                "InbredSet",
                "Strain",
                "StrainXRef",
                "Publication",
                "Phenotype",
                "PublishXRef",
                "PublishFreeze",
                "PublishData",
                "PublishSE",
                "NStrain")

            logger.debug(
                ("Locking database tables for the connection:" +
                 "".join("\n\t- %s" for _ in _db_tables_) + "\n"),
                *_db_tables_)
            cursor.execute(# Lock the tables to avoid race conditions
                "LOCK TABLES " + ", ".join(
                    f"{_table} WRITE" for _table in _db_tables_))

            db_results = load_data(conn, job)
            _xref_ids = tuple(xref["xref_id"] for xref in db_results[3])
            jobs.update_metadata(
                jobs_conn,
                args.job_id,
                "xref_ids",
                json.dumps(_xref_ids))

            logger.info("Unlocking all database tables.")
            cursor.execute("UNLOCK TABLES")

            logger.info("Updating means.")
            update_means(conn, db_results[1]["Id"], _xref_ids)

        # Update authorisations (break this down) — maybe loop until it works?
        logger.info("Updating authorisation.")
        _job_metadata = job["metadata"]
        return update_auth(_job_metadata["authserver"],
                           _job_metadata["token"],
                           *db_results)


    try:
        sys.exit(main())
    except Exception as _exc:# pylint: disable=[broad-exception-caught]
        logger.debug("Data loading failed… Halting!",
                     exc_info=True)
        sys.exit(1)