aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/computations/partial_correlations.py102
-rw-r--r--gn3/settings.py2
2 files changed, 53 insertions, 51 deletions
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 157024b..52c833f 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -6,6 +6,7 @@ GeneNetwork1.
"""
import math
+import multiprocessing as mp
from functools import reduce, partial
from typing import Any, Tuple, Union, Sequence
@@ -14,12 +15,12 @@ import pandas
import pingouin
from scipy.stats import pearsonr, spearmanr
-from gn3.settings import TEXTDIR
from gn3.random import random_string
from gn3.function_helpers import compose
from gn3.data_helpers import parse_csv_line
from gn3.db.traits import export_informative
from gn3.db.datasets import retrieve_trait_dataset
+from gn3.settings import TEXTDIR, MULTIPROCESSOR_PROCS
from gn3.db.partial_correlations import traits_info, traits_data
from gn3.db.species import species_name, translate_to_mouse_gene_id
from gn3.db.correlations import (
@@ -281,6 +282,47 @@ def build_data_frame(
return interm_df.rename(columns={"z0": "z"})
return interm_df
+def compute_trait_info(primary_vals, control_vals, target, method):
+ targ_vals = target[0]
+ targ_name = target[1]
+ primary = [
+ prim for targ, prim in zip(targ_vals, primary_vals)
+ if targ is not None]
+
+ if len(primary) < 3:
+ return None
+
+ def __remove_controls_for_target_nones(cont_targ):
+ return tuple(cont for cont, targ in cont_targ if targ is not None)
+
+ conts_targs = tuple(tuple(
+ zip(control, targ_vals)) for control in control_vals)
+ datafrm = build_data_frame(
+ primary,
+ [targ for targ in targ_vals if targ is not None],
+ [__remove_controls_for_target_nones(cont_targ)
+ for cont_targ in conts_targs])
+ covariates = "z" if datafrm.shape[1] == 3 else [
+ col for col in datafrm.columns if col not in ("x", "y")]
+ ppc = pingouin.partial_corr(
+ data=datafrm, x="x", y="y", covar=covariates, method=(
+ "pearson" if "pearson" in method.lower() else "spearman"))
+ pc_coeff = ppc["r"][0]
+
+ zero_order_corr = pingouin.corr(
+ datafrm["x"], datafrm["y"], method=(
+ "pearson" if "pearson" in method.lower() else "spearman"))
+
+ if math.isnan(pc_coeff):
+ return (
+ targ_name, len(primary), pc_coeff, 1, zero_order_corr["r"][0],
+ zero_order_corr["p-val"][0])
+ return (
+ targ_name, len(primary), pc_coeff,
+ (ppc["p-val"][0] if not math.isnan(ppc["p-val"][0]) else (
+ 0 if (abs(pc_coeff - 1) < 0.0000001) else 1)),
+ zero_order_corr["r"][0], zero_order_corr["p-val"][0])
+
def compute_partial(
primary_vals, control_vals, target_vals, target_names,
method: str) -> Tuple[
@@ -296,57 +338,15 @@ def compute_partial(
This implementation reworks the child function `compute_partial` which will
then be used in the place of `determinPartialsByR`.
-
- TODO: moving forward, we might need to use the multiprocessing library to
- speed up the computations, in case they are found to be slow.
"""
- # replace the R code with `pingouin.partial_corr`
- def __compute_trait_info__(target):
- targ_vals = target[0]
- targ_name = target[1]
- primary = [
- prim for targ, prim in zip(targ_vals, primary_vals)
- if targ is not None]
-
- if len(primary) < 3:
- return None
-
- def __remove_controls_for_target_nones(cont_targ):
- return tuple(cont for cont, targ in cont_targ if targ is not None)
-
- conts_targs = tuple(tuple(
- zip(control, targ_vals)) for control in control_vals)
- datafrm = build_data_frame(
- primary,
- [targ for targ in targ_vals if targ is not None],
- [__remove_controls_for_target_nones(cont_targ)
- for cont_targ in conts_targs])
- covariates = "z" if datafrm.shape[1] == 3 else [
- col for col in datafrm.columns if col not in ("x", "y")]
- ppc = pingouin.partial_corr(
- data=datafrm, x="x", y="y", covar=covariates, method=(
- "pearson" if "pearson" in method.lower() else "spearman"))
- pc_coeff = ppc["r"][0]
-
- zero_order_corr = pingouin.corr(
- datafrm["x"], datafrm["y"], method=(
- "pearson" if "pearson" in method.lower() else "spearman"))
-
- if math.isnan(pc_coeff):
- return (
- targ_name, len(primary), pc_coeff, 1, zero_order_corr["r"][0],
- zero_order_corr["p-val"][0])
- return (
- targ_name, len(primary), pc_coeff,
- (ppc["p-val"][0] if not math.isnan(ppc["p-val"][0]) else (
- 0 if (abs(pc_coeff - 1) < 0.0000001) else 1)),
- zero_order_corr["r"][0], zero_order_corr["p-val"][0])
-
- return tuple(
- result for result in (
- __compute_trait_info__(target)
- for target in zip(target_vals, target_names))
- if result is not None)
+ with mp.Pool(MULTIPROCESSOR_PROCS or (mp.cpu_count() - 1)) as pool:
+ return tuple(
+ result for result in
+ pool.starmap(
+ compute_trait_info,
+ ((primary_vals, control_vals, (tvals, tname), method)
+ for tvals, tname in zip(target_vals, target_names)))
+ if result is not None)
def partial_correlations_normal(# pylint: disable=R0913
primary_vals, control_vals, input_trait_gene_id, trait_database,
diff --git a/gn3/settings.py b/gn3/settings.py
index c945fbf..87e8f4b 100644
--- a/gn3/settings.py
+++ b/gn3/settings.py
@@ -57,3 +57,5 @@ GNSHARE = os.environ.get("GNSHARE", "/gnshare/gn/")
TEXTDIR = f"{GNSHARE}/web/ProbeSetFreeze_DataMatrix"
ROUND_TO = 10
+
+MULTIPROCESSOR_PROCS = 6 # Number of processes to spawn