about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py21
1 files changed, 11 insertions, 10 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 88133b31..5c4d0b8a 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -99,7 +99,7 @@ def chunk_dataset(dataset, steps, name):
     for i in range(0, len(dataset), steps):
         matrix = list(dataset[i:i + steps])
         results.append([traits_name_dict[matrix[0][0]]] + [str(value)
-                                  for (trait_name, strain, value) in matrix])        
+                                                           for (trait_name, strain, value) in matrix])
     return results
 
 
@@ -159,9 +159,9 @@ def compute_top_n_sample(start_vars, dataset, trait_list):
             corr_data, list(sample_data.values()), "pearson", ",")
 
 
-def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict:
+def compute_top_n_lit(corr_results, target_dataset, this_trait) -> dict:
     (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
-        this_trait, this_dataset)
+        this_trait, target_dataset)
 
     geneid_dict = {trait_name: geneid for (trait_name, geneid)
                    in geneid_dict.items() if
@@ -177,14 +177,14 @@ def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict:
     return {}
 
 
-def compute_top_n_tissue(this_dataset, this_trait, traits, method):
+def compute_top_n_tissue(target_dataset, this_trait, traits, method):
 
     # refactor lots of rpt
 
     trait_symbol_dict = dict({
         trait_name: symbol
         for (trait_name, symbol)
-        in this_dataset.retrieve_genes("Symbol").items()
+        in target_dataset.retrieve_genes("Symbol").items()
         if traits.get(trait_name)})
 
     corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
@@ -248,7 +248,6 @@ def __compute_sample_corr__(
 
     target_dataset.get_trait_data(list(sample_data.keys()))
 
-
     def __merge_key_and_values__(rows, current):
         wo_nones = [value for value in current[1] if value is not None]
         if len(wo_nones) > 0:
@@ -265,12 +264,14 @@ def __compute_sample_corr__(
         target_data, list(sample_data.values()), method, ",", corr_type,
         n_top)
 
+
 def __datasets_compatible_p__(trait_dataset, target_dataset, corr_method):
     return not (
         corr_method in ("tissue", "Tissue r", "Literature r", "lit")
         and (trait_dataset.type == "ProbeSet" and
              target_dataset.type in ("Publish", "Geno")))
 
+
 def __compute_tissue_corr__(
         start_vars: dict, corr_type: str, method: str, n_top: int,
         target_trait_info: tuple):
@@ -344,9 +345,9 @@ def compute_correlation_rust(
         if corr_type == "sample":
 
             top_a = compute_top_n_tissue(
-                this_dataset, this_trait, results, method)
+                target_dataset, this_trait, results, method)
 
-            top_b = compute_top_n_lit(results, this_dataset, this_trait)
+            top_b = compute_top_n_lit(results, target_dataset, this_trait)
 
         elif corr_type == "lit":
 
@@ -355,14 +356,14 @@ def compute_correlation_rust(
             top_a = compute_top_n_sample(
                 start_vars, target_dataset, list(results.keys()))
             top_b = compute_top_n_tissue(
-                this_dataset, this_trait, results, method)
+                target_dataset, this_trait, results, method)
 
         else:
 
             top_a = compute_top_n_sample(
                 start_vars, target_dataset, list(results.keys()))
 
-            top_b = compute_top_n_lit(results, this_dataset, this_trait)
+            top_b = compute_top_n_lit(results, target_dataset, this_trait)
 
     return {
         "correlation_results": merge_results(