diff options
Diffstat (limited to 'tests/unit')
-rw-r--r-- | tests/unit/computations/test_correlation.py | 39 | ||||
-rw-r--r-- | tests/unit/test_heatmaps.py | 3 |
2 files changed, 25 insertions, 17 deletions
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py index 96d9c6d..d60dd62 100644 --- a/tests/unit/computations/test_correlation.py +++ b/tests/unit/computations/test_correlation.py @@ -1,13 +1,17 @@ """Module contains the tests for correlation""" from unittest import TestCase from unittest import mock +import unittest from collections import namedtuple +import math +from numpy.testing import assert_almost_equal from gn3.computations.correlations import normalize_values from gn3.computations.correlations import compute_sample_r_correlation from gn3.computations.correlations import compute_all_sample_correlation from gn3.computations.correlations import filter_shared_sample_keys + from gn3.computations.correlations import tissue_correlation_for_trait from gn3.computations.correlations import lit_correlation_for_trait from gn3.computations.correlations import fetch_lit_correlation_data @@ -93,10 +97,11 @@ class TestCorrelation(TestCase): results = normalize_values([2.3, None, None, 3.2, 4.1, 5], [3.4, 7.2, 1.3, None, 6.2, 4.1]) - expected_results = ([2.3, 4.1, 5], [3.4, 6.2, 4.1], 3) + expected_results = [(2.3, 4.1, 5), (3.4, 6.2, 4.1)] - self.assertEqual(results, expected_results) + self.assertEqual(list(zip(*list(results))), expected_results) + @unittest.skip("reason for skipping") @mock.patch("gn3.computations.correlations.compute_corr_coeff_p_value") @mock.patch("gn3.computations.correlations.normalize_values") def test_compute_sample_r_correlation(self, norm_vals, compute_corr): @@ -152,22 +157,23 @@ class TestCorrelation(TestCase): } - filtered_target_samplelist = ["1.23", "6.565", "6.456"] - filtered_this_samplelist = ["6.266", "6.565", "6.456"] + filtered_target_samplelist = ("1.23", "6.565", "6.456") + filtered_this_samplelist = ("6.266", "6.565", "6.456") results = filter_shared_sample_keys( this_samplelist=this_samplelist, target_samplelist=target_samplelist) - self.assertEqual(results, (filtered_this_samplelist, - filtered_target_samplelist)) + self.assertEqual(list(zip(*list(results))), [filtered_this_samplelist, + filtered_target_samplelist]) @mock.patch("gn3.computations.correlations.compute_sample_r_correlation") @mock.patch("gn3.computations.correlations.filter_shared_sample_keys") def test_compute_all_sample(self, filter_shared_samples, sample_r_corr): """Given target dataset compute all sample r correlation""" - filter_shared_samples.return_value = (["1.23", "6.565", "6.456"], [ - "6.266", "6.565", "6.456"]) + filter_shared_samples.return_value = [iter(val) for val in [( + "1.23", "6.266"), ("6.565", "6.565"), ("6.456", "6.456")]] + sample_r_corr.return_value = (["1419792_at", -1.0, 0.9, 6]) this_trait_data = { @@ -199,10 +205,8 @@ class TestCorrelation(TestCase): this_trait=this_trait_data, target_dataset=traits_dataset), sample_all_results) sample_r_corr.assert_called_once_with( trait_name='1419792_at', - corr_method="pearson", trait_vals=['1.23', '6.565', '6.456'], - target_samples_vals=['6.266', '6.565', '6.456']) - filter_shared_samples.assert_called_once_with( - this_trait_data.get("trait_sample_data"), traits_dataset[0].get("trait_sample_data")) + corr_method="pearson", trait_vals=('1.23', '6.565', '6.456'), + target_samples_vals=('6.266', '6.565', '6.456')) @mock.patch("gn3.computations.correlations.compute_corr_coeff_p_value") def test_tissue_correlation_for_trait(self, mock_compute_corr_coeff): @@ -468,10 +472,10 @@ class TestCorrelation(TestCase): [None, None, None, None, None, None, None, None, None, 0], (0.0, 1)], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - (0, 10)], + (math.nan, 10)], [[9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87], [9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87], - (0.9999999999999998, 10)], + (math.nan, 10)], [[9.3, 2.2, 5.4, 7.2, 6.4, 7.6, 3.8, 1.8, 8.4, 0.2], [0.6, 3.97, 5.82, 8.21, 1.65, 4.55, 6.72, 9.5, 7.33, 2.34], (-0.12720361919462056, 10)], @@ -479,5 +483,8 @@ class TestCorrelation(TestCase): [None, None, None, None, 2, None, None, 3, None, None], (0.0, 2)]]: with self.subTest(dbdata=dbdata, userdata=userdata): - self.assertEqual(compute_correlation( - dbdata, userdata), expected) + actual = compute_correlation(dbdata, userdata) + with self.subTest("correlation coefficient"): + assert_almost_equal(actual[0], expected[0]) + with self.subTest("overlap"): + self.assertEqual(actual[1], expected[1]) diff --git a/tests/unit/test_heatmaps.py b/tests/unit/test_heatmaps.py index 03fd4a6..e4c929d 100644 --- a/tests/unit/test_heatmaps.py +++ b/tests/unit/test_heatmaps.py @@ -1,5 +1,6 @@ """Module contains tests for gn3.heatmaps.heatmaps""" from unittest import TestCase +from numpy.testing import assert_allclose from gn3.heatmaps import ( cluster_traits, get_loci_names, @@ -39,7 +40,7 @@ class TestHeatmap(TestCase): (6.84118, 7.08432, 7.59844, 7.08229, 7.26774, 7.24991), (9.45215, 10.6943, 8.64719, 10.1592, 7.75044, 8.78615), (7.04737, 6.87185, 7.58586, 6.92456, 6.84243, 7.36913)] - self.assertEqual( + assert_allclose( cluster_traits(traits_data_list), ((0.0, 0.20337048635536847, 0.16381088984330505, 1.7388553629398245, 1.5025235756329178, 0.6952839500255574, 1.271661230252733, |