diff options
Diffstat (limited to 'tests/unit')
-rw-r--r-- | tests/unit/computations/test_correlation.py | 48 | ||||
-rw-r--r-- | tests/unit/computations/test_heatmap.py | 143 | ||||
-rw-r--r-- | tests/unit/computations/test_slink.py | 311 | ||||
-rw-r--r-- | tests/unit/db/test_datasets.py | 133 | ||||
-rw-r--r-- | tests/unit/db/test_traits.py | 224 |
5 files changed, 849 insertions, 10 deletions
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py index b1bc6ef..fc52ec1 100644 --- a/tests/unit/computations/test_correlation.py +++ b/tests/unit/computations/test_correlation.py @@ -1,5 +1,4 @@ """Module contains the tests for correlation""" -import unittest from unittest import TestCase from unittest import mock @@ -16,9 +15,10 @@ from gn3.computations.correlations import fetch_lit_correlation_data from gn3.computations.correlations import query_formatter from gn3.computations.correlations import map_to_mouse_gene_id from gn3.computations.correlations import compute_all_lit_correlation -from gn3.computations.correlations import compute_all_tissue_correlation +from gn3.computations.correlations import compute_tissue_correlation from gn3.computations.correlations import map_shared_keys_to_values from gn3.computations.correlations import process_trait_symbol_dict +from gn3.computations.correlations2 import compute_correlation class QueryableMixin: @@ -172,7 +172,6 @@ class TestCorrelation(TestCase): self.assertEqual(results, (filtered_this_samplelist, filtered_target_samplelist)) - @unittest.skip("Test needs to be refactored ") @mock.patch("gn3.computations.correlations.compute_sample_r_correlation") @mock.patch("gn3.computations.correlations.filter_shared_sample_keys") def test_compute_all_sample(self, filter_shared_samples, sample_r_corr): @@ -180,7 +179,7 @@ class TestCorrelation(TestCase): filter_shared_samples.return_value = (["1.23", "6.565", "6.456"], [ "6.266", "6.565", "6.456"]) - sample_r_corr.return_value = ([-1.0, 0.9, 6]) + sample_r_corr.return_value = (["1419792_at", -1.0, 0.9, 6]) this_trait_data = { "trait_id": "1455376_at", @@ -203,13 +202,14 @@ class TestCorrelation(TestCase): } ] - sample_all_results = [{"1419792_at": {"corr_coeffient": -1.0, + sample_all_results = [{"1419792_at": {"corr_coefficient": -1.0, "p_value": 0.9, "num_overlap": 6}}] self.assertEqual(compute_all_sample_correlation( this_trait=this_trait_data, target_dataset=traits_dataset), sample_all_results) sample_r_corr.assert_called_once_with( + trait_name='1419792_at', corr_method="pearson", trait_vals=['1.23', '6.565', '6.456'], target_samples_vals=['6.266', '6.565', '6.456']) filter_shared_samples.assert_called_once_with( @@ -406,17 +406,17 @@ class TestCorrelation(TestCase): target_tissue_data = {"trait_symbol_dict": target_trait_symbol, "symbol_tissue_vals_dict": target_symbol_tissue_vals} - mock_tissue_corr.side_effect = [{"tissue_corr": -0.5, "tissue_p_val": 0.9, - "tissue_number": 3}, - {"tissue_corr": 1.11, "tissue_p_val": 0.2, - "tissue_number": 3}] + mock_tissue_corr.side_effect = [{"1418702_a_at": {"tissue_corr": -0.5, "tissue_p_val": 0.9, + "tissue_number": 3}}, + {"1412_at": {"tissue_corr": 1.11, "tissue_p_val": 0.2, + "tissue_number": 3}}] expected_results = [{"1412_at": {"tissue_corr": 1.11, "tissue_p_val": 0.2, "tissue_number": 3}}, {"1418702_a_at": {"tissue_corr": -0.5, "tissue_p_val": 0.9, "tissue_number": 3}}] - results = compute_all_tissue_correlation( + results = compute_tissue_correlation( primary_tissue_dict=primary_tissue_dict, target_tissues_data=target_tissue_data, corr_method="pearson") @@ -464,3 +464,31 @@ class TestCorrelation(TestCase): trait_symbol_dict, tissue_values_dict) self.assertEqual(results, [expected_results]) + + def test_compute_correlation(self): + """Test that the new correlation function works the same as the original + from genenetwork1.""" + for dbdata, userdata, expected in [ + [[None, None, None, None, None, None, None, None, None, None], + [None, None, None, None, None, None, None, None, None, None], + (0.0, 0)], + [[None, None, None, None, None, None, None, None, None, 0], + [None, None, None, None, None, None, None, None, None, None], + (0.0, 0)], + [[None, None, None, None, None, None, None, None, None, 0], + [None, None, None, None, None, None, None, None, None, 0], + (0.0, 1)], + [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + (0, 10)], + [[9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87], + [9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87], + (0.9999999999999998, 10)], + [[9.3, 2.2, 5.4, 7.2, 6.4, 7.6, 3.8, 1.8, 8.4, 0.2], + [0.6, 3.97, 5.82, 8.21, 1.65, 4.55, 6.72, 9.5, 7.33, 2.34], + (-0.12720361919462056, 10)], + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [None, None, None, None, 2, None, None, 3, None, None], + (0.0, 2)]]: + with self.subTest(dbdata=dbdata, userdata=userdata): + self.assertEqual(compute_correlation( + dbdata, userdata), expected) diff --git a/tests/unit/computations/test_heatmap.py b/tests/unit/computations/test_heatmap.py new file mode 100644 index 0000000..650cb45 --- /dev/null +++ b/tests/unit/computations/test_heatmap.py @@ -0,0 +1,143 @@ +"""Module contains tests for gn3.computations.heatmap""" +from unittest import TestCase +from gn3.computations.heatmap import cluster_traits, export_trait_data + +strainlist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"] +trait_data = { + "mysqlid": 36688172, + "data": { + "B6cC3-1": {"strain_name": "B6cC3-1", "value": 7.51879, "variance": None, "ndata": None}, + "BXD1": {"strain_name": "BXD1", "value": 7.77141, "variance": None, "ndata": None}, + "BXD12": {"strain_name": "BXD12", "value": 8.39265, "variance": None, "ndata": None}, + "BXD16": {"strain_name": "BXD16", "value": 8.17443, "variance": None, "ndata": None}, + "BXD19": {"strain_name": "BXD19", "value": 8.30401, "variance": None, "ndata": None}, + "BXD2": {"strain_name": "BXD2", "value": 7.80944, "variance": None, "ndata": None}, + "BXD21": {"strain_name": "BXD21", "value": 8.93809, "variance": None, "ndata": None}, + "BXD24": {"strain_name": "BXD24", "value": 7.99415, "variance": None, "ndata": None}, + "BXD27": {"strain_name": "BXD27", "value": 8.12177, "variance": None, "ndata": None}, + "BXD28": {"strain_name": "BXD28", "value": 7.67688, "variance": None, "ndata": None}, + "BXD32": {"strain_name": "BXD32", "value": 7.79062, "variance": None, "ndata": None}, + "BXD39": {"strain_name": "BXD39", "value": 8.27641, "variance": None, "ndata": None}, + "BXD40": {"strain_name": "BXD40", "value": 8.18012, "variance": None, "ndata": None}, + "BXD42": {"strain_name": "BXD42", "value": 7.82433, "variance": None, "ndata": None}, + "BXD6": {"strain_name": "BXD6", "value": 8.09718, "variance": None, "ndata": None}, + "BXH14": {"strain_name": "BXH14", "value": 7.97475, "variance": None, "ndata": None}, + "BXH19": {"strain_name": "BXH19", "value": 7.67223, "variance": None, "ndata": None}, + "BXH2": {"strain_name": "BXH2", "value": 7.93622, "variance": None, "ndata": None}, + "BXH22": {"strain_name": "BXH22", "value": 7.43692, "variance": None, "ndata": None}, + "BXH4": {"strain_name": "BXH4", "value": 7.96336, "variance": None, "ndata": None}, + "BXH6": {"strain_name": "BXH6", "value": 7.75132, "variance": None, "ndata": None}, + "BXH7": {"strain_name": "BXH7", "value": 8.12927, "variance": None, "ndata": None}, + "BXH8": {"strain_name": "BXH8", "value": 6.77338, "variance": None, "ndata": None}, + "BXH9": {"strain_name": "BXH9", "value": 8.03836, "variance": None, "ndata": None}, + "C3H/HeJ": {"strain_name": "C3H/HeJ", "value": 7.42795, "variance": None, "ndata": None}, + "C57BL/6J": {"strain_name": "C57BL/6J", "value": 7.50606, "variance": None, "ndata": None}, + "DBA/2J": {"strain_name": "DBA/2J", "value": 7.72588, "variance": None, "ndata": None}}} + +class TestHeatmap(TestCase): + """Class for testing heatmap computation functions""" + + def test_export_trait_data_dtype(self): + """ + Test `export_trait_data` with different values for the `dtype` keyword + argument + """ + for dtype, expected in [ + ["val", (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["var", (None, None, None, None, None, None)], + ["N", (None, None, None, None, None, None)], + ["all", (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)]]: + with self.subTest(dtype=dtype): + self.assertEqual( + export_trait_data(trait_data, strainlist, dtype=dtype), + expected) + + def test_export_trait_data_dtype_all_flags(self): + """ + Test `export_trait_data` with different values for the `dtype` keyword + argument and the different flags set up + """ + for dtype, vflag, nflag, expected in [ + ["val", False, False, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["val", False, True, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["val", True, False, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["val", True, True, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["var", False, False, (None, None, None, None, None, None)], + ["var", False, True, (None, None, None, None, None, None)], + ["var", True, False, (None, None, None, None, None, None)], + ["var", True, True, (None, None, None, None, None, None)], + ["N", False, False, (None, None, None, None, None, None)], + ["N", False, True, (None, None, None, None, None, None)], + ["N", True, False, (None, None, None, None, None, None)], + ["N", True, True, (None, None, None, None, None, None)], + ["all", False, False, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["all", False, True, + (7.51879, None, 7.77141, None, 8.39265, None, 8.17443, None, + 8.30401, None, 7.80944, None)], + ["all", True, False, + (7.51879, None, 7.77141, None, 8.39265, None, 8.17443, None, + 8.30401, None, 7.80944, None)], + ["all", True, True, + (7.51879, None, None, 7.77141, None, None, 8.39265, None, None, + 8.17443, None, None, 8.30401, None, None, 7.80944, None, None)] + ]: + with self.subTest(dtype=dtype, vflag=vflag, nflag=nflag): + self.assertEqual( + export_trait_data( + trait_data, strainlist, dtype=dtype, var_exists=vflag, + n_exists=nflag), + expected) + + def test_cluster_traits(self): + """ + Test that the clustering is working as expected. + """ + traits_data_list = [ + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944), + (6.1427, 6.50588, 7.73705, 6.68328, 7.49293, 7.27398), + (8.4211, 8.30581, 9.24076, 8.51173, 9.18455, 8.36077), + (10.0904, 10.6509, 9.36716, 9.91202, 8.57444, 10.5731), + (10.188, 9.76652, 9.54813, 9.05074, 9.52319, 9.10505), + (6.74676, 7.01029, 7.54169, 6.48574, 7.01427, 7.26815), + (6.39359, 6.85321, 5.78337, 7.11141, 6.22101, 6.16544), + (6.84118, 7.08432, 7.59844, 7.08229, 7.26774, 7.24991), + (9.45215, 10.6943, 8.64719, 10.1592, 7.75044, 8.78615), + (7.04737, 6.87185, 7.58586, 6.92456, 6.84243, 7.36913)] + self.assertEqual( + cluster_traits(traits_data_list), + ((0.0, 0.20337048635536847, 0.16381088984330505, 1.7388553629398245, + 1.5025235756329178, 0.6952839500255574, 1.271661230252733, + 0.2100487290977544, 1.4699690641062024, 0.7934461515867415), + (0.20337048635536847, 0.0, 0.2198321044997198, 1.5753041735592204, + 1.4815755944537086, 0.26087293140686374, 1.6939790104301427, + 0.06024619831474998, 1.7430082449189215, 0.4497104244247795), + (0.16381088984330505, 0.2198321044997198, 0.0, 1.9073926868549234, + 1.0396738891139845, 0.5278328671176757, 1.6275069061182947, + 0.2636503792482082, 1.739617877037615, 0.7127042590637039), + (1.7388553629398245, 1.5753041735592204, 1.9073926868549234, 0.0, + 0.9936846292920328, 1.1169999189889366, 0.6007483980555253, + 1.430209221053372, 0.25879514152086425, 0.9313185954797953), + (1.5025235756329178, 1.4815755944537086, 1.0396738891139845, + 0.9936846292920328, 0.0, 1.027827186339337, 1.1441743109173244, + 1.4122477962364253, 0.8968250491499363, 1.1683723389247052), + (0.6952839500255574, 0.26087293140686374, 0.5278328671176757, + 1.1169999189889366, 1.027827186339337, 0.0, 1.8420471110023269, + 0.19179284676938602, 1.4875072385631605, 0.23451785425383564), + (1.271661230252733, 1.6939790104301427, 1.6275069061182947, + 0.6007483980555253, 1.1441743109173244, 1.8420471110023269, 0.0, + 1.6540234785929928, 0.2140799896286565, 1.7413442197913358), + (0.2100487290977544, 0.06024619831474998, 0.2636503792482082, + 1.430209221053372, 1.4122477962364253, 0.19179284676938602, + 1.6540234785929928, 0.0, 1.5225640692832796, 0.33370067057028485), + (1.4699690641062024, 1.7430082449189215, 1.739617877037615, + 0.25879514152086425, 0.8968250491499363, 1.4875072385631605, + 0.2140799896286565, 1.5225640692832796, 0.0, 1.3256191648260216), + (0.7934461515867415, 0.4497104244247795, 0.7127042590637039, + 0.9313185954797953, 1.1683723389247052, 0.23451785425383564, + 1.7413442197913358, 0.33370067057028485, 1.3256191648260216, + 0.0))) diff --git a/tests/unit/computations/test_slink.py b/tests/unit/computations/test_slink.py new file mode 100644 index 0000000..995393b --- /dev/null +++ b/tests/unit/computations/test_slink.py @@ -0,0 +1,311 @@ +"""Module contains tests for slink""" +from unittest import TestCase + +from gn3.computations.slink import slink +from gn3.computations.slink import nearest +from gn3.computations.slink import LengthError +from gn3.computations.slink import MirrorError + +class TestSlink(TestCase): + """Class for testing slink functions""" + + def test_nearest_expects_list_of_lists(self): + """Test that function only accepts a list of lists.""" + # This might be better handled with type-hints and mypy + for item in [9, "some string", 5.432, + [1, 2, 3], ["test", 7.4]]: + with self.subTest(item=item): + with self.assertRaises(ValueError, msg="Expected list or tuple"): + nearest(item, 1, 1) + + def test_nearest_does_not_allow_empty_lists(self): + """Test that function does not accept an empty list, or any of the child + lists to be empty.""" + for lst in [[], + [[], []], + [[], [], []], + [[0, 1, 2], [], [1, 2, 0]]]: + with self.subTest(lst=lst): + with self.assertRaises(ValueError): + nearest(lst, 1, 1) + + def test_nearest_expects_children_are_same_length_as_parent(self): + """Test that children lists are same length as parent list.""" + for lst in [[[0, 1]], + [[0, 1, 2], [3, 4, 5]], + [[0, 1, 2, 3], [4, 5, 6], [7, 8, 9, 0]], + [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [1, 2, 3, 4, 5], [2, 3], + [3, 4, 5, 6, 7]]]: + with self.subTest(lst=lst): + with self.assertRaises(LengthError): + nearest(lst, 1, 1) + + def test_nearest_expects_member_is_zero_distance_from_itself(self): + """Test that distance of a member from itself is zero""" + for lst in [[[1]], + [[1, 2], [3, 4]], + [1, 0, 0], [0, 0, 5], [0, 3, 4], + [0, 0, 0, 0], [0, 0, 3, 3], [0, 1, 2, 3], [0, 3, 2, 0]]: + with self.subTest(lst=lst): + with self.assertRaises(ValueError): + nearest(lst, 1, 1) + + def test_nearest_expects_distance_atob_is_equal_to_distance_btoa(self): + """Test that the distance from member A to member B is the same as that + from member B to member A.""" + for lst in [[[0, 1], [2, 0]], + [[0, 1, 2], [1, 0, 3], [9, 7, 0]], + [[0, 1, 2, 3], [7, 0, 2, 3], [2, 3, 0, 1], [8, 9, 5, 0]]]: + with self.subTest(lst=lst): + with self.assertRaises(MirrorError): + nearest(lst, 1, 1) + + def test_nearest_expects_zero_or_positive_distances(self): + """Test that all distances are either zero, or greater than zero.""" + # Based on: + # https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/slink.py#L87-L89 + for lst in [[[0, -1, 2, 3], [-1, 0, 3, 4], [2, 3, 0, 5], [3, 4, 5, 0]], + [[0, 1, -2, 3], [1, 0, 3, 4], [-2, 3, 0, 5], [3, 4, 5, 0]], + [[0, 1, 2, 3], [1, 0, -3, 4], [2, -3, 0, 5], [3, 4, 5, 0]], + [[0, 1, 2, -3], [1, 0, 3, 4], [2, 3, 0, 5], [-3, 4, 5, 0]], + [[0, 1, 2, 3], [1, 0, 3, -4], [2, 3, 0, 5], [3, -4, 5, 0]], + [[0, 1, 2, 3], [1, 0, 3, 4], [2, 3, 0, -5], [3, 4, -5, 0]]]: + with self.subTest(lst=lst): + with self.assertRaises(ValueError, msg="Distances should be positive."): + nearest(lst, 1, 1) + + def test_nearest_returns_shortest_distance_given_coordinates_to_both_group_members(self): + """Test that the shortest distance is returned.""" + # This test is named wrong - at least I think it is, from the expected results + # This tests distance when both `i`, and `j` are integers + # We still need to add tests for when (either one/both) (is/are) not (an) integer(s) + # https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/slink.py#L39-L40 + for lst, i, j, expected in [ + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 0, 0, 0], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 0, 1, 9], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 0, 2, 3], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 0, 3, 6], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 0, 4, 11], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 1, 0, 9], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 1, 1, 0], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 1, 2, 7], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 1, 3, 5], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 1, 4, 10], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 2, 0, 3], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 2, 1, 7], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 2, 2, 0], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 2, 3, 9], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 2, 4, 2], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 3, 0, 6], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 3, 1, 5], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 3, 2, 9], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 3, 3, 0], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 3, 4, 8], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 4, 0, 11], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 4, 1, 10], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 4, 2, 2], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 4, 3, 8], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + 4, 4, 0], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 0, 0, 0], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 0, 1, 9], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 0, 2, 5.5], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 0, 3, 6], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 0, 4, 11], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 1, 0, 9], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 1, 1, 0], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 1, 2, 7], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 1, 3, 5], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 1, 4, 10], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 2, 0, 5.5], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 2, 1, 7], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 2, 2, 0], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 2, 3, 9], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 2, 4, 2], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 3, 0, 6], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 3, 1, 5], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 3, 2, 9], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 3, 3, 0], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 3, 4, 3], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 4, 0, 11], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 4, 1, 10], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 4, 2, 2], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 4, 3, 3], + [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2], + [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]], + 4, 4, 0]]: + with self.subTest(lst=lst): + self.assertEqual(nearest(lst, i, j), expected) + + def test_nearest_gives_shortest_distance_between_list_of_members_and_member(self): + """Test that the shortest distance is returned.""" + for members_distances, members_list, member_coordinate, expected_distance in [ + [[[0, 9, 3], [9, 0, 7], [3, 7, 0]], (0, 2, 3), 1, 7], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], [0, 1, 2, 3, 4], 3, 0], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], [0, 1, 2, 4], 3, 5], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], [0, 2, 4], 3, 6], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], [2, 4], 3, 9]]: + with self.subTest( + members_distances=members_distances, + members_list=members_list, + member_coordinate=member_coordinate, + expected_distance=expected_distance): + self.assertEqual( + nearest( + members_distances, members_list, member_coordinate), + expected_distance) + self.assertEqual( + nearest( + members_distances, member_coordinate, members_list), + expected_distance) + + def test_nearest_returns_shortest_distance_given_two_lists_of_members(self): + """Test that the shortest distance is returned.""" + for members_distances, members_list, member_list2, expected_distance in [ + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], 0], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + [0, 1], [3, 4], 6], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + [0, 1], [2, 3, 4], 3], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], + [0, 2], [3, 4], 6]]: + with self.subTest( + members_distances=members_distances, + members_list=members_list, + member_list2=member_list2, + expected_distance=expected_distance): + self.assertEqual( + nearest( + members_distances, members_list, member_list2), + expected_distance) + self.assertEqual( + nearest( + members_distances, member_list2, members_list), + expected_distance) + + def test_slink_wrong_data_returns_empty_list(self): + """Test that empty list is returned for wrong data.""" + for data in [1, "test", [], 2.945, nearest, [0]]: + with self.subTest(data=data): + self.assertEqual(slink(data), []) + + def test_slink_with_data(self): + """Test slink with example data, and expected results for each data + sample.""" + for data, expected in [ + [[[0, 9], [9, 0]], [0, 1, 9]], + [[[0, 9, 3], [9, 0, 7], [3, 7, 0]], [(0, 2, 3), 1, 7]], + [[[0, 9, 3, 6], [9, 0, 7, 5], [3, 7, 0, 9], [6, 5, 9, 0]], + [(0, 2, 3), (1, 3, 5), 6]], + [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2], + [6, 5, 9, 0, 8], + [11, 10, 2, 8, 0]], + [(0, (2, 4, 2), 3), (1, 3, 5), 6]]]: + with self.subTest(data=data): + self.assertEqual(slink(data), expected) diff --git a/tests/unit/db/test_datasets.py b/tests/unit/db/test_datasets.py new file mode 100644 index 0000000..38de0e2 --- /dev/null +++ b/tests/unit/db/test_datasets.py @@ -0,0 +1,133 @@ +"""Tests for gn3/db/datasets.py""" + +from unittest import mock, TestCase +from gn3.db.datasets import ( + retrieve_dataset_name, + retrieve_riset_fields, + retrieve_geno_riset_fields, + retrieve_publish_riset_fields, + retrieve_probeset_riset_fields) + +class TestDatasetsDBFunctions(TestCase): + """Test cases for datasets functions.""" + + def test_retrieve_dataset_name(self): + """Test that the function is called correctly.""" + for trait_type, thresh, trait_name, dataset_name, columns, table in [ + ["ProbeSet", 9, "probesetTraitName", "probesetDatasetName", + "Id, Name, FullName, ShortName, DataScale", "ProbeSetFreeze"], + ["Geno", 3, "genoTraitName", "genoDatasetName", + "Id, Name, FullName, ShortName", "GenoFreeze"], + ["Publish", 6, "publishTraitName", "publishDatasetName", + "Id, Name, FullName, ShortName", "PublishFreeze"], + ["Temp", 4, "tempTraitName", "tempTraitName", + "Id, Name, FullName, ShortName", "TempFreeze"]]: + db_mock = mock.MagicMock() + with self.subTest(trait_type=trait_type): + with db_mock.cursor() as cursor: + cursor.fetchone.return_value = {} + self.assertEqual( + retrieve_dataset_name( + trait_type, thresh, trait_name, dataset_name, db_mock), + {}) + cursor.execute.assert_called_once_with( + "SELECT {cols} " + "FROM {table} " + "WHERE public > %(threshold)s AND " + "(Name = %(name)s " + "OR FullName = %(name)s " + "OR ShortName = %(name)s)".format( + table=table, cols=columns), + {"threshold": thresh, "name": dataset_name}) + + def test_retrieve_probeset_riset_fields(self): + """ + Test that the `riset` and `riset_id` fields are retrieved appropriately + for the 'ProbeSet' trait type. + """ + for trait_name, expected in [ + ["testProbeSetName", {}]]: + db_mock = mock.MagicMock() + with self.subTest(trait_name=trait_name, expected=expected): + with db_mock.cursor() as cursor: + cursor.execute.return_value = () + self.assertEqual( + retrieve_probeset_riset_fields(trait_name, db_mock), + expected) + cursor.execute.assert_called_once_with( + ( + "SELECT InbredSet.Name, InbredSet.Id" + " FROM InbredSet, ProbeSetFreeze, ProbeFreeze" + " WHERE ProbeFreeze.InbredSetId = InbredSet.Id" + " AND ProbeFreeze.Id = ProbeSetFreeze.ProbeFreezeId" + " AND ProbeSetFreeze.Name = %(name)s"), + {"name": trait_name}) + + def test_retrieve_riset_fields(self): + """ + Test that the riset fields are set up correctly for the different trait + types. + """ + for trait_type, trait_name, dataset_info, expected in [ + ["Publish", "pubTraitName01", {"dataset_name": "pubDBName01"}, + {"dataset_name": "pubDBName01", "riset": ""}], + ["ProbeSet", "prbTraitName01", {"dataset_name": "prbDBName01"}, + {"dataset_name": "prbDBName01", "riset": ""}], + ["Geno", "genoTraitName01", {"dataset_name": "genoDBName01"}, + {"dataset_name": "genoDBName01", "riset": ""}], + ["Temp", "tempTraitName01", {}, {"riset": ""}], + ]: + db_mock = mock.MagicMock() + with self.subTest( + trait_type=trait_type, trait_name=trait_name, + dataset_info=dataset_info): + with db_mock.cursor() as cursor: + cursor.execute.return_value = ("riset_name", 0) + self.assertEqual( + retrieve_riset_fields( + trait_type, trait_name, dataset_info, db_mock), + expected) + + def test_retrieve_publish_riset_fields(self): + """ + Test that the `riset` and `riset_id` fields are retrieved appropriately + for the 'Publish' trait type. + """ + for trait_name, expected in [ + ["testPublishName", {}]]: + db_mock = mock.MagicMock() + with self.subTest(trait_name=trait_name, expected=expected): + with db_mock.cursor() as cursor: + cursor.execute.return_value = () + self.assertEqual( + retrieve_publish_riset_fields(trait_name, db_mock), + expected) + cursor.execute.assert_called_once_with( + ( + "SELECT InbredSet.Name, InbredSet.Id" + " FROM InbredSet, PublishFreeze" + " WHERE PublishFreeze.InbredSetId = InbredSet.Id" + " AND PublishFreeze.Name = %(name)s"), + {"name": trait_name}) + + def test_retrieve_geno_riset_fields(self): + """ + Test that the `riset` and `riset_id` fields are retrieved appropriately + for the 'Geno' trait type. + """ + for trait_name, expected in [ + ["testGenoName", {}]]: + db_mock = mock.MagicMock() + with self.subTest(trait_name=trait_name, expected=expected): + with db_mock.cursor() as cursor: + cursor.execute.return_value = () + self.assertEqual( + retrieve_geno_riset_fields(trait_name, db_mock), + expected) + cursor.execute.assert_called_once_with( + ( + "SELECT InbredSet.Name, InbredSet.Id" + " FROM InbredSet, GenoFreeze" + " WHERE GenoFreeze.InbredSetId = InbredSet.Id" + " AND GenoFreeze.Name = %(name)s"), + {"name": trait_name}) diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py new file mode 100644 index 0000000..ee98893 --- /dev/null +++ b/tests/unit/db/test_traits.py @@ -0,0 +1,224 @@ +"""Tests for gn3/db/traits.py""" +from unittest import mock, TestCase +from gn3.db.traits import ( + build_trait_name, + set_haveinfo_field, + update_sample_data, + retrieve_trait_info, + set_confidential_field, + set_homologene_id_field, + retrieve_geno_trait_info, + retrieve_temp_trait_info, + retrieve_publish_trait_info, + retrieve_probeset_trait_info) + +class TestTraitsDBFunctions(TestCase): + "Test cases for traits functions" + + def test_retrieve_publish_trait_info(self): + """Test retrieval of type `Publish` traits.""" + db_mock = mock.MagicMock() + with db_mock.cursor() as cursor: + cursor.fetchone.return_value = tuple() + trait_source = { + "trait_name": "PublishTraitName", "trait_dataset_id": 1} + self.assertEqual( + retrieve_publish_trait_info(trait_source, db_mock), {}) + cursor.execute.assert_called_once_with( + ("SELECT " + "PublishXRef.Id, Publication.PubMed_ID," + " Phenotype.Pre_publication_description," + " Phenotype.Post_publication_description," + " Phenotype.Original_description," + " Phenotype.Pre_publication_abbreviation," + " Phenotype.Post_publication_abbreviation," + " Phenotype.Lab_code, Phenotype.Submitter, Phenotype.Owner," + " Phenotype.Authorized_Users," + " CAST(Publication.Authors AS BINARY)," + " Publication.Title, Publication.Abstract," + " Publication.Journal," + " Publication.Volume, Publication.Pages, Publication.Month," + " Publication.Year, PublishXRef.Sequence, Phenotype.Units," + " PublishXRef.comments" + " FROM" + " PublishXRef, Publication, Phenotype, PublishFreeze" + " WHERE" + " PublishXRef.Id = %(trait_name)s" + " AND Phenotype.Id = PublishXRef.PhenotypeId" + " AND Publication.Id = PublishXRef.PublicationId" + " AND PublishXRef.InbredSetId = PublishFreeze.InbredSetId" + " AND PublishFreeze.Id =%(trait_dataset_id)s"), + trait_source) + + def test_retrieve_probeset_trait_info(self): + """Test retrieval of type `Probeset` traits.""" + db_mock = mock.MagicMock() + with db_mock.cursor() as cursor: + cursor.fetchone.return_value = tuple() + trait_source = { + "trait_name": "ProbeSetTraitName", + "trait_dataset_name": "ProbeSetDatasetTraitName"} + self.assertEqual( + retrieve_probeset_trait_info(trait_source, db_mock), {}) + cursor.execute.assert_called_once_with( + ( + "SELECT " + "ProbeSet.name, ProbeSet.symbol, ProbeSet.description, " + "ProbeSet.probe_target_description, ProbeSet.chr, " + "ProbeSet.mb, ProbeSet.alias, ProbeSet.geneid, " + "ProbeSet.genbankid, ProbeSet.unigeneid, ProbeSet.omim, " + "ProbeSet.refseq_transcriptid, ProbeSet.blatseq, " + "ProbeSet.targetseq, ProbeSet.chipid, ProbeSet.comments, " + "ProbeSet.strand_probe, ProbeSet.strand_gene, " + "ProbeSet.probe_set_target_region, ProbeSet.proteinid, " + "ProbeSet.probe_set_specificity, " + "ProbeSet.probe_set_blat_score, " + "ProbeSet.probe_set_blat_mb_start, " + "ProbeSet.probe_set_blat_mb_end, " + "ProbeSet.probe_set_strand, ProbeSet.probe_set_note_by_rw, " + "ProbeSet.flag " + "FROM " + "ProbeSet, ProbeSetFreeze, ProbeSetXRef " + "WHERE " + "ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id " + "AND ProbeSetXRef.ProbeSetId = ProbeSet.Id " + "AND ProbeSetFreeze.Name = %(trait_dataset_name)s " + "AND ProbeSet.Name = %(trait_name)s"), trait_source) + + def test_retrieve_geno_trait_info(self): + """Test retrieval of type `Geno` traits.""" + db_mock = mock.MagicMock() + with db_mock.cursor() as cursor: + cursor.fetchone.return_value = tuple() + trait_source = { + "trait_name": "GenoTraitName", + "trait_dataset_name": "GenoDatasetTraitName"} + self.assertEqual( + retrieve_geno_trait_info(trait_source, db_mock), {}) + cursor.execute.assert_called_once_with( + ( + "SELECT " + "Geno.name, Geno.chr, Geno.mb, Geno.source2, Geno.sequence " + "FROM " + "Geno, GenoFreeze, GenoXRef " + "WHERE " + "GenoXRef.GenoFreezeId = GenoFreeze.Id " + "AND GenoXRef.GenoId = Geno.Id " + "AND GenoFreeze.Name = %(trait_dataset_name)s " + "AND Geno.Name = %(trait_name)s"), + trait_source) + + def test_retrieve_temp_trait_info(self): + """Test retrieval of type `Temp` traits.""" + db_mock = mock.MagicMock() + with db_mock.cursor() as cursor: + cursor.fetchone.return_value = tuple() + trait_source = {"trait_name": "TempTraitName"} + self.assertEqual( + retrieve_temp_trait_info(trait_source, db_mock), {}) + cursor.execute.assert_called_once_with( + "SELECT name, description FROM Temp WHERE Name = %(trait_name)s", + trait_source) + + def test_build_trait_name_with_good_fullnames(self): + """ + Check that the name is built correctly. + """ + for fullname, expected in [ + ["testdb::testname", + {"db": {"dataset_name": "testdb", "dataset_type": "ProbeSet"}, + "trait_name": "testname", "cellid": "", + "trait_fullname": "testdb::testname"}], + ["testdb::testname::testcell", + {"db": {"dataset_name": "testdb", "dataset_type": "ProbeSet"}, + "trait_name": "testname", "cellid": "testcell", + "trait_fullname": "testdb::testname::testcell"}]]: + with self.subTest(fullname=fullname): + self.assertEqual(build_trait_name(fullname), expected) + + def test_build_trait_name_with_bad_fullnames(self): + """ + Check that an exception is raised if the full name format is wrong. + """ + for fullname in ["", "test", "test:test"]: + with self.subTest(fullname=fullname): + with self.assertRaises(AssertionError, msg="Name format error"): + build_trait_name(fullname) + + def test_retrieve_trait_info(self): + """Test that information on traits is retrieved as appropriate.""" + for threshold, trait_fullname, expected in [ + [9, "pubDb::PublishTraitName::pubCell", {"haveinfo": 0}], + [5, "prbDb::ProbeSetTraitName::prbCell", {"haveinfo": 0}], + [12, "genDb::GenoTraitName", {"haveinfo": 0}], + [6, "tmpDb::TempTraitName", {"haveinfo": 0}]]: + db_mock = mock.MagicMock() + with self.subTest(trait_fullname=trait_fullname): + with db_mock.cursor() as cursor: + cursor.fetchone.return_value = tuple() + self.assertEqual( + retrieve_trait_info( + threshold, trait_fullname, db_mock), + expected) + + def test_update_sample_data(self): + """Test that the SQL queries when calling update_sample_data are called with + the right calls. + + """ + db_mock = mock.MagicMock() + + STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s" + PUBLISH_DATA_SQL: str = ("UPDATE PublishData SET value = %s " + "WHERE StrainId = %s AND Id = %s") + PUBLISH_SE_SQL: str = ("UPDATE PublishSE SET error = %s " + "WHERE StrainId = %s AND DataId = %s") + N_STRAIN_SQL: str = ("UPDATE NStrain SET count = %s " + "WHERE StrainId = %s AND DataId = %s") + + with db_mock.cursor() as cursor: + type(cursor).rowcount = 1 + self.assertEqual(update_sample_data( + conn=db_mock, strain_name="BXD11", + strain_id=10, publish_data_id=8967049, + value=18.7, error=2.3, count=2), + (1, 1, 1, 1)) + cursor.execute.assert_has_calls( + [mock.call(STRAIN_ID_SQL, ('BXD11', 10)), + mock.call(PUBLISH_DATA_SQL, (18.7, 10, 8967049)), + mock.call(PUBLISH_SE_SQL, (2.3, 10, 8967049)), + mock.call(N_STRAIN_SQL, (2, 10, 8967049))] + ) + + def test_set_haveinfo_field(self): + """Test that the `haveinfo` field is set up correctly""" + for trait_info, expected in [ + [{}, {"haveinfo": 0}], + [{"k1": "v1"}, {"k1": "v1", "haveinfo": 1}]]: + with self.subTest(trait_info=trait_info, expected=expected): + self.assertEqual(set_haveinfo_field(trait_info), expected) + + def test_set_homologene_id_field(self): + """Test that the `homologene_id` field is set up correctly""" + for trait_type, trait_info, expected in [ + ["Publish", {}, {"homologeneid": None}], + ["ProbeSet", {}, {"homologeneid": None}], + ["Geno", {}, {"homologeneid": None}], + ["Temp", {}, {"homologeneid": None}]]: + db_mock = mock.MagicMock() + with self.subTest(trait_info=trait_info, expected=expected): + with db_mock.cursor() as cursor: + cursor.fetchone.return_value = () + self.assertEqual( + set_homologene_id_field(trait_type, trait_info, db_mock), expected) + + def test_set_confidential_field(self): + """Test that the `confidential` field is set up correctly""" + for trait_type, trait_info, expected in [ + ["Publish", {}, {"confidential": 0}], + ["ProbeSet", {}, {}], + ["Geno", {}, {}], + ["Temp", {}, {}]]: + with self.subTest(trait_info=trait_info, expected=expected): + self.assertEqual( + set_confidential_field(trait_type, trait_info), expected) |