aboutsummaryrefslogtreecommitdiff
path: root/gn3/computations/correlations.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/computations/correlations.py')
-rw-r--r--gn3/computations/correlations.py67
1 files changed, 35 insertions, 32 deletions
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py
index a0da2c4..ea1a862 100644
--- a/gn3/computations/correlations.py
+++ b/gn3/computations/correlations.py
@@ -2,6 +2,7 @@
import math
import multiprocessing
from contextlib import closing
+from multiprocessing import Pool, cpu_count
from typing import List
from typing import Tuple
@@ -161,27 +162,12 @@ def fast_compute_all_sample_correlation(this_trait,
corr_results,
key=lambda trait_name: -abs(list(trait_name.values())[0]["corr_coefficient"]))
-
-def compute_all_sample_correlation(this_trait,
- target_dataset,
- corr_method="pearson") -> List:
- """Temp function to benchmark with compute_all_sample_r alternative to
- compute_all_sample_r where we use multiprocessing
-
- """
- this_trait_samples = this_trait["trait_sample_data"]
- corr_results = []
- for target_trait in target_dataset:
- trait_name = target_trait.get("trait_id")
- target_trait_data = target_trait["trait_sample_data"]
-
- try:
- this_vals, target_vals = list(zip(*list(filter_shared_sample_keys(
- this_trait_samples, target_trait_data))))
-
- except ValueError:
- # case where no matching strain names
- continue
+def __corr_compute__(trait_samples, target_trait, corr_method):
+ trait_name = target_trait.get("trait_id")
+ target_trait_data = target_trait["trait_sample_data"]
+ try:
+ this_vals, target_vals = list(zip(*list(filter_shared_sample_keys(
+ trait_samples, target_trait_data))))
sample_correlation = compute_sample_r_correlation(
trait_name=trait_name,
@@ -191,17 +177,34 @@ def compute_all_sample_correlation(this_trait,
if sample_correlation is not None:
(trait_name, corr_coefficient,
p_value, num_overlap) = sample_correlation
- else:
- continue
- corr_result = {
- "corr_coefficient": corr_coefficient,
- "p_value": p_value,
- "num_overlap": num_overlap
- }
- corr_results.append({trait_name: corr_result})
- return sorted(
- corr_results,
- key=lambda trait_name: -abs(list(trait_name.values())[0]["corr_coefficient"]))
+ return {trait_name: {
+ "corr_coefficient": corr_coefficient,
+ "p_value": p_value,
+ "num_overlap": num_overlap
+ }}
+ except ValueError:
+ # case where no matching strain names
+ return None
+ return None
+
+def compute_all_sample_correlation(this_trait,
+ target_dataset,
+ corr_method="pearson") -> List:
+ """Temp function to benchmark with compute_all_sample_r alternative to
+ compute_all_sample_r where we use multiprocessing
+
+ """
+ this_trait_samples = this_trait["trait_sample_data"]
+ with Pool(processes=(cpu_count() - 1)) as pool:
+ return sorted(
+ (
+ corr for corr in
+ pool.starmap(
+ __corr_compute__,
+ ((this_trait_samples, trait, corr_method) for trait in target_dataset))
+ if corr is not None),
+ key=lambda trait_name: -abs(
+ list(trait_name.values())[0]["corr_coefficient"]))
def tissue_correlation_for_trait(