diff options
Diffstat (limited to 'gn3/computations/correlations.py')
-rw-r--r-- | gn3/computations/correlations.py | 41 |
1 files changed, 18 insertions, 23 deletions
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py index bb13ff1..c5c56db 100644 --- a/gn3/computations/correlations.py +++ b/gn3/computations/correlations.py @@ -1,6 +1,7 @@ """module contains code for correlations""" import math import multiprocessing +from contextlib import closing from typing import List from typing import Tuple @@ -8,7 +9,7 @@ from typing import Optional from typing import Callable import scipy.stats -from gn3.computations.biweight import calculate_biweight_corr +import pingouin as pg def map_shared_keys_to_values(target_sample_keys: List, @@ -49,13 +50,9 @@ def normalize_values(a_values: List, ([2.3, 4.1, 5], [3.4, 6.2, 4.1], 3) """ - a_new = [] - b_new = [] for a_val, b_val in zip(a_values, b_values): if (a_val and b_val is not None): - a_new.append(a_val) - b_new.append(b_val) - return a_new, b_new, len(a_new) + yield a_val, b_val def compute_corr_coeff_p_value(primary_values: List, target_values: List, @@ -81,8 +78,10 @@ def compute_sample_r_correlation(trait_name, corr_method, trait_vals, correlation coeff and p value """ - (sanitized_traits_vals, sanitized_target_vals, - num_overlap) = normalize_values(trait_vals, target_samples_vals) + + sanitized_traits_vals, sanitized_target_vals = list( + zip(*list(normalize_values(trait_vals, target_samples_vals)))) + num_overlap = len(sanitized_traits_vals) if num_overlap > 5: @@ -102,11 +101,10 @@ package :not packaged in guix """ - try: - results = calculate_biweight_corr(x_val, y_val) - return results - except Exception as error: - raise error + results = pg.corr(x_val, y_val, method="bicor") + corr_coeff = results["r"].values[0] + p_val = results["p-val"].values[0] + return (corr_coeff, p_val) def filter_shared_sample_keys(this_samplelist, @@ -115,13 +113,9 @@ def filter_shared_sample_keys(this_samplelist, filter the values using the shared keys """ - this_vals = [] - target_vals = [] for key, value in target_samplelist.items(): if key in this_samplelist: - target_vals.append(value) - this_vals.append(this_samplelist[key]) - return (this_vals, target_vals) + yield this_samplelist[key], value def fast_compute_all_sample_correlation(this_trait, @@ -140,9 +134,10 @@ def fast_compute_all_sample_correlation(this_trait, for target_trait in target_dataset: trait_name = target_trait.get("trait_id") target_trait_data = target_trait["trait_sample_data"] - processed_values.append((trait_name, corr_method, *filter_shared_sample_keys( - this_trait_samples, target_trait_data))) - with multiprocessing.Pool(4) as pool: + processed_values.append((trait_name, corr_method, + list(zip(*list(filter_shared_sample_keys( + this_trait_samples, target_trait_data)))))) + with closing(multiprocessing.Pool()) as pool: results = pool.starmap(compute_sample_r_correlation, processed_values) for sample_correlation in results: @@ -173,8 +168,8 @@ def compute_all_sample_correlation(this_trait, for target_trait in target_dataset: trait_name = target_trait.get("trait_id") target_trait_data = target_trait["trait_sample_data"] - this_vals, target_vals = filter_shared_sample_keys( - this_trait_samples, target_trait_data) + this_vals, target_vals = list(zip(*list(filter_shared_sample_keys( + this_trait_samples, target_trait_data)))) sample_correlation = compute_sample_r_correlation( trait_name=trait_name, |