about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-05-05 15:19:08 +0300
committerFrederick Muriuki Muriithi2022-05-05 15:19:08 +0300
commitbe9d1d4aad720274f6d75345123fae8d6a96bc12 (patch)
treea0e314b06c4aee746cac71e73c7ff9bc9d2ce0af
parent3f0b4bf1085f4b28d50318e695da3f2bd739061f (diff)
downloadgenenetwork3-be9d1d4aad720274f6d75345123fae8d6a96bc12.tar.gz
Extract common error checking. Rename function.
* Extract the common error checking code into a separate function
* Rename the function to make its use clearer
-rw-r--r--gn3/computations/partial_correlations.py74
-rwxr-xr-xscripts/partial_correlations.py4
-rw-r--r--tests/integration/test_partial_correlations.py4
3 files changed, 57 insertions, 25 deletions
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 2921852..f82031a 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -555,26 +555,14 @@ def trait_for_output(trait):
     }
     return {key: val for key, val in trait.items() if val is not None}
 
-def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
-        conn: Any, primary_trait_name: str,
-        control_trait_names: Tuple[str, ...], method: str,
-        criteria: int, target_db_name: str) -> dict:
-    """
-    This is the 'ochestration' function for the partial-correlation feature.
-
-    This function will dispatch the functions doing data fetches from the
-    database (and various other places) and feed that data to the functions
-    doing the conversions and computations. It will then return the results of
-    all of that work.
-
-    This function is doing way too much. Look into splitting out the
-    functionality into smaller functions that do fewer things.
-    """
+def check_for_common_errors(conn, primary_trait_name, control_trait_names):
+    """Check for common errors"""
     threshold = 0
     corr_min_informative = 4
+    non_error_result = {"status": "success"}
 
-    all_traits = traits_info(
-        conn, threshold, (primary_trait_name,) + control_trait_names)
+    all_traits = tuple(item for item in traits_info(
+        conn, threshold, (primary_trait_name,) + control_trait_names))
     all_traits_data = traits_data(conn, all_traits)
 
     primary_trait = tuple(
@@ -585,6 +573,8 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
             "status": "not-found",
             "message": f"Could not find primary trait {primary_trait['trait_fullname']}"
         }
+    non_error_result["primary_trait"] = primary_trait
+
     cntrl_traits = tuple(
         trait for trait in all_traits
         if trait["trait_fullname"] != primary_trait_name)
@@ -599,15 +589,25 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
                  "- continuing without it."),
                 category=UserWarning)
 
-    group = primary_trait["db"]["group"]
+    non_error_result["control_traits"] = cntrl_traits
+
+    non_error_result["group"] = group = primary_trait["db"]["group"]
     primary_trait_data = all_traits_data[primary_trait["trait_name"]]
+    non_error_result[""] = primary_trait_data
+
     primary_samples, primary_values, _primary_variances = export_informative(
         primary_trait_data)
+    non_error_result["primary_samples"] = primary_samples
+    non_error_result["primary_values"] = primary_values
+    non_error_result["primary_variances"] = _primary_variances
 
     cntrl_traits_data = tuple(
         data for trait_name, data in all_traits_data.items()
         if trait_name != primary_trait["trait_name"])
+    non_error_result["control_traits_data"] = cntrl_traits_data
+
     species = species_name(conn, group)
+    non_error_result["species"] = species
 
     (cntrl_samples,
      cntrl_values,
@@ -632,6 +632,11 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
                 f"{group} dataset. No calculation of correlation has been "
                 "attempted."),
             "error_type": "Inadequate Samples"}
+    non_error_result["common_primary_control_samples"] = common_primary_control_samples
+    non_error_result["fixed_primary_values"] = fixed_primary_vals
+    non_error_result["fixed_control_values"] = fixed_control_vals
+    non_error_result["primary_variances"] = _primary_variances
+    non_error_result["control_variances"] = _cntrl_variances
 
     identical_traits_names = find_identical_traits(
         primary_trait_name, primary_values, control_trait_names, cntrl_values)
