about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-08-11 06:57:11 +0300
committerFrederick Muriuki Muriithi2022-08-11 06:57:11 +0300
commit90195f350758902cecbd184bfa500eac2a39a263 (patch)
tree8e9e0ab1a1e81408f574650205d55d796b19eef6
parentac9941b3cff7605500dc6ec6fda7a4f5db664a0a (diff)
downloadgenenetwork2-90195f350758902cecbd184bfa500eac2a39a263.tar.gz
Update format to prevent tissue correlation from failing
Update the data format of returned values so that it conforms with
expectatitions.
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py53
1 files changed, 24 insertions, 29 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 8a5021cc..4a22af72 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -1,5 +1,6 @@
 """module contains integration code for rust-gn3"""
 import json
+from functools import reduce
 from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values
 from wqflask.correlation.correlation_gn3_api import create_target_this_trait
 from wqflask.correlation.correlation_gn3_api import lit_for_trait_list
@@ -11,22 +12,21 @@ from gn3.computations.rust_correlation import parse_tissue_corr_data
 from gn3.db_utils import database_connector
 
 
-def compute_top_n_lit(corr_results, this_dataset, this_trait):
+def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict:
     (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
         this_trait, this_dataset)
 
     geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if
                    corr_results.get(trait_name)}
+    with database_connector() as conn:
+        return reduce(
+            lambda acc, corr: {**acc, **corr},
+            compute_all_lit_correlation(
+                conn=conn, trait_lists=list(geneid_dict.items()),
+                species=species, gene_id=this_trait_geneid),
+            {})
 
-    conn = database_connector()
-
-    with conn:
-
-        correlation_results = compute_all_lit_correlation(
-            conn=conn, trait_lists=list(geneid_dict.items()),
-            species=species, gene_id=this_trait_geneid)
-
-    return correlation_results
+    return {}
 
 
 def compute_top_n_tissue(this_dataset, this_trait, traits, method):
@@ -55,25 +55,19 @@ def compute_top_n_tissue(this_dataset, this_trait, traits, method):
     return {}
 
 
-def merge_results(dict_a, dict_b, dict_c):
+def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
     """code to merge diff corr  into individual dicts
     a"""
 
-    correlation_results = []
-
-    for (key, val) in dict_a.items():
-
-        if key in dict_b:
-
-            dict_a[key].update(dict_b[key])
-
-        if key in dict_c:
-
-            dict_a[key].update(dict_c[key])
-
-        correlation_results.append({key: dict_a[key]})
-
-    return correlation_results
+    def __merge__(trait_name, trait_corrs):
+        return {
+            trait_name: {
+                **trait_corrs,
+                **dict_b.get(trait_name, {}),
+                **dict_c.get(trait_name, {})
+            }
+        }
+    return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
 
 
 def compute_correlation_rust(
@@ -109,7 +103,7 @@ def compute_correlation_rust(
         top_lit_results = compute_top_n_lit(results,this_dataset,this_trait)
 
         # merging the results
-        results = merge_results(results,top_tissue_results,top_lit_results)
+        results = merge_results(results, top_tissue_results, top_lit_results)
 
     if corr_type == "tissue":
 
@@ -125,8 +119,9 @@ def compute_correlation_rust(
                                       dataset_vals=corr_result_tissue_vals_dict)
 
         if data:
-            results = run_correlation(
-                data[1], data[0], method, ",", "tissue")
+            results = merge_results(
+                run_correlation(data[1], data[0], method, ",", "tissue"),
+                {}, {})
 
     return {
         "correlation_results": results,