about summary refs log tree commit diff
path: root/gn3/computations
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/computations')
-rw-r--r--gn3/computations/partial_correlations.py102
1 files changed, 51 insertions, 51 deletions
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 157024b..52c833f 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -6,6 +6,7 @@ GeneNetwork1.
 """
 
 import math
+import multiprocessing as mp
 from functools import reduce, partial
 from typing import Any, Tuple, Union, Sequence
 
@@ -14,12 +15,12 @@ import pandas
 import pingouin
 from scipy.stats import pearsonr, spearmanr
 
-from gn3.settings import TEXTDIR
 from gn3.random import random_string
 from gn3.function_helpers import  compose
 from gn3.data_helpers import parse_csv_line
 from gn3.db.traits import export_informative
 from gn3.db.datasets import retrieve_trait_dataset
+from gn3.settings import TEXTDIR, MULTIPROCESSOR_PROCS
 from gn3.db.partial_correlations import traits_info, traits_data
 from gn3.db.species import species_name, translate_to_mouse_gene_id
 from gn3.db.correlations import (
@@ -281,6 +282,47 @@ def build_data_frame(
         return interm_df.rename(columns={"z0": "z"})
     return interm_df
 
+def compute_trait_info(primary_vals, control_vals, target, method):
+    targ_vals = target[0]
+    targ_name = target[1]
+    primary = [
+        prim for targ, prim in zip(targ_vals, primary_vals)
+        if targ is not None]
+
+    if len(primary) < 3:
+        return None
+
+    def __remove_controls_for_target_nones(cont_targ):
+        return tuple(cont for cont, targ in cont_targ if targ is not None)
+
+    conts_targs = tuple(tuple(
+        zip(control, targ_vals)) for control in control_vals)
+    datafrm = build_data_frame(
+        primary,
+        [targ for targ in targ_vals if targ is not None],
+        [__remove_controls_for_target_nones(cont_targ)
+         for cont_targ in conts_targs])
+    covariates = "z" if datafrm.shape[1] == 3 else [
+        col for col in datafrm.columns if col not in ("x", "y")]
+    ppc = pingouin.partial_corr(
+        data=datafrm, x="x", y="y", covar=covariates, method=(
+            "pearson" if "pearson" in method.lower() else "spearman"))
+    pc_coeff = ppc["r"][0]
+
+    zero_order_corr = pingouin.corr(
+        datafrm["x"], datafrm["y"], method=(
+            "pearson" if "pearson" in method.lower() else "spearman"))
+
+    if math.isnan(pc_coeff):
+        return (
+            targ_name, len(primary), pc_coeff, 1, zero_order_corr["r"][0],
+            zero_order_corr["p-val"][0])
+    return (
+        targ_name, len(primary), pc_coeff,
+        (ppc["p-val"][0] if not math.isnan(ppc["p-val"][0]) else (
+            0 if (abs(pc_coeff - 1) < 0.0000001) else 1)),
+        zero_order_corr["r"][0], zero_order_corr["p-val"][0])
+
 def compute_partial(
         primary_vals, control_vals, target_vals, target_names,
         method: str) -> Tuple[
@@ -296,57 +338,15 @@ def compute_partial(
 
     This implementation reworks the child function `compute_partial` which will
     then be used in the place of `determinPartialsByR`.
-
-    TODO: moving forward, we might need to use the multiprocessing library to
-    speed up the computations, in case they are found to be slow.
     """
-    # replace the R code with `pingouin.partial_corr`
-    def __compute_trait_info__(target):
-        targ_vals = target[0]
-        targ_name = target[1]
-        primary = [
-            prim for targ, prim in zip(targ_vals, primary_vals)
-            if targ is not None]
-
-        if len(primary) < 3:
-            return None
-
-        def __remove_controls_for_target_nones(cont_targ):
-            return tuple(cont for cont, targ in cont_targ if targ is not None)
-
-        conts_targs = tuple(tuple(
-            zip(control, targ_vals)) for control in control_vals)
-        datafrm = build_data_frame(
-            primary,
-            [targ for targ in targ_vals if targ is not None],
-            [__remove_controls_for_target_nones(cont_targ)
-             for cont_targ in conts_targs])
-        covariates = "z" if datafrm.shape[1] == 3 else [
-            col for col in datafrm.columns if col not in ("x", "y")]
-        ppc = pingouin.partial_corr(
-            data=datafrm, x="x", y="y", covar=covariates, method=(
-                "pearson" if "pearson" in method.lower() else "spearman"))
-        pc_coeff = ppc["r"][0]
-
-        zero_order_corr = pingouin.corr(
-            datafrm["x"], datafrm["y"], method=(
-                "pearson" if "pearson" in method.lower() else "spearman"))
-
-        if math.isnan(pc_coeff):
-            return (
-                targ_name, len(primary), pc_coeff, 1, zero_order_corr["r"][0],
-                zero_order_corr["p-val"][0])
-        return (
-            targ_name, len(primary), pc_coeff,
-            (ppc["p-val"][0] if not math.isnan(ppc["p-val"][0]) else (
-                0 if (abs(pc_coeff - 1) < 0.0000001) else 1)),
-            zero_order_corr["r"][0], zero_order_corr["p-val"][0])
-
-    return tuple(
-        result for result in (
-            __compute_trait_info__(target)
-            for target in zip(target_vals, target_names))
-        if result is not None)
+    with mp.Pool(MULTIPROCESSOR_PROCS or (mp.cpu_count() - 1)) as pool:
+        return tuple(
+            result for result in
+            pool.starmap(
+                compute_trait_info,
+                ((primary_vals, control_vals, (tvals, tname), method)
+                 for tvals, tname in zip(target_vals, target_names)))
+            if result is not None)
 
 def partial_correlations_normal(# pylint: disable=R0913
         primary_vals, control_vals, input_trait_gene_id, trait_database,