about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py69
1 files changed, 38 insertions, 31 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 94720f54..2a2ad4a0 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -39,15 +39,14 @@ def chunk_dataset(dataset,steps,name):
         strains = [trait_name] + [str(value) for (trait_name, strain, value) in matrix]
         results.append(",".join(strains))
 
-    breakpoint()
     return results
 
 
 def compute_top_n_sample(start_vars, dataset, trait_list):
-    """only if dataset is of type probeset"""
-
-
+    """check if dataset is of type probeset"""
 
+    if dataset.type!= "Probeset":
+        return  {}
 
     def __fetch_sample_ids__(samples_vals, samples_group):
 
@@ -73,19 +72,9 @@ def compute_top_n_sample(start_vars, dataset, trait_list):
 
             )
 
-            return dict(curr.fetchall())
-
-
-
-
-
+            return (sample_data,dict(curr.fetchall()))
 
-
-
-
-  
-
-    ty = __fetch_sample_ids__(start_vars["sample_vals"], start_vars["corr_samples_group"])
+    (sample_data,sample_ids) = __fetch_sample_ids__(start_vars["sample_vals"], start_vars["corr_samples_group"])
 
 
 
@@ -93,6 +82,8 @@ def compute_top_n_sample(start_vars, dataset, trait_list):
 
         curr = conn.cursor()
 
+        #fetching strain data in bulk
+
         curr.execute(
 
         """
@@ -104,15 +95,14 @@ def compute_top_n_sample(start_vars, dataset, trait_list):
             and ProbeSetFreeze.Name = '{}'
             and ProbeSet.Name in {}
             and ProbeSet.Id = ProbeSetXRef.ProbeSetId)
-           """.format(create_in_clause(list(ty.values())),dataset.name,create_in_clause(trait_list))
+           """.format(create_in_clause(list(sample_ids.values())),dataset.name,create_in_clause(trait_list))
 
 
         )
 
+        corr_data = chunk_dataset(list(curr.fetchall()),len(sample_ids.values()),dataset.name)
 
-
-
-        return chunk_dataset(list(curr.fetchall()),len(ty.values()),dataset.name)
+        return run_correlation(corr_data,list(sample_data.values()),"pearson",",")
 
 
 def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict:
@@ -170,7 +160,10 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
                 **dict_c.get(trait_name, {})
             }
         }
-    return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
+    results = [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
+
+
+    return results
 
 
 def __compute_sample_corr__(
@@ -249,27 +242,41 @@ def compute_correlation_rust(
     }
     results = corr_type_fns[corr_type](
         start_vars, corr_type, method, n_top, target_trait_info)
+
     # END: Replace this with `match ...` once we hit Python 3.10
 
-    top_tissue_results = {}
-    top_lit_results = {}
 
+    top_a = top_b = {}
 
-    results = compute_top_n_sample(start_vars,target_dataset,list(results.keys()))
+    if compute_all:
 
+        if corr_type == "sample":
 
+            top_a = compute_top_n_tissue(
+            this_dataset, this_trait, results, method)
+        
+            top_b = compute_top_n_lit(results, this_dataset, this_trait)
 
-    breakpoint()
 
-    if compute_all:
-        # example compute of compute both correlation
-        top_tissue_results = compute_top_n_tissue(
+        elif corr_type == "lit":
+
+            #currently fails for lit
+
+            top_a = compute_top_n_sample(start_vars,target_dataset,list(results.keys()))
+            top_b =  compute_top_n_tissue(
             this_dataset, this_trait, results, method)
-        top_lit_results = compute_top_n_lit(results, this_dataset, this_trait)
 
-    return {
+        else:
+
+            top_a = compute_top_n_sample(start_vars,target_dataset,list(results.keys()))
+
+            top_b = compute_top_n_lit(results, this_dataset, this_trait)
+
+
+
+    return  {
         "correlation_results": merge_results(
-            results, top_tissue_results, top_lit_results),
+            results, top_a, top_b),
         "this_trait": this_trait.name,
         "target_dataset": start_vars['corr_dataset'],
         "return_results": n_top