about summary refs log tree commit diff
path: root/gn3
diff options
context:
space:
mode:
Diffstat (limited to 'gn3')
-rw-r--r--gn3/computations/partial_correlations.py74
1 files changed, 53 insertions, 21 deletions
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 2921852..f82031a 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -555,26 +555,14 @@ def trait_for_output(trait):
     }
     return {key: val for key, val in trait.items() if val is not None}
 
-def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
-        conn: Any, primary_trait_name: str,
-        control_trait_names: Tuple[str, ...], method: str,
-        criteria: int, target_db_name: str) -> dict:
-    """
-    This is the 'ochestration' function for the partial-correlation feature.
-
-    This function will dispatch the functions doing data fetches from the
-    database (and various other places) and feed that data to the functions
-    doing the conversions and computations. It will then return the results of
-    all of that work.
-
-    This function is doing way too much. Look into splitting out the
-    functionality into smaller functions that do fewer things.
-    """
+def check_for_common_errors(conn, primary_trait_name, control_trait_names):
+    """Check for common errors"""
     threshold = 0
     corr_min_informative = 4
+    non_error_result = {"status": "success"}
 
-    all_traits = traits_info(
-        conn, threshold, (primary_trait_name,) + control_trait_names)
+    all_traits = tuple(item for item in traits_info(
+        conn, threshold, (primary_trait_name,) + control_trait_names))
     all_traits_data = traits_data(conn, all_traits)
 
     primary_trait = tuple(
@@ -585,6 +573,8 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
             "status": "not-found",
             "message": f"Could not find primary trait {primary_trait['trait_fullname']}"
         }
+    non_error_result["primary_trait"] = primary_trait
+
     cntrl_traits = tuple(
         trait for trait in all_traits
         if trait["trait_fullname"] != primary_trait_name)
@@ -599,15 +589,25 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
                  "- continuing without it."),
                 category=UserWarning)
 
-    group = primary_trait["db"]["group"]
+    non_error_result["control_traits"] = cntrl_traits
+
+    non_error_result["group"] = group = primary_trait["db"]["group"]
     primary_trait_data = all_traits_data[primary_trait["trait_name"]]
+    non_error_result[""] = primary_trait_data
+
     primary_samples, primary_values, _primary_variances = export_informative(
         primary_trait_data)
+    non_error_result["primary_samples"] = primary_samples
+    non_error_result["primary_values"] = primary_values
+    non_error_result["primary_variances"] = _primary_variances
 
     cntrl_traits_data = tuple(
         data for trait_name, data in all_traits_data.items()
         if trait_name != primary_trait["trait_name"])
+    non_error_result["control_traits_data"] = cntrl_traits_data
+
     species = species_name(conn, group)
+    non_error_result["species"] = species
 
     (cntrl_samples,
      cntrl_values,
@@ -632,6 +632,11 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
                 f"{group} dataset. No calculation of correlation has been "
                 "attempted."),
             "error_type": "Inadequate Samples"}
+    non_error_result["common_primary_control_samples"] = common_primary_control_samples
+    non_error_result["fixed_primary_values"] = fixed_primary_vals
+    non_error_result["fixed_control_values"] = fixed_control_vals
+    non_error_result["primary_variances"] = _primary_variances
+    non_error_result["control_variances"] = _cntrl_variances
 
     identical_traits_names = find_identical_traits(
         primary_trait_name, primary_values, control_trait_names, cntrl_values)
@@ -646,7 +651,32 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
                 "partial correlation cannot be computed. Please re-select your "
                 "traits."),
             "error_type": "Identical Traits"}
+    non_error_result["identical_traits_names"] = identical_traits_names
+
+    return non_error_result
+
+def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911]
+        conn: Any, primary_trait_name: str,
+        control_trait_names: Tuple[str, ...], method: str,
+        criteria: int, target_db_name: str) -> dict:
+    """
+    This is the 'ochestration' function for the partial-correlation feature.
+
+    This function will dispatch the functions doing data fetches from the
+    database (and various other places) and feed that data to the functions
+    doing the conversions and computations. It will then return the results of
+    all of that work.
+
+    This function is doing way too much. Look into splitting out the
+    functionality into smaller functions that do fewer things.
+    """
+
+    check_res = check_for_common_errors(
+        conn, primary_trait_name, control_trait_names)
+    if check_res.get("status") == "error":
+        return error_check_results
 
+    primary_trait = check_res["primary_trait"]
     input_trait_geneid = primary_trait.get("geneid", 0)
     input_trait_symbol = primary_trait.get("symbol", "")
     input_trait_mouse_geneid = translate_to_mouse_gene_id(
@@ -713,10 +743,12 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
 
     database_filename = get_filename(conn, target_db_name, TEXTDIR)
     _total_traits, all_correlations = partial_corrs(
-        conn, common_primary_control_samples, fixed_primary_vals,
-        fixed_control_vals, len(fixed_primary_vals), species,
+        conn, check_res["common_primary_control_samples"],
+        check_res["fixed_primary_values"], check_res["fixed_control_vals"],
+        len(check_res["fixed_primary_vals"]), check_res["species"],
         input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id,
-        method, {**target_dataset, "dataset_type": target_dataset["type"]}, database_filename)
+        method, {**target_dataset, "dataset_type": target_dataset["type"]},
+        database_filename)
 
 
     def __make_sorter__(method):