about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/computations/correlations.py5
-rw-r--r--mypy.ini3
-rw-r--r--tests/integration/test_correlation.py2
-rw-r--r--tests/unit/computations/test_correlation.py11
4 files changed, 14 insertions, 7 deletions
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py
index ea1a862..5a9ce62 100644
--- a/gn3/computations/correlations.py
+++ b/gn3/computations/correlations.py
@@ -162,7 +162,8 @@ def fast_compute_all_sample_correlation(this_trait,
         corr_results,
         key=lambda trait_name: -abs(list(trait_name.values())[0]["corr_coefficient"]))
 
-def __corr_compute__(trait_samples, target_trait, corr_method):
+def compute_one_sample_correlation(trait_samples, target_trait, corr_method):
+    """Compute sample correlation against a single trait."""
     trait_name = target_trait.get("trait_id")
     target_trait_data = target_trait["trait_sample_data"]
     try:
@@ -200,7 +201,7 @@ def compute_all_sample_correlation(this_trait,
             (
                 corr for corr in
                 pool.starmap(
-                    __corr_compute__,
+                    compute_one_sample_correlation,
                     ((this_trait_samples, trait, corr_method) for trait in target_dataset))
                 if corr is not None),
             key=lambda trait_name: -abs(
diff --git a/mypy.ini b/mypy.ini
index 8266756..4465656 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -41,3 +41,6 @@ ignore_missing_imports = True
 
 [mypy-sklearn.*]
 ignore_missing_imports = True
+
+[mypy-scripts.argparse_actions.*]
+ignore_missing_imports = True
diff --git a/tests/integration/test_correlation.py b/tests/integration/test_correlation.py
index d52ab01..c1d518d 100644
--- a/tests/integration/test_correlation.py
+++ b/tests/integration/test_correlation.py
@@ -12,7 +12,7 @@ class CorrelationIntegrationTest(TestCase):
         self.app = create_app().test_client()
 
     @pytest.mark.integration_test
-    @mock.patch("gn3.api.correlation.compute_all_sample_correlation")
+    @mock.patch("gn3.api.correlation.run_sample_corr_cmd")
     def test_sample_r_correlation(self, mock_compute_samples):
         """Test /api/correlation/sample_r/{method}"""
         this_trait_data = {
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py
index 267ced3..e8d4f75 100644
--- a/tests/unit/computations/test_correlation.py
+++ b/tests/unit/computations/test_correlation.py
@@ -9,7 +9,7 @@ from numpy.testing import assert_almost_equal
 
 from gn3.computations.correlations import normalize_values
 from gn3.computations.correlations import compute_sample_r_correlation
-from gn3.computations.correlations import compute_all_sample_correlation
+from gn3.computations.correlations import compute_one_sample_correlation
 from gn3.computations.correlations import filter_shared_sample_keys
 
 from gn3.computations.correlations import tissue_correlation_for_trait
@@ -164,7 +164,7 @@ class TestCorrelation(TestCase):
     @pytest.mark.unit_test
     @mock.patch("gn3.computations.correlations.compute_sample_r_correlation")
     @mock.patch("gn3.computations.correlations.filter_shared_sample_keys")
-    def test_compute_all_sample(self, filter_shared_samples, sample_r_corr):
+    def test_compute_one_sample(self, filter_shared_samples, sample_r_corr):
         """Given target dataset compute all sample r correlation"""
 
         filter_shared_samples.return_value = [iter(val) for val in [(
@@ -197,8 +197,11 @@ class TestCorrelation(TestCase):
                                               "p_value": 0.9,
                                               "num_overlap": 6}}]
 
-        self.assertEqual(compute_all_sample_correlation(
-            this_trait=this_trait_data, target_dataset=traits_dataset), sample_all_results)
+        self.assertEqual(
+            compute_one_sample_correlation(
+                this_trait_data["trait_sample_data"],
+                traits_dataset[0], "pearson"),
+            sample_all_results[0])
         sample_r_corr.assert_called_once_with(
             trait_name='1419792_at',
             corr_method="pearson", trait_vals=('1.23', '6.565', '6.456'),