about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-07-28 09:39:39 +0300
committerFrederick Muriuki Muriithi2022-07-28 09:39:39 +0300
commit3c8da2cae39efc25b320b78e2a1ed16afc1c5b8a (patch)
tree4599cdaf70d7d3400cec1096f3f3284be276100d
parentab354ac46bf7f84ed2504c6e0061ede808ab6ee1 (diff)
downloadgenenetwork3-3c8da2cae39efc25b320b78e2a1ed16afc1c5b8a.tar.gz
Update sample correlations code to use multiprocessing
* To help speed up the processing of the correlations, convert the
`compute_all_sample_correlation` function to use the multiprocessing module.
-rw-r--r--gn3/computations/correlations.py67
1 files changed, 35 insertions, 32 deletions
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py
index a0da2c4..ea1a862 100644
--- a/gn3/computations/correlations.py
+++ b/gn3/computations/correlations.py
@@ -2,6 +2,7 @@
 import math
 import multiprocessing
 from contextlib import closing
+from multiprocessing import Pool, cpu_count
 
 from typing import List
 from typing import Tuple
@@ -161,27 +162,12 @@ def fast_compute_all_sample_correlation(this_trait,
         corr_results,
         key=lambda trait_name: -abs(list(trait_name.values())[0]["corr_coefficient"]))
 
-
-def compute_all_sample_correlation(this_trait,
-                                   target_dataset,
-                                   corr_method="pearson") -> List:
-    """Temp function to benchmark with compute_all_sample_r alternative to
-    compute_all_sample_r where we use multiprocessing
-
-    """
-    this_trait_samples = this_trait["trait_sample_data"]
-    corr_results = []
-    for target_trait in target_dataset:
-        trait_name = target_trait.get("trait_id")
-        target_trait_data = target_trait["trait_sample_data"]
-
-        try:
-            this_vals, target_vals = list(zip(*list(filter_shared_sample_keys(
-                this_trait_samples, target_trait_data))))
-
-        except ValueError:
-            # case where no matching strain names
-            continue
+def __corr_compute__(trait_samples, target_trait, corr_method):
+    trait_name = target_trait.get("trait_id")
+    target_trait_data = target_trait["trait_sample_data"]
+    try:
+        this_vals, target_vals = list(zip(*list(filter_shared_sample_keys(
+            trait_samples, target_trait_data))))
 
         sample_correlation = compute_sample_r_correlation(
             trait_name=trait_name,
@@ -191,17 +177,34 @@ def compute_all_sample_correlation(this_trait,
         if sample_correlation is not None:
             (trait_name, corr_coefficient,
              p_value, num_overlap) = sample_correlation
-        else:
-            continue
-        corr_result = {
-            "corr_coefficient": corr_coefficient,
-            "p_value": p_value,
-            "num_overlap": num_overlap
-        }
-        corr_results.append({trait_name: corr_result})
-    return sorted(
-        corr_results,
-        key=lambda trait_name: -abs(list(trait_name.values())[0]["corr_coefficient"]))
+            return {trait_name: {
+                "corr_coefficient": corr_coefficient,
+                "p_value": p_value,
+                "num_overlap": num_overlap
+            }}
+    except ValueError:
+        # case where no matching strain names
+        return None
+    return None
+
+def compute_all_sample_correlation(this_trait,
+                                   target_dataset,
+                                   corr_method="pearson") -> List:
+    """Temp function to benchmark with compute_all_sample_r alternative to
+    compute_all_sample_r where we use multiprocessing
+
+    """
+    this_trait_samples = this_trait["trait_sample_data"]
+    with Pool(processes=(cpu_count() - 1)) as pool:
+        return sorted(
+            (
+                corr for corr in
+                pool.starmap(
+                    __corr_compute__,
+                    ((this_trait_samples, trait, corr_method) for trait in target_dataset))
+                if corr is not None),
+            key=lambda trait_name: -abs(
+                list(trait_name.values())[0]["corr_coefficient"]))
 
 
 def tissue_correlation_for_trait(