aboutsummaryrefslogtreecommitdiff
"""module contains integration code for rust-gn3"""
import json
from functools import reduce

from gn2.utility.tools import SQL_URI
from gn2.utility.db_tools import mescape
from gn2.utility.db_tools import create_in_clause
from gn2.wqflask.correlation.correlation_functions\
    import get_trait_symbol_and_tissue_values
from gn2.wqflask.correlation.correlation_gn3_api import create_target_this_trait
from gn2.wqflask.correlation.correlation_gn3_api import lit_for_trait_list
from gn2.wqflask.correlation.correlation_gn3_api import do_lit_correlation
from gn2.wqflask.correlation.pre_computes import fetch_text_file
from gn2.wqflask.correlation.pre_computes import read_text_file
from gn2.wqflask.correlation.pre_computes import write_db_to_textfile
from gn2.wqflask.correlation.pre_computes import read_trait_metadata
from gn2.wqflask.correlation.pre_computes import cache_trait_metadata
from gn3.computations.correlations import compute_all_lit_correlation
from gn3.computations.rust_correlation import run_correlation
from gn3.computations.rust_correlation import get_sample_corr_data
from gn3.computations.rust_correlation import parse_tissue_corr_data
from gn3.db_utils import database_connection

from gn2.wqflask.correlation.exceptions import WrongCorrelationType


def query_probes_metadata(dataset, trait_list):
    """query traits metadata in bulk for probeset"""

    if not bool(trait_list) or dataset.type != "ProbeSet":
        return []

    with database_connection(SQL_URI) as conn:
        with conn.cursor() as cursor:

            query = """
                    SELECT ProbeSet.Name,ProbeSet.Chr,ProbeSet.Mb,
                    ProbeSet.Symbol,ProbeSetXRef.mean,
                    CONCAT_WS('; ', ProbeSet.description, ProbeSet.Probe_Target_Description) AS description,
                    ProbeSetXRef.additive,ProbeSetXRef.LRS,Geno.Chr, Geno.Mb
                    FROM ProbeSet INNER JOIN ProbeSetXRef
                    ON ProbeSet.Id=ProbeSetXRef.ProbeSetId
                    INNER JOIN Geno
                    ON ProbeSetXRef.Locus = Geno.Name
                    INNER JOIN Species
                    ON Geno.SpeciesId = Species.Id
                    WHERE ProbeSet.Name in ({}) AND
                    Species.Name = %s AND
                    ProbeSetXRef.ProbeSetFreezeId IN (
                      SELECT ProbeSetFreeze.Id
                      FROM ProbeSetFreeze WHERE ProbeSetFreeze.Name = %s)
                """.format(", ".join(["%s"] * len(trait_list)))

            cursor.execute(query,
                           (tuple(trait_list) +
                            (dataset.group.species,) + (dataset.name,))
                           )

            return cursor.fetchall()


def get_metadata(dataset, traits):
    """Retrieve the metadata"""
    def __location__(probe_chr, probe_mb):
        if probe_mb:
            return f"Chr{probe_chr}: {probe_mb:.6f}"
        return f"Chr{probe_chr}: ???"
    cached_metadata = read_trait_metadata(dataset.name)
    to_fetch_metadata = list(
        set(traits).difference(list(cached_metadata.keys())))
    if to_fetch_metadata:
        results = {**({trait_name: {
            "name": trait_name,
            "view": True,
            "symbol": symbol,
            "dataset": dataset.name,
            "dataset_name": dataset.shortname,
            "mean": mean,
            "description": description,
            "additive": additive,
            "lrs_score": f"{lrs:3.1f}" if lrs else "",
            "location": __location__(probe_chr, probe_mb),
            "chr": probe_chr,
            "mb": probe_mb,
            "lrs_location": f'Chr{chr_score}: {mb:{".6f" if mb  else ""}}',
            "lrs_chr": chr_score,
            "lrs_mb": mb

        } for trait_name, probe_chr, probe_mb, symbol, mean, description,
            additive, lrs, chr_score, mb
            in query_probes_metadata(dataset, to_fetch_metadata)}), **cached_metadata}
        cache_trait_metadata(dataset.name, results)
        return results
    return cached_metadata


def chunk_dataset(dataset, steps, name):

    results = []

    query = """
            SELECT ProbeSetXRef.DataId,ProbeSet.Name
            FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze
            WHERE ProbeSetFreeze.Name = '{}' AND
                  ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND
                  ProbeSetXRef.ProbeSetId = ProbeSet.Id
    """.format(name)

    with database_connection(SQL_URI) as conn:
        with conn.cursor() as curr:
            curr.execute(query)
            traits_name_dict = dict(curr.fetchall())

    for i in range(0, len(dataset), steps):
        matrix = list(dataset[i:i + steps])
        results.append([traits_name_dict[matrix[0][0]]] + [str(value)
                                                           for (trait_name, strain, value) in matrix])
    return results


