aboutsummaryrefslogtreecommitdiff
"""module that calls the gn3 api's to do the correlation """
import json
import time
from functools import wraps

from gn2.utility.tools import SQL_URI

from gn2.wqflask.correlation import correlation_functions
from gn2.base import data_set

from gn2.base.trait import create_trait
from gn2.base.trait import retrieve_sample_data

from gn3.db_utils import database_connection
from gn3.commands import run_sample_corr_cmd
from gn3.computations.correlations import map_shared_keys_to_values
from gn3.computations.correlations import compute_all_lit_correlation
from gn3.computations.correlations import compute_tissue_correlation
from gn3.computations.correlations import fast_compute_all_sample_correlation


def create_target_this_trait(start_vars):
    """this function creates the required trait and target dataset for correlation"""

    if start_vars['dataset'] == "Temp":
        this_dataset = data_set.create_dataset(
            dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group'])
    else:
        this_dataset = data_set.create_dataset(
            dataset_name=start_vars['dataset'])
    target_dataset = data_set.create_dataset(
        dataset_name=start_vars['corr_dataset'])
    this_trait = create_trait(dataset=this_dataset,
                              name=start_vars['trait_id'])
    sample_data = ()
    return (this_dataset, this_trait, target_dataset, sample_data)


def test_process_data(this_trait, dataset, start_vars):
    """test function for bxd,all and other sample data"""

    corr_samples_group = start_vars["corr_samples_group"]

    primary_samples = dataset.group.samplelist
    if dataset.group.parlist != None:
        primary_samples += dataset.group.parlist
    if dataset.group.f1list != None:
        primary_samples += dataset.group.f1list

    # If either BXD/whatever Only or All Samples, append all of that group's samplelist
    if corr_samples_group != 'samples_other':
        sample_data = process_samples(start_vars, primary_samples)

    # If either Non-BXD/whatever or All Samples, get all samples from this_trait.data and
    # exclude the primary samples (because they would have been added in the previous
    # if statement if the user selected All Samples)
    if corr_samples_group != 'samples_primary':
        if corr_samples_group == 'samples_other':
            primary_samples = [x for x in primary_samples if x not in (
                dataset.group.parlist + dataset.group.f1list)]
        sample_data = process_samples(start_vars, list(
            this_trait.data.keys()), primary_samples)

    return sample_data


def process_samples(start_vars, sample_names=[], excluded_samples=[]):
    """code to fetch correct samples"""
    sample_data = {}
    sample_vals_dict = json.loads(start_vars["sample_vals"])
    if sample_names:
        for sample in sample_names:
            if sample in sample_vals_dict and sample not in excluded_samples:
                val = sample_vals_dict[sample]
                if not val.strip().lower() == "x":
                    sample_data[str(sample)] = float(val)

    else:
        for sample in sample_vals_dict.keys():
            if sample not in excluded_samples:
                val = sample_vals_dict[sample]
                if not val.strip().lower() == "x":
                    sample_data[str(sample)] = float(val)
    return sample_data


def merge_correlation_results(correlation_results, target_correlation_results):

    corr_dict = {}

    for trait_dict in target_correlation_results:
        for trait_name, values in trait_dict.items():

            corr_dict[trait_name] = values
    for trait_dict in correlation_results:
        for trait_name, values in trait_dict.items():

            if corr_dict.get(trait_name):

                trait_dict[trait_name].update(corr_dict.get(trait_name))

    return correlation_results


def sample_for_trait_lists(corr_results, target_dataset,
                           this_trait, this_dataset, start_vars):
    """interface function for correlation on top results"""

    (this_trait_data, target_dataset) = fetch_sample_data(
        start_vars, this_trait, this_dataset, target_dataset)
    correlation_results = run_sample_corr_cmd(
        corr_method="pearson", this_trait=this_trait_data,
        target_dataset=target_dataset)

    return correlation_results


def tissue_for_trait_lists(corr_results, this_dataset, this_trait):
    """interface function for doing tissue corr_results on trait_list"""
    trait_lists = dict([(list(corr_result)[0], True)
                        for corr_result in corr_results])
    # trait_lists = {list(corr_results)[0]: 1 for corr_result in corr_results}
    traits_symbol_dict = this_dataset.retrieve_genes("Symbol")
    traits_symbol_dict = dict({trait_name: symbol for (
        trait_name, symbol) in traits_symbol_dict.items() if trait_lists.get(trait_name)})
    tissue_input = get_tissue_correlation_input(
        this_trait, traits_symbol_dict)

    if tissue_input is not None:
        (primary_tissue_data, target_tissue_data) = tissue_input
        corr_results = compute_tissue_correlation(
            primary_tissue_dict=primary_tissue_data,
            target_tissues_data=target_tissue_data,
            corr_method="pearson")
        return corr_results


def lit_for_trait_list(corr_results, this_dataset, this_trait):
    (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
        this_trait, this_dataset)

    # trait_lists = {list(corr_results)[0]: 1 for corr_result in corr_results}
    trait_lists = dict([(list(corr_result)[0], True)
                        for corr_result in corr_results])

    geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if
                   trait_lists.get(trait_name)}

    with database_connection(SQL_URI) as conn:
        correlation_results = compute_all_lit_correlation(
            conn=conn, trait_lists=list(geneid_dict.items()),
            species=species, gene_id=this_trait_geneid)

    return correlation_results


