"""module that calls the gn3 api's to do the correlation """ import json import time from functools import wraps from gn2.utility.tools import SQL_URI from gn2.wqflask.correlation import correlation_functions from gn2.base import data_set from gn2.base.trait import create_trait from gn2.base.trait import retrieve_sample_data from gn3.db_utils import database_connection from gn3.commands import run_sample_corr_cmd from gn3.computations.correlations import map_shared_keys_to_values from gn3.computations.correlations import compute_all_lit_correlation from gn3.computations.correlations import compute_tissue_correlation from gn3.computations.correlations import fast_compute_all_sample_correlation def create_target_this_trait(start_vars): """this function creates the required trait and target dataset for correlation""" if start_vars['dataset'] == "Temp": this_dataset = data_set.create_dataset( dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group']) else: this_dataset = data_set.create_dataset( dataset_name=start_vars['dataset']) target_dataset = data_set.create_dataset( dataset_name=start_vars['corr_dataset']) this_trait = create_trait(dataset=this_dataset, name=start_vars['trait_id']) sample_data = () return (this_dataset, this_trait, target_dataset, sample_data) def test_process_data(this_trait, dataset, start_vars): """test function for bxd,all and other sample data""" corr_samples_group = start_vars["corr_samples_group"] primary_samples = dataset.group.samplelist if dataset.group.parlist != None: primary_samples += dataset.group.parlist if dataset.group.f1list != None: primary_samples += dataset.group.f1list # If either BXD/whatever Only or All Samples, append all of that group's samplelist if corr_samples_group != 'samples_other': sample_data = process_samples(start_vars, primary_samples) # If either Non-BXD/whatever or All Samples, get all samples from this_trait.data and # exclude the primary samples (because they would have been added in the previous # if statement if the user selected All Samples) if corr_samples_group != 'samples_primary': if corr_samples_group == 'samples_other': primary_samples = [x for x in primary_samples if x not in ( dataset.group.parlist + dataset.group.f1list)] sample_data = process_samples(start_vars, list( this_trait.data.keys()), primary_samples) return sample_data def process_samples(start_vars, sample_names=[], excluded_samples=[]): """code to fetch correct samples""" sample_data = {} sample_vals_dict = json.loads(start_vars["sample_vals"]) if sample_names: for sample in sample_names: if sample in sample_vals_dict and sample not in excluded_samples: val = sample_vals_dict[sample] if not val.strip().lower() == "x": sample_data[str(sample)] = float(val) else: for sample in sample_vals_dict.keys(): if sample not in excluded_samples: val = sample_vals_dict[sample] if not val.strip().lower() == "x": sample_data[str(sample)] = float(val) return sample_data def merge_correlation_results(correlation_results, target_correlation_results): corr_dict = {} for trait_dict in target_correlation_results: for trait_name, values in trait_dict.items(): corr_dict[trait_name] = values for trait_dict in correlation_results: for trait_name, values in trait_dict.items(): if corr_dict.get(trait_name): trait_dict[trait_name].update(corr_dict.get(trait_name)) return correlation_results def sample_for_trait_lists(corr_results, target_dataset, this_trait, this_dataset, start_vars): """interface function for correlation on top results""" (this_trait_data, target_dataset) = fetch_sample_data( start_vars, this_trait, this_dataset, target_dataset) correlation_results = run_sample_corr_cmd( corr_method="pearson", this_trait=this_trait_data, target_dataset=target_dataset) return correlation_results def tissue_for_trait_lists(corr_results, this_dataset, this_trait): """interface function for doing tissue corr_results on trait_list""" trait_lists = dict([(list(corr_result)[0], True) for corr_result in corr_results]) # trait_lists = {list(corr_results)[0]: 1 for corr_result in corr_results} traits_symbol_dict = this_dataset.retrieve_genes("Symbol") traits_symbol_dict = dict({trait_name: symbol for ( trait_name, symbol) in traits_symbol_dict.items() if trait_lists.get(trait_name)}) tissue_input = get_tissue_correlation_input( this_trait, traits_symbol_dict) if tissue_input is not None: (primary_tissue_data, target_tissue_data) = tissue_input corr_results = compute_tissue_correlation( primary_tissue_dict=primary_tissue_data, target_tissues_data=target_tissue_data, corr_method="pearson") return corr_results def lit_for_trait_list(corr_results, this_dataset, this_trait): (this_trait_geneid, geneid_dict, species) = do_lit_correlation( this_trait, this_dataset) # trait_lists = {list(corr_results)[0]: 1 for corr_result in corr_results} trait_lists = dict([(list(corr_result)[0], True) for corr_result in corr_results]) geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if trait_lists.get(trait_name)} with database_connection(SQL_URI) as conn: correlation_results = compute_all_lit_correlation( conn=conn, trait_lists=list(geneid_dict.items()), species=species, gene_id=this_trait_geneid) return correlation_results def fetch_sample_data(start_vars, this_trait, this_dataset, target_dataset): corr_samples_group = start_vars["corr_samples_group"] if corr_samples_group == "samples_primary": sample_data = process_samples( start_vars, this_dataset.group.samplelist) elif corr_samples_group == "samples_other": sample_data = process_samples( start_vars, excluded_samples=this_dataset.group.samplelist) else: sample_data = process_samples(start_vars, this_dataset.group.all_samples_ordered()) target_dataset.get_trait_data(list(sample_data.keys())) this_trait = retrieve_sample_data(this_trait, this_dataset) this_trait_data = { "trait_sample_data": sample_data, "trait_id": start_vars["trait_id"] } results = map_shared_keys_to_values( target_dataset.samplelist, target_dataset.trait_data) return (this_trait_data, results) def compute_correlation(start_vars, method="pearson", compute_all=False): """Compute correlations using GN3 API Keyword arguments: start_vars -- All input from form; includes things like the trait/dataset names method -- Correlation method to be used (pearson, spearman, or bicor) compute_all -- Include sample, tissue, and literature correlations (when applicable) """ from gn2.wqflask.correlation.rust_correlation import compute_correlation_rust corr_type = start_vars['corr_type'] method = start_vars['corr_sample_method'] corr_return_results = int(start_vars.get("corr_return_results", 100)) return compute_correlation_rust( start_vars, corr_type, method, corr_return_results, compute_all) def compute_corr_for_top_results(start_vars, correlation_results, this_trait, this_dataset, target_dataset, corr_type): if corr_type != "tissue" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet": tissue_result = tissue_for_trait_lists( correlation_results, this_dataset, this_trait) if tissue_result: correlation_results = merge_correlation_results( correlation_results, tissue_result) if corr_type != "lit" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet": lit_result = lit_for_trait_list( correlation_results, this_dataset, this_trait) if lit_result: correlation_results = merge_correlation_results( correlation_results, lit_result) if corr_type != "sample" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet": sample_result = sample_for_trait_lists( correlation_results, target_dataset, this_trait, this_dataset, start_vars) if sample_result: correlation_results = merge_correlation_results( correlation_results, sample_result) return correlation_results def do_lit_correlation(this_trait, this_dataset): """function for fetching lit inputs""" geneid_dict = this_dataset.retrieve_genes("GeneId") species = this_dataset.group.species if species: species = species.lower() trait_geneid = this_trait.geneid return (trait_geneid, geneid_dict, species) def get_tissue_correlation_input(this_trait, trait_symbol_dict): """Gets tissue expression values for the primary trait and target tissues values""" primary_trait_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values( symbol_list=[this_trait.symbol]) if this_trait.symbol and this_trait.symbol.lower() in primary_trait_tissue_vals_dict: primary_trait_tissue_values = primary_trait_tissue_vals_dict[this_trait.symbol.lower( )] corr_result_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values( symbol_list=list(trait_symbol_dict.values())) primary_tissue_data = { "this_id": this_trait.name, "tissue_values": primary_trait_tissue_values } target_tissue_data = { "trait_symbol_dict": trait_symbol_dict, "symbol_tissue_vals_dict": corr_result_tissue_vals_dict } return (primary_tissue_data, target_tissue_data)