aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py39
1 files changed, 19 insertions, 20 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 4bd2dd9d..161215c5 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -13,6 +13,24 @@ from gn3.db_utils import database_connector
+def compute_top_n_lit(corr_results, this_dataset, this_trait):
+ (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
+ this_trait, this_dataset)
+
+ geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if
+ corr_results.get(trait_name)}
+
+ conn = database_connector()
+
+ with conn:
+
+ correlation_results = compute_all_lit_correlation(
+ conn=conn, trait_lists=list(geneid_dict.items()),
+ species=species, gene_id=this_trait_geneid)
+
+ return correlation_results
+
+
def compute_top_n_tissue(this_dataset, this_trait, traits, method):
@@ -56,29 +74,10 @@ def compute_correlation_rust(start_vars: dict, corr_type: str,
lts = [key] + [str(x) for x in val]
r = ",".join(lts)
target_data.append(r)
- # breakpoint()
- results_k = run_correlation(target_data, ",".join(
+ results = run_correlation(target_data, ",".join(
[str(x) for x in list(sample_data.values())]), method, ",")
- tissue_top = compute_top_n_tissue(
- this_dataset, this_trait, results_k, method)
-
-
- lit_top = compute_top_n_lit(results_k,this_dataset,this_trait)
-
-
- results = []
-
- for (key,val) in results_k.items():
- if key in tissue_top:
- results_k[key].update(tissue_top[key])
-
- if key in lit_top:
- results_k[key].update(lit_top[key])
-
- results.append({key:results_k[key]})
-
if corr_type == "tissue":