diff options
Diffstat (limited to 'gn3/computations')
-rw-r--r-- | gn3/computations/correlations.py | 67 |
1 files changed, 35 insertions, 32 deletions
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py index a0da2c4..ea1a862 100644 --- a/gn3/computations/correlations.py +++ b/gn3/computations/correlations.py @@ -2,6 +2,7 @@ import math import multiprocessing from contextlib import closing +from multiprocessing import Pool, cpu_count from typing import List from typing import Tuple @@ -161,27 +162,12 @@ def fast_compute_all_sample_correlation(this_trait, corr_results, key=lambda trait_name: -abs(list(trait_name.values())[0]["corr_coefficient"])) - -def compute_all_sample_correlation(this_trait, - target_dataset, - corr_method="pearson") -> List: - """Temp function to benchmark with compute_all_sample_r alternative to - compute_all_sample_r where we use multiprocessing - - """ - this_trait_samples = this_trait["trait_sample_data"] - corr_results = [] - for target_trait in target_dataset: - trait_name = target_trait.get("trait_id") - target_trait_data = target_trait["trait_sample_data"] - - try: - this_vals, target_vals = list(zip(*list(filter_shared_sample_keys( - this_trait_samples, target_trait_data)))) - - except ValueError: - # case where no matching strain names - continue +def __corr_compute__(trait_samples, target_trait, corr_method): + trait_name = target_trait.get("trait_id") + target_trait_data = target_trait["trait_sample_data"] + try: + this_vals, target_vals = list(zip(*list(filter_shared_sample_keys( + trait_samples, target_trait_data)))) sample_correlation = compute_sample_r_correlation( trait_name=trait_name, @@ -191,17 +177,34 @@ def compute_all_sample_correlation(this_trait, if sample_correlation is not None: (trait_name, corr_coefficient, p_value, num_overlap) = sample_correlation - else: - continue - corr_result = { - "corr_coefficient": corr_coefficient, - "p_value": p_value, - "num_overlap": num_overlap - } - corr_results.append({trait_name: corr_result}) - return sorted( - corr_results, - key=lambda trait_name: -abs(list(trait_name.values())[0]["corr_coefficient"])) + return {trait_name: { + "corr_coefficient": corr_coefficient, + "p_value": p_value, + "num_overlap": num_overlap + }} + except ValueError: + # case where no matching strain names + return None + return None + +def compute_all_sample_correlation(this_trait, + target_dataset, + corr_method="pearson") -> List: + """Temp function to benchmark with compute_all_sample_r alternative to + compute_all_sample_r where we use multiprocessing + + """ + this_trait_samples = this_trait["trait_sample_data"] + with Pool(processes=(cpu_count() - 1)) as pool: + return sorted( + ( + corr for corr in + pool.starmap( + __corr_compute__, + ((this_trait_samples, trait, corr_method) for trait in target_dataset)) + if corr is not None), + key=lambda trait_name: -abs( + list(trait_name.values())[0]["corr_coefficient"])) def tissue_correlation_for_trait( |