From d037b551b7877b7611fcc161006d82e1e148d6aa Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Thu, 3 Feb 2022 05:23:24 +0300 Subject: Use multiprocessing to speed up computation This commit refactors the code to make it possible to use multiprocessing to speed up the computation of the partial correlations. The major refactor is to move the `__compute_trait_info__` function to the top-level of the module, and provide to it all the other necessary context via the new args. --- gn3/computations/partial_correlations.py | 102 +++++++++++++++---------------- gn3/settings.py | 2 + 2 files changed, 53 insertions(+), 51 deletions(-) (limited to 'gn3') diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py index 157024b..52c833f 100644 --- a/gn3/computations/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -6,6 +6,7 @@ GeneNetwork1. """ import math +import multiprocessing as mp from functools import reduce, partial from typing import Any, Tuple, Union, Sequence @@ -14,12 +15,12 @@ import pandas import pingouin from scipy.stats import pearsonr, spearmanr -from gn3.settings import TEXTDIR from gn3.random import random_string from gn3.function_helpers import compose from gn3.data_helpers import parse_csv_line from gn3.db.traits import export_informative from gn3.db.datasets import retrieve_trait_dataset +from gn3.settings import TEXTDIR, MULTIPROCESSOR_PROCS from gn3.db.partial_correlations import traits_info, traits_data from gn3.db.species import species_name, translate_to_mouse_gene_id from gn3.db.correlations import ( @@ -281,6 +282,47 @@ def build_data_frame( return interm_df.rename(columns={"z0": "z"}) return interm_df +def compute_trait_info(primary_vals, control_vals, target, method): + targ_vals = target[0] + targ_name = target[1] + primary = [ + prim for targ, prim in zip(targ_vals, primary_vals) + if targ is not None] + + if len(primary) < 3: + return None + + def __remove_controls_for_target_nones(cont_targ): + return tuple(cont for cont, targ in cont_targ if targ is not None) + + conts_targs = tuple(tuple( + zip(control, targ_vals)) for control in control_vals) + datafrm = build_data_frame( + primary, + [targ for targ in targ_vals if targ is not None], + [__remove_controls_for_target_nones(cont_targ) + for cont_targ in conts_targs]) + covariates = "z" if datafrm.shape[1] == 3 else [ + col for col in datafrm.columns if col not in ("x", "y")] + ppc = pingouin.partial_corr( + data=datafrm, x="x", y="y", covar=covariates, method=( + "pearson" if "pearson" in method.lower() else "spearman")) + pc_coeff = ppc["r"][0] + + zero_order_corr = pingouin.corr( + datafrm["x"], datafrm["y"], method=( + "pearson" if "pearson" in method.lower() else "spearman")) + + if math.isnan(pc_coeff): + return ( + targ_name, len(primary), pc_coeff, 1, zero_order_corr["r"][0], + zero_order_corr["p-val"][0]) + return ( + targ_name, len(primary), pc_coeff, + (ppc["p-val"][0] if not math.isnan(ppc["p-val"][0]) else ( + 0 if (abs(pc_coeff - 1) < 0.0000001) else 1)), + zero_order_corr["r"][0], zero_order_corr["p-val"][0]) + def compute_partial( primary_vals, control_vals, target_vals, target_names, method: str) -> Tuple[ @@ -296,57 +338,15 @@ def compute_partial( This implementation reworks the child function `compute_partial` which will then be used in the place of `determinPartialsByR`. - - TODO: moving forward, we might need to use the multiprocessing library to - speed up the computations, in case they are found to be slow. """ - # replace the R code with `pingouin.partial_corr` - def __compute_trait_info__(target): - targ_vals = target[0] - targ_name = target[1] - primary = [ - prim for targ, prim in zip(targ_vals, primary_vals) - if targ is not None] - - if len(primary) < 3: - return None - - def __remove_controls_for_target_nones(cont_targ): - return tuple(cont for cont, targ in cont_targ if targ is not None) - - conts_targs = tuple(tuple( - zip(control, targ_vals)) for control in control_vals) - datafrm = build_data_frame( - primary, - [targ for targ in targ_vals if targ is not None], - [__remove_controls_for_target_nones(cont_targ) - for cont_targ in conts_targs]) - covariates = "z" if datafrm.shape[1] == 3 else [ - col for col in datafrm.columns if col not in ("x", "y")] - ppc = pingouin.partial_corr( - data=datafrm, x="x", y="y", covar=covariates, method=( - "pearson" if "pearson" in method.lower() else "spearman")) - pc_coeff = ppc["r"][0] - - zero_order_corr = pingouin.corr( - datafrm["x"], datafrm["y"], method=( - "pearson" if "pearson" in method.lower() else "spearman")) - - if math.isnan(pc_coeff): - return ( - targ_name, len(primary), pc_coeff, 1, zero_order_corr["r"][0], - zero_order_corr["p-val"][0]) - return ( - targ_name, len(primary), pc_coeff, - (ppc["p-val"][0] if not math.isnan(ppc["p-val"][0]) else ( - 0 if (abs(pc_coeff - 1) < 0.0000001) else 1)), - zero_order_corr["r"][0], zero_order_corr["p-val"][0]) - - return tuple( - result for result in ( - __compute_trait_info__(target) - for target in zip(target_vals, target_names)) - if result is not None) + with mp.Pool(MULTIPROCESSOR_PROCS or (mp.cpu_count() - 1)) as pool: + return tuple( + result for result in + pool.starmap( + compute_trait_info, + ((primary_vals, control_vals, (tvals, tname), method) + for tvals, tname in zip(target_vals, target_names))) + if result is not None) def partial_correlations_normal(# pylint: disable=R0913 primary_vals, control_vals, input_trait_gene_id, trait_database, diff --git a/gn3/settings.py b/gn3/settings.py index c945fbf..87e8f4b 100644 --- a/gn3/settings.py +++ b/gn3/settings.py @@ -57,3 +57,5 @@ GNSHARE = os.environ.get("GNSHARE", "/gnshare/gn/") TEXTDIR = f"{GNSHARE}/web/ProbeSetFreeze_DataMatrix" ROUND_TO = 10 + +MULTIPROCESSOR_PROCS = 6 # Number of processes to spawn -- cgit v1.2.3