diff options
-rw-r--r-- | gn3/computations/partial_correlations.py | 102 | ||||
-rw-r--r-- | gn3/settings.py | 2 |
2 files changed, 53 insertions, 51 deletions
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py index 157024b..52c833f 100644 --- a/gn3/computations/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -6,6 +6,7 @@ GeneNetwork1. """ import math +import multiprocessing as mp from functools import reduce, partial from typing import Any, Tuple, Union, Sequence @@ -14,12 +15,12 @@ import pandas import pingouin from scipy.stats import pearsonr, spearmanr -from gn3.settings import TEXTDIR from gn3.random import random_string from gn3.function_helpers import compose from gn3.data_helpers import parse_csv_line from gn3.db.traits import export_informative from gn3.db.datasets import retrieve_trait_dataset +from gn3.settings import TEXTDIR, MULTIPROCESSOR_PROCS from gn3.db.partial_correlations import traits_info, traits_data from gn3.db.species import species_name, translate_to_mouse_gene_id from gn3.db.correlations import ( @@ -281,6 +282,47 @@ def build_data_frame( return interm_df.rename(columns={"z0": "z"}) return interm_df +def compute_trait_info(primary_vals, control_vals, target, method): + targ_vals = target[0] + targ_name = target[1] + primary = [ + prim for targ, prim in zip(targ_vals, primary_vals) + if targ is not None] + + if len(primary) < 3: + return None + + def __remove_controls_for_target_nones(cont_targ): + return tuple(cont for cont, targ in cont_targ if targ is not None) + + conts_targs = tuple(tuple( + zip(control, targ_vals)) for control in control_vals) + datafrm = build_data_frame( + primary, + [targ for targ in targ_vals if targ is not None], + [__remove_controls_for_target_nones(cont_targ) + for cont_targ in conts_targs]) + covariates = "z" if datafrm.shape[1] == 3 else [ + col for col in datafrm.columns if col not in ("x", "y")] + ppc = pingouin.partial_corr( + data=datafrm, x="x", y="y", covar=covariates, method=( + "pearson" if "pearson" in method.lower() else "spearman")) + pc_coeff = ppc["r"][0] + + zero_order_corr = pingouin.corr( + datafrm["x"], datafrm["y"], method=( + "pearson" if "pearson" in method.lower() else "spearman")) + + if math.isnan(pc_coeff): + return ( + targ_name, len(primary), pc_coeff, 1, zero_order_corr["r"][0], + zero_order_corr["p-val"][0]) + return ( + targ_name, len(primary), pc_coeff, + (ppc["p-val"][0] if not math.isnan(ppc["p-val"][0]) else ( + 0 if (abs(pc_coeff - 1) < 0.0000001) else 1)), + zero_order_corr["r"][0], zero_order_corr["p-val"][0]) + def compute_partial( primary_vals, control_vals, target_vals, target_names, method: str) -> Tuple[ @@ -296,57 +338,15 @@ def compute_partial( This implementation reworks the child function `compute_partial` which will then be used in the place of `determinPartialsByR`. - - TODO: moving forward, we might need to use the multiprocessing library to - speed up the computations, in case they are found to be slow. """ - # replace the R code with `pingouin.partial_corr` - def __compute_trait_info__(target): - targ_vals = target[0] - targ_name = target[1] - primary = [ - prim for targ, prim in zip(targ_vals, primary_vals) - if targ is not None] - - if len(primary) < 3: - return None - - def __remove_controls_for_target_nones(cont_targ): - return tuple(cont for cont, targ in cont_targ if targ is not None) - - conts_targs = tuple(tuple( - zip(control, targ_vals)) for control in control_vals) - datafrm = build_data_frame( - primary, - [targ for targ in targ_vals if targ is not None], - [__remove_controls_for_target_nones(cont_targ) - for cont_targ in conts_targs]) - covariates = "z" if datafrm.shape[1] == 3 else [ - col for col in datafrm.columns if col not in ("x", "y")] - ppc = pingouin.partial_corr( - data=datafrm, x="x", y="y", covar=covariates, method=( - "pearson" if "pearson" in method.lower() else "spearman")) - pc_coeff = ppc["r"][0] - - zero_order_corr = pingouin.corr( - datafrm["x"], datafrm["y"], method=( - "pearson" if "pearson" in method.lower() else "spearman")) - - if math.isnan(pc_coeff): - return ( - targ_name, len(primary), pc_coeff, 1, zero_order_corr["r"][0], - zero_order_corr["p-val"][0]) - return ( - targ_name, len(primary), pc_coeff, - (ppc["p-val"][0] if not math.isnan(ppc["p-val"][0]) else ( - 0 if (abs(pc_coeff - 1) < 0.0000001) else 1)), - zero_order_corr["r"][0], zero_order_corr["p-val"][0]) - - return tuple( - result for result in ( - __compute_trait_info__(target) - for target in zip(target_vals, target_names)) - if result is not None) + with mp.Pool(MULTIPROCESSOR_PROCS or (mp.cpu_count() - 1)) as pool: + return tuple( + result for result in + pool.starmap( + compute_trait_info, + ((primary_vals, control_vals, (tvals, tname), method) + for tvals, tname in zip(target_vals, target_names))) + if result is not None) def partial_correlations_normal(# pylint: disable=R0913 primary_vals, control_vals, input_trait_gene_id, trait_database, diff --git a/gn3/settings.py b/gn3/settings.py index c945fbf..87e8f4b 100644 --- a/gn3/settings.py +++ b/gn3/settings.py @@ -57,3 +57,5 @@ GNSHARE = os.environ.get("GNSHARE", "/gnshare/gn/") TEXTDIR = f"{GNSHARE}/web/ProbeSetFreeze_DataMatrix" ROUND_TO = 10 + +MULTIPROCESSOR_PROCS = 6 # Number of processes to spawn |