diff options
-rw-r--r-- | wqflask/wqflask/correlation/rust_correlation.py | 59 |
1 files changed, 54 insertions, 5 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py index 4106d3f0..4bd2dd9d 100644 --- a/wqflask/wqflask/correlation/rust_correlation.py +++ b/wqflask/wqflask/correlation/rust_correlation.py @@ -2,9 +2,37 @@ import json from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values from wqflask.correlation.correlation_gn3_api import create_target_this_trait +from wqflask.correlation.correlation_gn3_api import lit_for_trait_list +from wqflask.correlation.correlation_gn3_api import do_lit_correlation +from gn3.computations.correlations import compute_all_lit_correlation from gn3.computations.rust_correlation import run_correlation from gn3.computations.rust_correlation import get_sample_corr_data from gn3.computations.rust_correlation import parse_tissue_corr_data +from gn3.db_utils import database_connector + + + + + +def compute_top_n_tissue(this_dataset, this_trait, traits, method): + + trait_symbol_dict = dict({trait_name: symbol for ( + trait_name, symbol) in this_dataset.retrieve_genes("Symbol").items() if traits.get(trait_name)}) + + corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values( + symbol_list=list(trait_symbol_dict.values())) + + data = parse_tissue_corr_data(symbol_name=this_trait.symbol, + symbol_dict=get_trait_symbol_and_tissue_values( + symbol_list=[this_trait.symbol]), + dataset_symbols=trait_symbol_dict, + dataset_vals=corr_result_tissue_vals_dict) + + if data: + return run_correlation( + data[1], data[0], method, ",","tissue") + + return {} def compute_correlation_rust(start_vars: dict, corr_type: str, @@ -28,9 +56,29 @@ def compute_correlation_rust(start_vars: dict, corr_type: str, lts = [key] + [str(x) for x in val] r = ",".join(lts) target_data.append(r) + # breakpoint() + + results_k = run_correlation(target_data, ",".join( + [str(x) for x in list(sample_data.values())]), method, ",") + + tissue_top = compute_top_n_tissue( + this_dataset, this_trait, results_k, method) + + + lit_top = compute_top_n_lit(results_k,this_dataset,this_trait) + + + results = [] + + for (key,val) in results_k.items(): + if key in tissue_top: + results_k[key].update(tissue_top[key]) + + if key in lit_top: + results_k[key].update(lit_top[key]) + + results.append({key:results_k[key]}) - results = run_correlation( - target_data, list(sample_data.values()), method, ",") if corr_type == "tissue": @@ -41,15 +89,16 @@ def compute_correlation_rust(start_vars: dict, corr_type: str, data = parse_tissue_corr_data(symbol_name=this_trait.symbol, symbol_dict=get_trait_symbol_and_tissue_values( - symbol_list=[this_trait.symbol]), + symbol_list=[this_trait.symbol] + ), dataset_symbols=trait_symbol_dict, dataset_vals=corr_result_tissue_vals_dict) if data: results = run_correlation( - data[1], data[0], method, ",") + data[1], data[0], method, ",","tissue") - return {"correlation_results": results[0:n_top], + return {"correlation_results": results, "this_trait": this_trait.name, "target_dataset": start_vars['corr_dataset'], "return_results": n_top |