diff options
-rw-r--r-- | wqflask/wqflask/correlation/rust_correlation.py | 135 | ||||
-rw-r--r-- | wqflask/wqflask/correlation/show_corr_results.py | 6 | ||||
-rw-r--r-- | wqflask/wqflask/views.py | 2 |
3 files changed, 125 insertions, 18 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py index 79d08a59..5c22efbf 100644 --- a/wqflask/wqflask/correlation/rust_correlation.py +++ b/wqflask/wqflask/correlation/rust_correlation.py @@ -1,7 +1,10 @@ """module contains integration code for rust-gn3""" import json from functools import reduce -from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values +from utility.db_tools import mescape +from utility.db_tools import create_in_clause +from wqflask.correlation.correlation_functions\ + import get_trait_symbol_and_tissue_values from wqflask.correlation.correlation_gn3_api import create_target_this_trait from wqflask.correlation.correlation_gn3_api import lit_for_trait_list from wqflask.correlation.correlation_gn3_api import do_lit_correlation @@ -12,11 +15,92 @@ from gn3.computations.rust_correlation import parse_tissue_corr_data from gn3.db_utils import database_connector +def chunk_dataset(dataset, steps, name): + + results = [] + + query = """ + SELECT ProbeSetXRef.DataId,ProbeSet.Name + FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze + WHERE ProbeSetFreeze.Name = '{}' AND + ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND + ProbeSetXRef.ProbeSetId = ProbeSet.Id + """.format(name) + + with database_connector() as conn: + with conn.cursor() as curr: + curr.execute(query) + traits_name_dict = dict(curr.fetchall()) + + for i in range(0, len(dataset), steps): + matrix = list(dataset[i:i + steps]) + trait_name = traits_name_dict[matrix[0][0]] + + strains = [trait_name] + [str(value) + for (trait_name, strain, value) in matrix] + results.append(",".join(strains)) + + return results + + +def compute_top_n_sample(start_vars, dataset, trait_list): + """check if dataset is of type probeset""" + + if dataset.type.lower() != "probeset": + return {} + + def __fetch_sample_ids__(samples_vals, samples_group): + all_samples = json.loads(samples_vals) + sample_data = get_sample_corr_data( + sample_type=samples_group, all_samples=all_samples, + dataset_samples=dataset.group.all_samples_ordered()) + + with database_connector() as conn: + with conn.cursor() as curr: + curr.execute( + """ + SELECT Strain.Name, Strain.Id FROM Strain, Species + WHERE Strain.Name IN {} + and Strain.SpeciesId=Species.Id + and Species.name = '{}' + """.format(create_in_clause(list(sample_data.keys())), + *mescape(dataset.group.species))) + return (sample_data, dict(curr.fetchall())) + + (sample_data, sample_ids) = __fetch_sample_ids__( + start_vars["sample_vals"], start_vars["corr_samples_group"]) + + with database_connector() as conn: + with conn.cursor() as curr: + # fetching strain data in bulk + curr.execute( + """ + SELECT * from ProbeSetData + where StrainID in {} + and id in (SELECT ProbeSetXRef.DataId + FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) + WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id + and ProbeSetFreeze.Name = '{}' + and ProbeSet.Name in {} + and ProbeSet.Id = ProbeSetXRef.ProbeSetId) + """.format( + create_in_clause(list(sample_ids.values())), + dataset.name, + create_in_clause(trait_list))) + + corr_data = chunk_dataset( + list(curr.fetchall()), len(sample_ids.values()), dataset.name) + + return run_correlation( + corr_data, list(sample_data.values()), "pearson", ",") + + def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict: (this_trait_geneid, geneid_dict, species) = do_lit_correlation( this_trait, this_dataset) - geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if + geneid_dict = {trait_name: geneid for (trait_name, geneid) + in geneid_dict.items() if corr_results.get(trait_name)} with database_connector() as conn: return reduce( @@ -69,6 +153,7 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]: } return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()] + def __compute_sample_corr__( start_vars: dict, corr_type: str, method: str, n_top: int, target_trait_info: tuple): @@ -86,11 +171,11 @@ def __compute_sample_corr__( r = ",".join(lts) target_data.append(r) - return run_correlation( target_data, list(sample_data.values()), method, ",", corr_type, n_top) + def __compute_tissue_corr__( start_vars: dict, corr_type: str, method: str, n_top: int, target_trait_info: tuple): @@ -111,6 +196,7 @@ def __compute_tissue_corr__( return run_correlation(data[1], data[0], method, ",", "tissue") return {} + def __compute_lit_corr__( start_vars: dict, corr_type: str, method: str, n_top: int, target_trait_info: tuple): @@ -130,15 +216,16 @@ def __compute_lit_corr__( {}) return {} + def compute_correlation_rust( start_vars: dict, corr_type: str, method: str = "pearson", - n_top: int = 500, compute_all: bool = False): + n_top: int = 500, should_compute_all: bool = False): """function to compute correlation""" target_trait_info = create_target_this_trait(start_vars) (this_dataset, this_trait, target_dataset, sample_data) = ( target_trait_info) - ## Replace this with `match ...` once we hit Python 3.10 + # Replace this with `match ...` once we hit Python 3.10 corr_type_fns = { "sample": __compute_sample_corr__, "tissue": __compute_tissue_corr__, @@ -146,19 +233,39 @@ def compute_correlation_rust( } results = corr_type_fns[corr_type]( start_vars, corr_type, method, n_top, target_trait_info) - ## END: Replace this with `match ...` once we hit Python 3.10 - top_tissue_results = {} - top_lit_results = {} - if compute_all: - # example compute of compute both correlation - top_tissue_results = compute_top_n_tissue( - this_dataset,this_trait,results,method) - top_lit_results = compute_top_n_lit(results,this_dataset,this_trait) + # END: Replace this with `match ...` once we hit Python 3.10 + + top_a = top_b = {} + + if should_compute_all: + + if corr_type == "sample": + + top_a = compute_top_n_tissue( + this_dataset, this_trait, results, method) + + top_b = compute_top_n_lit(results, this_dataset, this_trait) + + elif corr_type == "lit": + + # currently fails for lit + + top_a = compute_top_n_sample( + start_vars, target_dataset, list(results.keys())) + top_b = compute_top_n_tissue( + this_dataset, this_trait, results, method) + + else: + + top_a = compute_top_n_sample( + start_vars, target_dataset, list(results.keys())) + + top_b = compute_top_n_lit(results, this_dataset, this_trait) return { "correlation_results": merge_results( - results, top_tissue_results, top_lit_results), + results, top_a, top_b), "this_trait": this_trait.name, "target_dataset": start_vars['corr_dataset'], "return_results": n_top diff --git a/wqflask/wqflask/correlation/show_corr_results.py b/wqflask/wqflask/correlation/show_corr_results.py index 1c391386..f5fdd9b3 100644 --- a/wqflask/wqflask/correlation/show_corr_results.py +++ b/wqflask/wqflask/correlation/show_corr_results.py @@ -121,9 +121,9 @@ def correlation_json_for_table(correlation_data, this_trait, this_dataset, targe results_dict['dataset'] = target_dataset['name'] results_dict['hmac'] = hmac.data_hmac( '{}:{}'.format(target_trait['name'], target_dataset['name'])) - results_dict['sample_r'] = f"{float(trait['corr_coefficient']):.3f}" - results_dict['num_overlap'] = trait['num_overlap'] - results_dict['sample_p'] = f"{float(trait['p_value']):.3e}" + results_dict['sample_r'] = f"{float(trait.get('corr_coefficient',0.0)):.3f}" + results_dict['num_overlap'] = trait.get('num_overlap',0) + results_dict['sample_p'] = f"{float(trait.get('p_value',0)):.3e}" if target_dataset['type'] == "ProbeSet": results_dict['symbol'] = target_trait['symbol'] results_dict['description'] = "N/A" diff --git a/wqflask/wqflask/views.py b/wqflask/wqflask/views.py index 2e13451d..e054cd49 100644 --- a/wqflask/wqflask/views.py +++ b/wqflask/wqflask/views.py @@ -876,7 +876,7 @@ def test_corr_compute_page(): correlation_results = compute_correlation_rust(start_vars, start_vars["corr_type"], start_vars['corr_sample_method'], - int(start_vars.get("corr_return_results", 500))) + int(start_vars.get("corr_return_results", 500)),True) correlation_results = set_template_vars(request.form, correlation_results) |