def compute_top_n_sample(start_vars, dataset, trait_list):
    """check if dataset is of type probeset"""

    if dataset.type.lower() != "probeset":
        return {}

    def __fetch_sample_ids__(samples_vals, samples_group):
        sample_data = get_sample_corr_data(
            sample_type=samples_group,
            sample_data=json.loads(samples_vals),
            dataset_samples=dataset.group.all_samples_ordered())

        with database_connection(SQL_URI) as conn:
            with conn.cursor() as curr:
                curr.execute(
                    """
                    SELECT Strain.Name, Strain.Id FROM Strain, Species
                    WHERE Strain.Name IN {}
                    and Strain.SpeciesId=Species.Id
                    and Species.name = '{}'
                    """.format(create_in_clause(list(sample_data.keys())),
                               *mescape(dataset.group.species)))
                return (sample_data, dict(curr.fetchall()))

    (sample_data, sample_ids) = __fetch_sample_ids__(
        start_vars["sample_vals"], start_vars["corr_samples_group"])

    if len(trait_list) == 0:
        return {}

    with database_connection(SQL_URI) as conn:
        with conn.cursor() as curr:
            # fetching strain data in bulk
            query = (
                "SELECT * from ProbeSetData "
                f"WHERE StrainID IN ({', '.join(['%s'] * len(sample_ids))}) "
                "AND Id IN ("
                "  SELECT ProbeSetXRef.DataId "
                "  FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) "
                "  WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id "
                "  AND ProbeSetFreeze.Name = %s "
                "  AND ProbeSet.Name "
                f" IN ({', '.join(['%s'] * len(trait_list))}) "
                "  AND ProbeSet.Id = ProbeSetXRef.ProbeSetId"
                ")")
            curr.execute(
                query,
                tuple(sample_ids.values()) + (dataset.name,) + tuple(trait_list))

            corr_data = chunk_dataset(
                list(curr.fetchall()), len(sample_ids.values()), dataset.name)

        return run_correlation(
            corr_data, list(sample_data.values()), "pearson", ",")


def compute_top_n_lit(corr_results, target_dataset, this_trait) -> dict:
    if not __datasets_compatible_p__(this_trait.dataset, target_dataset, "lit"):
        return {}

    (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
        this_trait, target_dataset)

    geneid_dict = {trait_name: geneid for (trait_name, geneid)
                   in geneid_dict.items() if
                   corr_results.get(trait_name)}
    with database_connection(SQL_URI) as conn:
        return reduce(
            lambda acc, corr: {**acc, **corr},
            compute_all_lit_correlation(
                conn=conn, trait_lists=list(geneid_dict.items()),
                species=species, gene_id=this_trait_geneid),
            {})

    return {}


def compute_top_n_tissue(target_dataset, this_trait, traits, method):
    # refactor lots of rpt
    if not __datasets_compatible_p__(this_trait.dataset, target_dataset, "tissue"):
        return {}

    trait_symbol_dict = dict({
        trait_name: symbol
        for (trait_name, symbol)
        in target_dataset.retrieve_genes("Symbol").items()
        if traits.get(trait_name)})

    corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
        symbol_list=list(trait_symbol_dict.values()))

    data = parse_tissue_corr_data(symbol_name=this_trait.symbol,
                                  symbol_dict=get_trait_symbol_and_tissue_values(
                                      symbol_list=[this_trait.symbol]),
                                  dataset_symbols=trait_symbol_dict,
                                  dataset_vals=corr_result_tissue_vals_dict)

    if data and data[0]:
        return run_correlation(
            data[1], data[0], method, ",", "tissue")

    return {}


def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
    """code to merge diff corr  into individual dicts
    a"""

    def __merge__(trait_name, trait_corrs):
        return {
            trait_name: {
                **trait_corrs,
                **dict_b.get(trait_name, {}),
                **dict_c.get(trait_name, {})
            }
        }
    return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]


