about summary refs log tree commit diff
path: root/gn3/computations/partial_correlations.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/computations/partial_correlations.py')
-rw-r--r--gn3/computations/partial_correlations.py41
1 files changed, 23 insertions, 18 deletions
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 6eee299..8674910 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -16,7 +16,6 @@ import pandas
 import pingouin
 from scipy.stats import pearsonr, spearmanr
 
-from gn3.settings import TEXTDIR
 from gn3.chancy import random_string
 from gn3.function_helpers import  compose
 from gn3.data_helpers import parse_csv_line
@@ -99,7 +98,7 @@ def fix_samples(
         primary_samples,
         tuple(primary_trait_data["data"][sample]["value"]
               for sample in primary_samples),
-        control_vals_vars[0],
+        (control_vals_vars[0],),
         tuple(primary_trait_data["data"][sample]["variance"]
               for sample in primary_samples),
         control_vals_vars[1])
@@ -209,7 +208,7 @@ def good_dataset_samples_indexes(
         samples_from_file.index(good) for good in
         set(samples).intersection(set(samples_from_file))))
 
-def partial_correlations_fast(# pylint: disable=[R0913, R0914]
+def partial_correlations_fast(# pylint: disable=[R0913, R0914, too-many-positional-arguments]
         samples, primary_vals, control_vals, database_filename,
         fetched_correlations, method: str, correlation_type: str) -> Generator:
     """
@@ -334,7 +333,7 @@ def compute_partial(
     This implementation reworks the child function `compute_partial` which will
     then be used in the place of `determinPartialsByR`.
     """
-    with Pool(processes=(cpu_count() - 1)) as pool:
+    with Pool(processes=cpu_count() - 1) as pool:
         return (
             result for result in (
                 pool.starmap(
@@ -345,7 +344,7 @@ def compute_partial(
                      for target in targets)))
         if result is not None)
 
-def partial_correlations_normal(# pylint: disable=R0913
+def partial_correlations_normal(# pylint: disable=[R0913, too-many-positional-arguments]
         primary_vals, control_vals, input_trait_gene_id, trait_database,
         data_start_pos: int, db_type: str, method: str) -> Generator:
     """
@@ -381,7 +380,7 @@ def partial_correlations_normal(# pylint: disable=R0913
 
     return all_correlations
 
-def partial_corrs(# pylint: disable=[R0913]
+def partial_corrs(# pylint: disable=[R0913, too-many-positional-arguments]
         conn, samples, primary_vals, control_vals, return_number, species,
         input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id,
         method, dataset, database_filename):
@@ -667,10 +666,15 @@ def check_for_common_errors(# pylint: disable=[R0914]
 
     return non_error_result
 
-def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911]
-        conn: Any, primary_trait_name: str,
-        control_trait_names: Tuple[str, ...], method: str,
-        criteria: int, target_db_name: str) -> dict:
+def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911 too-many-positional-arguments]
+        conn: Any,
+        primary_trait_name: str,
+        control_trait_names: Tuple[str, ...],
+        method: str,
+        criteria: int,
+        target_db_name: str,
+        textdir: str
+) -> dict:
     """
     This is the 'ochestration' function for the partial-correlation feature.
 
@@ -755,7 +759,7 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911]
         threshold,
         conn)
 
-    database_filename = get_filename(conn, target_db_name, TEXTDIR)
+    database_filename = get_filename(conn, target_db_name, textdir)
     all_correlations = partial_corrs(
         conn, check_res["common_primary_control_samples"],
         check_res["fixed_primary_values"], check_res["fixed_control_values"],
@@ -837,7 +841,7 @@ def partial_correlations_with_target_traits(
         return check_res
 
     target_traits = {
-        trait["name"]: trait
+        trait["trait_name"]: trait
         for trait in traits_info(conn, threshold, target_trait_names)}
     target_traits_data = traits_data(conn, tuple(target_traits.values()))
 
@@ -854,12 +858,13 @@ def partial_correlations_with_target_traits(
         __merge(
             target_traits[target_name],
             compute_trait_info(
-            check_res["primary_values"], check_res["fixed_control_values"],
-            (export_trait_data(
-                target_data,
-                samplelist=check_res["common_primary_control_samples"]),
-             target_name),
-            method))
+                check_res["primary_values"],
+                check_res["fixed_control_values"],
+                (export_trait_data(
+                    target_data,
+                    samplelist=check_res["common_primary_control_samples"]),
+                 target_name),
+                method))
         for target_name, target_data in target_traits_data.items())
 
     return {