about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/computations/partial_correlations.py81
1 files changed, 39 insertions, 42 deletions
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index e6056d5..13def5e 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -20,7 +20,7 @@ from gn3.function_helpers import  compose
 from gn3.data_helpers import parse_csv_line
 from gn3.db.traits import export_informative
 from gn3.db.datasets import retrieve_trait_dataset
-from gn3.db.traits import retrieve_trait_info, retrieve_trait_data
+from gn3.db.partial_correlations import traits_info, traits_data
 from gn3.db.species import species_name, translate_to_mouse_gene_id
 from gn3.db.correlations import (
     get_filename,
@@ -608,18 +608,24 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
     threshold = 0
     corr_min_informative = 4
 
-    primary_trait = retrieve_trait_info(threshold, primary_trait_name, conn)
-    group = primary_trait["group"]
-    primary_trait_data = retrieve_trait_data(primary_trait, conn)
+    all_traits = traits_info(
+        conn, threshold, (primary_trait_name,) + control_trait_names)
+    all_traits_data = traits_data(conn, all_traits)
+
+    primary_trait = tuple(
+        trait for trait in all_traits
+        if trait["trait_fullname"] == primary_trait_name)[0]
+    group = primary_trait["db"]["group"]
+    primary_trait_data = all_traits_data[primary_trait["trait_name"]]
     primary_samples, primary_values, _primary_variances = export_informative(
         primary_trait_data)
 
     cntrl_traits = tuple(
-        retrieve_trait_info(threshold, trait_full_name, conn)
-        for trait_full_name in control_trait_names)
+        trait for trait in all_traits
+        if trait["trait_fullname"] != primary_trait_name)
     cntrl_traits_data = tuple(
-        retrieve_trait_data(cntrl_trait, conn)
-        for cntrl_trait in cntrl_traits)
+        data for trait_name, data in all_traits_data.items()
+        if trait_name != primary_trait["trait_name"])
     species = species_name(conn, group)
 
     (cntrl_samples,
@@ -660,8 +666,8 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
                 "traits."),
             "error_type": "Identical Traits"}
 
-    input_trait_geneid = primary_trait.get("geneid")
-    input_trait_symbol = primary_trait.get("symbol")
+    input_trait_geneid = primary_trait.get("geneid", 0)
+    input_trait_symbol = primary_trait.get("symbol", "")
     input_trait_mouse_geneid = translate_to_mouse_gene_id(
         species, input_trait_geneid, conn)
 
@@ -682,7 +688,7 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
             "error_type": "Correlation Type"}
 
     if (method.lower() == "sgo literature correlation" and (
-            input_trait_geneid is None or
+            bool(input_trait_geneid) is False or
             check_for_literature_info(conn, input_trait_mouse_geneid))):
         return {
             "status": "error",
@@ -695,7 +701,7 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
             method.lower() in (
                 "tissue correlation, pearson's r",
                 "tissue correlation, spearman's rho")
-            and input_trait_symbol is None):
+            and bool(input_trait_symbol) is False):
         return {
             "status": "error",
             "message": (
@@ -733,33 +739,19 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
 
 
     def __make_sorter__(method):
-        def __compare_lit_or_tiss_correlation_values_(row):
-            # Index  Content
-            # 0      trait name
-            # 1      N
-            # 2      partial correlation coefficient
-            # 3      p value of partial correlation
-            # 6      literature/tissue correlation value
-            return (row[6], row[3])
-
-        def __compare_partial_correlation_p_values__(row):
-            # Index  Content
-            # 0      trait name
-            # 1      partial correlation coefficient
-            # 2      N
-            # 3      p value of partial correlation
+        def __sort_6__(row):
+            return row[6]
+
+        def __sort_3__(row):
             return row[3]
 
         if "literature" in method.lower():
-            return __compare_lit_or_tiss_correlation_values_
+            return __sort_6__
 
         if "tissue" in method.lower():
-            return __compare_lit_or_tiss_correlation_values_
-
-        return __compare_partial_correlation_p_values__
+            return __sort_6__
 
-    sorted_correlations = sorted(
-        all_correlations, key=__make_sorter__(method))
+        return __sort_3__
 
     add_lit_corr_and_tiss_corr = compose(
         partial(literature_correlation_by_list, conn, species),
@@ -767,12 +759,11 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
             tissue_correlation_by_list, conn, input_trait_symbol,
             tissue_probeset_freeze_id, method))
 
-    trait_list = add_lit_corr_and_tiss_corr(tuple(
-        {
-            **retrieve_trait_info(
-                threshold,
-                f"{target_dataset['dataset_name']}::{item[0]}",
-                conn),
+    selected_results = sorted(
+        all_correlations,
+        key=__make_sorter__(method))[:min(criteria, len(all_correlations))]
+    traits_list_corr_info = {
+        "{target_dataset['dataset_name']}::{item[0]}": {
             "noverlap": item[1],
             "partial_corr": item[2],
             "partial_corr_p_value": item[3],
@@ -785,9 +776,15 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
                if len(item) == 8 else {}),
             **({"l_corr": item[6]}
                if len(item) == 7 else {})
-        }
-        for item in
-        sorted_correlations[:min(criteria, len(all_correlations))]))
+        } for item in selected_results}
+
+    trait_list = add_lit_corr_and_tiss_corr(tuple(
+        {**trait, **traits_list_corr_info.get(trait["trait_fullname"], {})}
+        for trait in traits_info(
+            conn, threshold,
+            tuple(
+                f"{target_dataset['dataset_name']}::{item[0]}"
+                for item in selected_results))))
 
     return {
         "status": "success",