about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/api/correlation.py4
-rw-r--r--gn3/computations/correlations.py26
-rw-r--r--tests/integration/test_correlation.py2
-rw-r--r--tests/unit/computations/test_correlation.py14
4 files changed, 25 insertions, 21 deletions
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py
index e7e89cf..a3e366e 100644
--- a/gn3/api/correlation.py
+++ b/gn3/api/correlation.py
@@ -5,7 +5,7 @@ from flask import request
 
 from gn3.computations.correlations import compute_all_sample_correlation
 from gn3.computations.correlations import compute_all_lit_correlation
-from gn3.computations.correlations import compute_all_tissue_correlation
+from gn3.computations.correlations import compute_tissue_correlation
 from gn3.computations.correlations import map_shared_keys_to_values
 from gn3.db_utils import database_connector
 
@@ -78,7 +78,7 @@ def compute_tissue_corr(corr_method="pearson"):
     primary_tissue_dict = tissue_input_data["primary_tissue"]
     target_tissues_dict = tissue_input_data["target_tissues_dict"]
 
-    results = compute_all_tissue_correlation(primary_tissue_dict=primary_tissue_dict,
+    results = compute_tissue_correlation(primary_tissue_dict=primary_tissue_dict,
                                              target_tissues_data=target_tissues_dict,
                                              corr_method=corr_method)
 
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py
index 56f483c..1fd3213 100644
--- a/gn3/computations/correlations.py
+++ b/gn3/computations/correlations.py
@@ -124,11 +124,12 @@ def filter_shared_sample_keys(this_samplelist,
     return (this_vals, target_vals)
 
 
-def compute_all_sample_correlation(this_trait,
-                                   target_dataset,
-                                   corr_method="pearson") -> List:
+def speed_compute_all_sample_correlation(this_trait,
+                                         target_dataset,
+                                         corr_method="pearson") -> List:
     """Given a trait data sample-list and target__datasets compute all sample
     correlation
+    this functions uses multiprocessing if not use the normal fun
 
     """
     # xtodo fix trait_name currently returning single one
@@ -160,9 +161,9 @@ def compute_all_sample_correlation(this_trait,
         key=lambda trait_name: -abs(list(trait_name.values())[0]["corr_coefficient"]))
 
 
-def benchmark_compute_all_sample(this_trait,
-                                 target_dataset,
-                                 corr_method="pearson") -> List:
+def compute_all_sample_correlation(this_trait,
+                                   target_dataset,
+                                   corr_method="pearson") -> List:
     """Temp function to benchmark with compute_all_sample_r alternative to
     compute_all_sample_r where we use multiprocessing
 
@@ -174,6 +175,7 @@ def benchmark_compute_all_sample(this_trait,
         target_trait_data = target_trait["trait_sample_data"]
         this_vals, target_vals = filter_shared_sample_keys(
             this_trait_samples, target_trait_data)
+
         sample_correlation = compute_sample_r_correlation(
             trait_name=trait_name,
             corr_method=corr_method,
@@ -190,7 +192,9 @@ def benchmark_compute_all_sample(this_trait,
             "num_overlap": num_overlap
         }
         corr_results.append({trait_name: corr_result})
-    return corr_results
+    return sorted(
+        corr_results,
+        key=lambda trait_name: -abs(list(trait_name.values())[0]["corr_coefficient"]))
 
 
 def tissue_correlation_for_trait(
@@ -336,7 +340,7 @@ def compute_all_lit_correlation(conn, trait_lists: List,
     return sorted_lit_results
 
 
-def compute_all_tissue_correlation(primary_tissue_dict: dict,
+def compute_tissue_correlation(primary_tissue_dict: dict,
                                    target_tissues_data: dict,
                                    corr_method: str):
     """Function acts as an abstraction for tissue_correlation_for_trait\
@@ -382,9 +386,9 @@ def process_trait_symbol_dict(trait_symbol_dict, symbol_tissue_vals_dict) -> Lis
     return traits_tissue_vals
 
 
-def compute_tissue_correlation(primary_tissue_dict: dict,
-                               target_tissues_data: dict,
-                               corr_method: str):
+def speed_compute_tissue_correlation(primary_tissue_dict: dict,
+                                     target_tissues_data: dict,
+                                     corr_method: str):
     """Experimental function that uses multiprocessing for computing tissue
     correlation
 
diff --git a/tests/integration/test_correlation.py b/tests/integration/test_correlation.py
index e67f58d..bdd9bce 100644
--- a/tests/integration/test_correlation.py
+++ b/tests/integration/test_correlation.py
@@ -80,7 +80,7 @@ class CorrelationIntegrationTest(TestCase):
         self.assertEqual(mock_compute_corr.call_count, 1)
         self.assertEqual(response.status_code, 200)
 
-    @mock.patch("gn3.api.correlation.compute_all_tissue_correlation")
+    @mock.patch("gn3.api.correlation.compute_tissue_correlation")
     def test_tissue_correlation(self, mock_tissue_corr):
         """Test api/correlation/tissue_corr/{corr_method}"""
         mock_tissue_corr.return_value = {}
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py
index 9450094..f2d65bd 100644
--- a/tests/unit/computations/test_correlation.py
+++ b/tests/unit/computations/test_correlation.py
@@ -1,5 +1,4 @@
 """Module contains the tests for correlation"""
-import unittest
 from unittest import TestCase
 from unittest import mock
 
@@ -16,7 +15,7 @@ from gn3.computations.correlations import fetch_lit_correlation_data
 from gn3.computations.correlations import query_formatter
 from gn3.computations.correlations import map_to_mouse_gene_id
 from gn3.computations.correlations import compute_all_lit_correlation
-from gn3.computations.correlations import compute_all_tissue_correlation
+from gn3.computations.correlations import compute_tissue_correlation
 from gn3.computations.correlations import map_shared_keys_to_values
 from gn3.computations.correlations import process_trait_symbol_dict
 from gn3.computations.correlations2 import compute_correlation
@@ -173,7 +172,6 @@ class TestCorrelation(TestCase):
         self.assertEqual(results, (filtered_this_samplelist,
                                    filtered_target_samplelist))
 
-    @unittest.skip("Test needs to be refactored ")
     @mock.patch("gn3.computations.correlations.compute_sample_r_correlation")
     @mock.patch("gn3.computations.correlations.filter_shared_sample_keys")
     def test_compute_all_sample(self, filter_shared_samples, sample_r_corr):
@@ -181,7 +179,7 @@ class TestCorrelation(TestCase):
 
         filter_shared_samples.return_value = (["1.23", "6.565", "6.456"], [
             "6.266", "6.565", "6.456"])
-        sample_r_corr.return_value = ([-1.0, 0.9, 6])
+        sample_r_corr.return_value = (["1419792_at", -1.0, 0.9, 6])
 
         this_trait_data = {
             "trait_id": "1455376_at",
@@ -204,13 +202,14 @@ class TestCorrelation(TestCase):
             }
         ]
 
-        sample_all_results = [{"1419792_at": {"corr_coeffient": -1.0,
+        sample_all_results = [{"1419792_at": {"corr_coefficient": -1.0,
                                               "p_value": 0.9,
                                               "num_overlap": 6}}]
 
         self.assertEqual(compute_all_sample_correlation(
             this_trait=this_trait_data, target_dataset=traits_dataset), sample_all_results)
         sample_r_corr.assert_called_once_with(
+            trait_name='1419792_at',
             corr_method="pearson", trait_vals=['1.23', '6.565', '6.456'],
             target_samples_vals=['6.266', '6.565', '6.456'])
         filter_shared_samples.assert_called_once_with(
@@ -417,7 +416,7 @@ class TestCorrelation(TestCase):
                             {"1418702_a_at":
                              {"tissue_corr": -0.5, "tissue_p_val": 0.9, "tissue_number": 3}}]
 
-        results = compute_all_tissue_correlation(
+        results = compute_tissue_correlation(
             primary_tissue_dict=primary_tissue_dict,
             target_tissues_data=target_tissue_data,
             corr_method="pearson")
@@ -491,4 +490,5 @@ class TestCorrelation(TestCase):
                  [None, None, None, None, 2, None, None, 3, None, None],
                  (0.0, 2)]]:
             with self.subTest(dbdata=dbdata, userdata=userdata):
-                self.assertEqual(compute_correlation(dbdata, userdata), expected)
+                self.assertEqual(compute_correlation(
+                    dbdata, userdata), expected)