def fetch_sample_data(start_vars, this_trait, this_dataset, target_dataset):

    corr_samples_group = start_vars["corr_samples_group"]
    if corr_samples_group == "samples_primary":
        sample_data = process_samples(
            start_vars, this_dataset.group.samplelist)

    elif corr_samples_group == "samples_other":
        sample_data = process_samples(
            start_vars, excluded_samples=this_dataset.group.samplelist)

    else:
        sample_data = process_samples(start_vars,
            this_dataset.group.all_samples_ordered())

    target_dataset.get_trait_data(list(sample_data.keys()))
    this_trait = retrieve_sample_data(this_trait, this_dataset)
    this_trait_data = {
        "trait_sample_data": sample_data,
        "trait_id": start_vars["trait_id"]
    }
    results = map_shared_keys_to_values(
        target_dataset.samplelist, target_dataset.trait_data)

    return (this_trait_data, results)


def compute_correlation(start_vars, method="pearson", compute_all=False):
    """Compute correlations using GN3 API

    Keyword arguments:
    start_vars -- All input from form; includes things like the trait/dataset names
    method -- Correlation method to be used (pearson, spearman, or bicor)
    compute_all -- Include sample, tissue, and literature correlations (when applicable)
    """
    from gn2.wqflask.correlation.rust_correlation import compute_correlation_rust

    corr_type = start_vars['corr_type']
    method = start_vars['corr_sample_method']
    corr_return_results = int(start_vars.get("corr_return_results", 100))
    return compute_correlation_rust(
        start_vars, corr_type, method, corr_return_results, compute_all)


def compute_corr_for_top_results(start_vars,
                                 correlation_results,
                                 this_trait,
                                 this_dataset,
                                 target_dataset,
                                 corr_type):
    if corr_type != "tissue" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet":
        tissue_result = tissue_for_trait_lists(
            correlation_results, this_dataset, this_trait)

        if tissue_result:
            correlation_results = merge_correlation_results(
                correlation_results, tissue_result)

    if corr_type != "lit" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet":
        lit_result = lit_for_trait_list(
            correlation_results, this_dataset, this_trait)

        if lit_result:
            correlation_results = merge_correlation_results(
                correlation_results, lit_result)

    if corr_type != "sample" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet":
        sample_result = sample_for_trait_lists(
            correlation_results, target_dataset, this_trait, this_dataset, start_vars)
        if sample_result:
            correlation_results = merge_correlation_results(
                correlation_results, sample_result)

    return correlation_results


def do_lit_correlation(this_trait, this_dataset):
    """function for fetching lit inputs"""
    geneid_dict = this_dataset.retrieve_genes("GeneId")
    species = this_dataset.group.species
    if species:
        species = species.lower()
    trait_geneid = this_trait.geneid
    return (trait_geneid, geneid_dict, species)


def get_tissue_correlation_input(this_trait, trait_symbol_dict):
    """Gets tissue expression values for the primary trait and target tissues values"""
    primary_trait_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values(
        symbol_list=[this_trait.symbol])
    if this_trait.symbol and this_trait.symbol.lower() in primary_trait_tissue_vals_dict:

        primary_trait_tissue_values = primary_trait_tissue_vals_dict[this_trait.symbol.lower(
        )]
        corr_result_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values(
            symbol_list=list(trait_symbol_dict.values()))
        primary_tissue_data = {
            "this_id": this_trait.name,
            "tissue_values": primary_trait_tissue_values

        }
        target_tissue_data = {
            "trait_symbol_dict": trait_symbol_dict,
            "symbol_tissue_vals_dict": corr_result_tissue_vals_dict
        }
        return (primary_tissue_data, target_tissue_data)