aboutsummaryrefslogtreecommitdiff
path: root/wqflask
diff options
context:
space:
mode:
Diffstat (limited to 'wqflask')
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py59
1 files changed, 54 insertions, 5 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 4106d3f0..4bd2dd9d 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -2,9 +2,37 @@
import json
from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values
from wqflask.correlation.correlation_gn3_api import create_target_this_trait
+from wqflask.correlation.correlation_gn3_api import lit_for_trait_list
+from wqflask.correlation.correlation_gn3_api import do_lit_correlation
+from gn3.computations.correlations import compute_all_lit_correlation
from gn3.computations.rust_correlation import run_correlation
from gn3.computations.rust_correlation import get_sample_corr_data
from gn3.computations.rust_correlation import parse_tissue_corr_data
+from gn3.db_utils import database_connector
+
+
+
+
+
+def compute_top_n_tissue(this_dataset, this_trait, traits, method):
+
+ trait_symbol_dict = dict({trait_name: symbol for (
+ trait_name, symbol) in this_dataset.retrieve_genes("Symbol").items() if traits.get(trait_name)})
+
+ corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
+ symbol_list=list(trait_symbol_dict.values()))
+
+ data = parse_tissue_corr_data(symbol_name=this_trait.symbol,
+ symbol_dict=get_trait_symbol_and_tissue_values(
+ symbol_list=[this_trait.symbol]),
+ dataset_symbols=trait_symbol_dict,
+ dataset_vals=corr_result_tissue_vals_dict)
+
+ if data:
+ return run_correlation(
+ data[1], data[0], method, ",","tissue")
+
+ return {}
def compute_correlation_rust(start_vars: dict, corr_type: str,
@@ -28,9 +56,29 @@ def compute_correlation_rust(start_vars: dict, corr_type: str,
lts = [key] + [str(x) for x in val]
r = ",".join(lts)
target_data.append(r)
+ # breakpoint()
+
+ results_k = run_correlation(target_data, ",".join(
+ [str(x) for x in list(sample_data.values())]), method, ",")
+
+ tissue_top = compute_top_n_tissue(
+ this_dataset, this_trait, results_k, method)
+
+
+ lit_top = compute_top_n_lit(results_k,this_dataset,this_trait)
+
+
+ results = []
+
+ for (key,val) in results_k.items():
+ if key in tissue_top:
+ results_k[key].update(tissue_top[key])
+
+ if key in lit_top:
+ results_k[key].update(lit_top[key])
+
+ results.append({key:results_k[key]})
- results = run_correlation(
- target_data, list(sample_data.values()), method, ",")
if corr_type == "tissue":
@@ -41,15 +89,16 @@ def compute_correlation_rust(start_vars: dict, corr_type: str,
data = parse_tissue_corr_data(symbol_name=this_trait.symbol,
symbol_dict=get_trait_symbol_and_tissue_values(
- symbol_list=[this_trait.symbol]),
+ symbol_list=[this_trait.symbol]
+ ),
dataset_symbols=trait_symbol_dict,
dataset_vals=corr_result_tissue_vals_dict)
if data:
results = run_correlation(
- data[1], data[0], method, ",")
+ data[1], data[0], method, ",","tissue")
- return {"correlation_results": results[0:n_top],
+ return {"correlation_results": results,
"this_trait": this_trait.name,
"target_dataset": start_vars['corr_dataset'],
"return_results": n_top