@@ -646,7 +651,32 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
                 "partial correlation cannot be computed. Please re-select your "
                 "traits."),
             "error_type": "Identical Traits"}
+    non_error_result["identical_traits_names"] = identical_traits_names
+
+    return non_error_result
+
+def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911]
+        conn: Any, primary_trait_name: str,
+        control_trait_names: Tuple[str, ...], method: str,
+        criteria: int, target_db_name: str) -> dict:
+    """
+    This is the 'ochestration' function for the partial-correlation feature.
+
+    This function will dispatch the functions doing data fetches from the
+    database (and various other places) and feed that data to the functions
+    doing the conversions and computations. It will then return the results of
+    all of that work.
+
+    This function is doing way too much. Look into splitting out the
+    functionality into smaller functions that do fewer things.
+    """
+
+    check_res = check_for_common_errors(
+        conn, primary_trait_name, control_trait_names)
+    if check_res.get("status") == "error":
+        return error_check_results
 
+    primary_trait = check_res["primary_trait"]
     input_trait_geneid = primary_trait.get("geneid", 0)
     input_trait_symbol = primary_trait.get("symbol", "")
     input_trait_mouse_geneid = translate_to_mouse_gene_id(
@@ -713,10 +743,12 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
 
     database_filename = get_filename(conn, target_db_name, TEXTDIR)
     _total_traits, all_correlations = partial_corrs(
-        conn, common_primary_control_samples, fixed_primary_vals,
-        fixed_control_vals, len(fixed_primary_vals), species,
+        conn, check_res["common_primary_control_samples"],
+        check_res["fixed_primary_values"], check_res["fixed_control_vals"],
+        len(check_res["fixed_primary_vals"]), check_res["species"],
         input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id,
-        method, {**target_dataset, "dataset_type": target_dataset["type"]}, database_filename)
+        method, {**target_dataset, "dataset_type": target_dataset["type"]},
+        database_filename)
 
 
     def __make_sorter__(method):
diff --git a/scripts/partial_correlations.py b/scripts/partial_correlations.py
index f203daa..52bde4c 100755
--- a/scripts/partial_correlations.py
+++ b/scripts/partial_correlations.py
@@ -5,7 +5,7 @@ from argparse import ArgumentParser
 
 from gn3.db_utils import database_connector
 from gn3.responses.pcorrs_responses import OutputEncoder
-from gn3.computations.partial_correlations import partial_correlations_entry
+from gn3.computations.partial_correlations import partial_correlations_with_target_db
 
 def process_cli_arguments():
     parser = ArgumentParser()
@@ -37,7 +37,7 @@ def cleanup_string(the_str):
 def run_partial_corrs(args):
     with database_connector() as conn:
         try:
-            return partial_correlations_entry(
+            return partial_correlations_with_target_db(
                 conn, cleanup_string(args.primary_trait),
                 tuple(cleanup_string(args.control_traits).split(",")),
                 cleanup_string(args.method), args.criteria,
diff --git a/tests/integration/test_partial_correlations.py b/tests/integration/test_partial_correlations.py
index d249b42..fc9f64f 100644
--- a/tests/integration/test_partial_correlations.py
+++ b/tests/integration/test_partial_correlations.py
@@ -3,7 +3,7 @@ from unittest import mock
 
 import pytest
 
-from gn3.computations.partial_correlations import partial_correlations_entry
+from gn3.computations.partial_correlations import partial_correlations_with_target_db
 
 @pytest.mark.integration_test
 @pytest.mark.parametrize(
@@ -220,5 +220,5 @@ def test_part_corr_api_with_mix_of_existing_and_non_existing_control_traits(
     """
     criteria = 10
     with pytest.warns(UserWarning):
-        partial_correlations_entry(
+        partial_correlations_with_target_db(
             db_conn, primary, controls, method, criteria, target)