diff options
-rw-r--r-- | gn3/computations/correlations.py | 31 |
1 files changed, 14 insertions, 17 deletions
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py index c930df0..8eaa523 100644 --- a/gn3/computations/correlations.py +++ b/gn3/computations/correlations.py @@ -49,13 +49,9 @@ def normalize_values(a_values: List, ([2.3, 4.1, 5], [3.4, 6.2, 4.1], 3) """ - a_new = [] - b_new = [] for a_val, b_val in zip(a_values, b_values): if (a_val and b_val is not None): - a_new.append(a_val) - b_new.append(b_val) - return a_new, b_new, len(a_new) + yield a_val, b_val def compute_corr_coeff_p_value(primary_values: List, target_values: List, @@ -81,8 +77,10 @@ def compute_sample_r_correlation(trait_name, corr_method, trait_vals, correlation coeff and p value """ - (sanitized_traits_vals, sanitized_target_vals, - num_overlap) = normalize_values(trait_vals, target_samples_vals) + + sanitized_traits_vals, sanitized_target_vals = list( + zip(*list(normalize_values(trait_vals, target_samples_vals)))) + num_overlap = len(sanitized_traits_vals) if num_overlap > 5: @@ -114,13 +112,9 @@ def filter_shared_sample_keys(this_samplelist, filter the values using the shared keys """ - this_vals = [] - target_vals = [] for key, value in target_samplelist.items(): if key in this_samplelist: - target_vals.append(value) - this_vals.append(this_samplelist[key]) - return (this_vals, target_vals) + yield value, this_samplelist[key] def fast_compute_all_sample_correlation(this_trait, @@ -139,9 +133,10 @@ def fast_compute_all_sample_correlation(this_trait, for target_trait in target_dataset: trait_name = target_trait.get("trait_id") target_trait_data = target_trait["trait_sample_data"] - processed_values.append((trait_name, corr_method, *filter_shared_sample_keys( - this_trait_samples, target_trait_data))) - with multiprocessing.Pool(4) as pool: + processed_values.append((trait_name, corr_method, *list(zip(*list(filter_shared_sample_keys( + this_trait_samples, target_trait_data)))) + )) + with multiprocessing.Pool() as pool: results = pool.starmap(compute_sample_r_correlation, processed_values) for sample_correlation in results: @@ -172,8 +167,10 @@ def compute_all_sample_correlation(this_trait, for target_trait in target_dataset: trait_name = target_trait.get("trait_id") target_trait_data = target_trait["trait_sample_data"] - this_vals, target_vals = filter_shared_sample_keys( - this_trait_samples, target_trait_data) + this_vals, target_vals = list(zip(*list(filter_shared_sample_keys( + this_trait_samples, target_trait_data)))) + # this_vals, target_vals = filter_shared_sample_keys( + # this_trait_samples, target_trait_data) sample_correlation = compute_sample_r_correlation( trait_name=trait_name, |