def __compute_sample_corr__(
        start_vars: dict, corr_type: str, method: str, n_top: int,
        target_trait_info: tuple):
    """Compute the sample correlations"""
    (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info

    if this_dataset.group.f1list != None:
        this_dataset.group.samplelist += this_dataset.group.f1list

    if this_dataset.group.parlist != None:
        this_dataset.group.samplelist += this_dataset.group.parlist

    sample_data = get_sample_corr_data(
        sample_type=start_vars["corr_samples_group"],
        sample_data=json.loads(start_vars["sample_vals"]),
        dataset_samples=this_dataset.group.all_samples_ordered())

    if not bool(sample_data):
        return {}

    if target_dataset.type == "ProbeSet" and start_vars.get("use_cache") == "true":
        with database_connection(SQL_URI) as conn:
            file_path = fetch_text_file(target_dataset.name, conn)
            if file_path:
                (sample_vals, target_data) = read_text_file(
                    sample_data, file_path)

                return run_correlation(target_data, sample_vals,
                                       method, ",", corr_type, n_top)

            write_db_to_textfile(target_dataset.name, conn)
            file_path = fetch_text_file(target_dataset.name, conn)
            if file_path:
                (sample_vals, target_data) = read_text_file(
                    sample_data, file_path)

                return run_correlation(target_data, sample_vals,
                                       method, ",", corr_type, n_top)

    target_dataset.get_trait_data(list(sample_data.keys()))

    def __merge_key_and_values__(rows, current):
        wo_nones = [value for value in current[1]]
        if len(wo_nones) > 0:
            return rows + [[current[0]] + wo_nones]
        return rows

    target_data = reduce(
        __merge_key_and_values__, target_dataset.trait_data.items(), [])

    if len(target_data) == 0:
        return {}

    return run_correlation(
        target_data, list(sample_data.values()), method, ",", corr_type,
        n_top)


def __datasets_compatible_p__(trait_dataset, target_dataset, corr_method):
    return not (
        corr_method in ("tissue", "Tissue r", "Literature r", "lit")
        and (trait_dataset.type == "ProbeSet" and
             target_dataset.type in ("Publish", "Geno")))


def __compute_tissue_corr__(
        start_vars: dict, corr_type: str, method: str, n_top: int,
        target_trait_info: tuple):
    """Compute the tissue correlations"""
    (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
    if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type):
        raise WrongCorrelationType(this_trait, target_dataset, corr_type)

    trait_symbol_dict = target_dataset.retrieve_genes("Symbol")
    corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
        symbol_list=list(trait_symbol_dict.values()))

    data = parse_tissue_corr_data(
        symbol_name=this_trait.symbol,
        symbol_dict=get_trait_symbol_and_tissue_values(
            symbol_list=[this_trait.symbol]),
        dataset_symbols=trait_symbol_dict,
        dataset_vals=corr_result_tissue_vals_dict)

    if data:
        return run_correlation(data[1], data[0], method, ",", "tissue")
    return {}


def __compute_lit_corr__(
        start_vars: dict, corr_type: str, method: str, n_top: int,
        target_trait_info: tuple):
    """Compute the literature correlations"""
    (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
    if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type):
        raise WrongCorrelationType(this_trait, target_dataset, corr_type)

    target_dataset_type = target_dataset.type
    this_dataset_type = this_dataset.type
    (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
        this_trait, target_dataset)

    with database_connection(SQL_URI) as conn:
        return reduce(
            lambda acc, lit: {**acc, **lit},
            compute_all_lit_correlation(
                conn=conn, trait_lists=list(geneid_dict.items()),
                species=species, gene_id=this_trait_geneid)[:n_top],
            {})
    return {}


def compute_correlation_rust(
        start_vars: dict, corr_type: str, method: str = "pearson",
        n_top: int = 500, should_compute_all: bool = False):
    """function to compute correlation"""
    target_trait_info = create_target_this_trait(start_vars)
    (this_dataset, this_trait, target_dataset, sample_data) = (
        target_trait_info)
    if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type):
        raise WrongCorrelationType(this_trait, target_dataset, corr_type)

    # Replace this with `match ...` once we hit Python 3.10
    corr_type_fns = {
        "sample": __compute_sample_corr__,
        "tissue": __compute_tissue_corr__,
        "lit": __compute_lit_corr__
    }

    results = corr_type_fns[corr_type](
        start_vars, corr_type, method, n_top, target_trait_info)

    # END: Replace this with `match ...` once we hit Python 3.10

    top_a = top_b = {}

    if should_compute_all:

        if corr_type == "sample":
            if this_dataset.type == "ProbeSet":
                top_a = compute_top_n_tissue(
                    target_dataset, this_trait, results, method)

                top_b = compute_top_n_lit(results, target_dataset, this_trait)
            else:
                pass

        elif corr_type == "lit":

            # currently fails for lit

            top_a = compute_top_n_sample(
                start_vars, target_dataset, list(results.keys()))
            top_b = compute_top_n_tissue(
                target_dataset, this_trait, results, method)

        else:

            top_a = compute_top_n_sample(
                start_vars, target_dataset, list(results.keys()))

    return {
        "correlation_results": merge_results(
            results, top_a, top_b),
        "this_trait": this_trait.name,
        "target_dataset": start_vars['corr_dataset'],
        "traits_metadata": get_metadata(target_dataset, list(results.keys())),
        "return_results": n_top
    }