about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py59
1 files changed, 54 insertions, 5 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 4106d3f0..4bd2dd9d 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -2,9 +2,37 @@
 import json
 from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values
 from wqflask.correlation.correlation_gn3_api import create_target_this_trait
+from wqflask.correlation.correlation_gn3_api import lit_for_trait_list
+from wqflask.correlation.correlation_gn3_api import do_lit_correlation
+from gn3.computations.correlations import compute_all_lit_correlation
 from gn3.computations.rust_correlation import run_correlation
 from gn3.computations.rust_correlation import get_sample_corr_data
 from gn3.computations.rust_correlation import parse_tissue_corr_data
+from gn3.db_utils import database_connector
+
+
+
+
+
+def compute_top_n_tissue(this_dataset, this_trait, traits, method):
+
+    trait_symbol_dict = dict({trait_name: symbol for (
+        trait_name, symbol) in this_dataset.retrieve_genes("Symbol").items() if traits.get(trait_name)})
+
+    corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
+        symbol_list=list(trait_symbol_dict.values()))
+
+    data = parse_tissue_corr_data(symbol_name=this_trait.symbol,
+                                  symbol_dict=get_trait_symbol_and_tissue_values(
+                                      symbol_list=[this_trait.symbol]),
+                                  dataset_symbols=trait_symbol_dict,
+                                  dataset_vals=corr_result_tissue_vals_dict)
+
+    if data:
+        return run_correlation(
+            data[1], data[0], method, ",","tissue")
+
+    return {}
 
 
 def compute_correlation_rust(start_vars: dict, corr_type: str,
@@ -28,9 +56,29 @@ def compute_correlation_rust(start_vars: dict, corr_type: str,
             lts = [key] + [str(x) for x in val]
             r = ",".join(lts)
             target_data.append(r)
+        # breakpoint()
+
+        results_k = run_correlation(target_data, ",".join(
+            [str(x) for x in list(sample_data.values())]), method, ",")
+
+        tissue_top = compute_top_n_tissue(
+            this_dataset, this_trait, results_k, method)
+
+
+        lit_top = compute_top_n_lit(results_k,this_dataset,this_trait)
+
+
+        results = []
+
+        for (key,val) in results_k.items():
+            if key in tissue_top:
+                results_k[key].update(tissue_top[key])
+
+            if key in lit_top:
+                results_k[key].update(lit_top[key])
+
+            results.append({key:results_k[key]})
 
-        results = run_correlation(
-            target_data, list(sample_data.values()), method, ",")
 
 
     if corr_type == "tissue":
@@ -41,15 +89,16 @@ def compute_correlation_rust(start_vars: dict, corr_type: str,
 
         data = parse_tissue_corr_data(symbol_name=this_trait.symbol,
                                       symbol_dict=get_trait_symbol_and_tissue_values(
-                                          symbol_list=[this_trait.symbol]),
+                                          symbol_list=[this_trait.symbol]
+                                      ),
                                       dataset_symbols=trait_symbol_dict,
                                       dataset_vals=corr_result_tissue_vals_dict)
 
         if data:
             results = run_correlation(
-                data[1], data[0], method, ",")
+                data[1], data[0], method, ",","tissue")
 
-    return {"correlation_results": results[0:n_top],
+    return {"correlation_results": results,
             "this_trait": this_trait.name,
             "target_dataset": start_vars['corr_dataset'],
             "return_results": n_top