about summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/integration/test_correlation.py2
-rw-r--r--tests/unit/computations/test_correlation.py48
-rw-r--r--tests/unit/computations/test_heatmap.py143
-rw-r--r--tests/unit/computations/test_slink.py311
-rw-r--r--tests/unit/db/test_datasets.py133
-rw-r--r--tests/unit/db/test_traits.py224
6 files changed, 850 insertions, 11 deletions
diff --git a/tests/integration/test_correlation.py b/tests/integration/test_correlation.py
index e67f58d..bdd9bce 100644
--- a/tests/integration/test_correlation.py
+++ b/tests/integration/test_correlation.py
@@ -80,7 +80,7 @@ class CorrelationIntegrationTest(TestCase):
         self.assertEqual(mock_compute_corr.call_count, 1)
         self.assertEqual(response.status_code, 200)
 
-    @mock.patch("gn3.api.correlation.compute_all_tissue_correlation")
+    @mock.patch("gn3.api.correlation.compute_tissue_correlation")
     def test_tissue_correlation(self, mock_tissue_corr):
         """Test api/correlation/tissue_corr/{corr_method}"""
         mock_tissue_corr.return_value = {}
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py
index b1bc6ef..fc52ec1 100644
--- a/tests/unit/computations/test_correlation.py
+++ b/tests/unit/computations/test_correlation.py
@@ -1,5 +1,4 @@
 """Module contains the tests for correlation"""
-import unittest
 from unittest import TestCase
 from unittest import mock
 
@@ -16,9 +15,10 @@ from gn3.computations.correlations import fetch_lit_correlation_data
 from gn3.computations.correlations import query_formatter
 from gn3.computations.correlations import map_to_mouse_gene_id
 from gn3.computations.correlations import compute_all_lit_correlation
-from gn3.computations.correlations import compute_all_tissue_correlation
+from gn3.computations.correlations import compute_tissue_correlation
 from gn3.computations.correlations import map_shared_keys_to_values
 from gn3.computations.correlations import process_trait_symbol_dict
+from gn3.computations.correlations2 import compute_correlation
 
 
 class QueryableMixin:
@@ -172,7 +172,6 @@ class TestCorrelation(TestCase):
         self.assertEqual(results, (filtered_this_samplelist,
                                    filtered_target_samplelist))
 
-    @unittest.skip("Test needs to be refactored ")
     @mock.patch("gn3.computations.correlations.compute_sample_r_correlation")
     @mock.patch("gn3.computations.correlations.filter_shared_sample_keys")
     def test_compute_all_sample(self, filter_shared_samples, sample_r_corr):
@@ -180,7 +179,7 @@ class TestCorrelation(TestCase):
 
         filter_shared_samples.return_value = (["1.23", "6.565", "6.456"], [
             "6.266", "6.565", "6.456"])
-        sample_r_corr.return_value = ([-1.0, 0.9, 6])
+        sample_r_corr.return_value = (["1419792_at", -1.0, 0.9, 6])
 
         this_trait_data = {
             "trait_id": "1455376_at",
@@ -203,13 +202,14 @@ class TestCorrelation(TestCase):
             }
         ]
 
-        sample_all_results = [{"1419792_at": {"corr_coeffient": -1.0,
+        sample_all_results = [{"1419792_at": {"corr_coefficient": -1.0,
                                               "p_value": 0.9,
                                               "num_overlap": 6}}]
 
         self.assertEqual(compute_all_sample_correlation(
             this_trait=this_trait_data, target_dataset=traits_dataset), sample_all_results)
         sample_r_corr.assert_called_once_with(
+            trait_name='1419792_at',
             corr_method="pearson", trait_vals=['1.23', '6.565', '6.456'],
             target_samples_vals=['6.266', '6.565', '6.456'])
         filter_shared_samples.assert_called_once_with(
@@ -406,17 +406,17 @@ class TestCorrelation(TestCase):
         target_tissue_data = {"trait_symbol_dict": target_trait_symbol,
                               "symbol_tissue_vals_dict": target_symbol_tissue_vals}
 
-        mock_tissue_corr.side_effect = [{"tissue_corr": -0.5, "tissue_p_val": 0.9,
-                                         "tissue_number": 3},
-                                        {"tissue_corr": 1.11, "tissue_p_val": 0.2,
-                                         "tissue_number": 3}]
+        mock_tissue_corr.side_effect = [{"1418702_a_at": {"tissue_corr": -0.5, "tissue_p_val": 0.9,
+                                                          "tissue_number": 3}},
+                                        {"1412_at": {"tissue_corr": 1.11, "tissue_p_val": 0.2,
+                                                     "tissue_number": 3}}]
 
         expected_results = [{"1412_at":
                              {"tissue_corr": 1.11, "tissue_p_val": 0.2, "tissue_number": 3}},
                             {"1418702_a_at":
                              {"tissue_corr": -0.5, "tissue_p_val": 0.9, "tissue_number": 3}}]
 
-        results = compute_all_tissue_correlation(
+        results = compute_tissue_correlation(
             primary_tissue_dict=primary_tissue_dict,
             target_tissues_data=target_tissue_data,
             corr_method="pearson")
@@ -464,3 +464,31 @@ class TestCorrelation(TestCase):
             trait_symbol_dict, tissue_values_dict)
 
         self.assertEqual(results, [expected_results])
+
+    def test_compute_correlation(self):
+        """Test that the new correlation function works the same as the original
+        from genenetwork1."""
+        for dbdata, userdata, expected in [
+                [[None, None, None, None, None, None, None, None, None, None],
+                 [None, None, None, None, None, None, None, None, None, None],
+                 (0.0, 0)],
+                [[None, None, None, None, None, None, None, None, None, 0],
+                 [None, None, None, None, None, None, None, None, None, None],
+                 (0.0, 0)],
+                [[None, None, None, None, None, None, None, None, None, 0],
+                 [None, None, None, None, None, None, None, None, None, 0],
+                 (0.0, 1)],
+                [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                 (0, 10)],
+                [[9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87],
+                 [9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87],
+                 (0.9999999999999998, 10)],
+                [[9.3, 2.2, 5.4, 7.2, 6.4, 7.6, 3.8, 1.8, 8.4, 0.2],
+                 [0.6, 3.97, 5.82, 8.21, 1.65, 4.55, 6.72, 9.5, 7.33, 2.34],
+                 (-0.12720361919462056, 10)],
+                [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+                 [None, None, None, None, 2, None, None, 3, None, None],
+                 (0.0, 2)]]:
+            with self.subTest(dbdata=dbdata, userdata=userdata):
+                self.assertEqual(compute_correlation(
+                    dbdata, userdata), expected)
diff --git a/tests/unit/computations/test_heatmap.py b/tests/unit/computations/test_heatmap.py
new file mode 100644
index 0000000..650cb45
--- /dev/null
+++ b/tests/unit/computations/test_heatmap.py
@@ -0,0 +1,143 @@
+"""Module contains tests for gn3.computations.heatmap"""
+from unittest import TestCase
+from gn3.computations.heatmap import cluster_traits, export_trait_data
+
+strainlist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"]
+trait_data = {
+    "mysqlid": 36688172,
+    "data": {
+        "B6cC3-1": {"strain_name": "B6cC3-1", "value": 7.51879, "variance": None, "ndata": None},
+        "BXD1": {"strain_name": "BXD1", "value": 7.77141, "variance": None, "ndata": None},
+        "BXD12": {"strain_name": "BXD12", "value": 8.39265, "variance": None, "ndata": None},
+        "BXD16": {"strain_name": "BXD16", "value": 8.17443, "variance": None, "ndata": None},
+        "BXD19": {"strain_name": "BXD19", "value": 8.30401, "variance": None, "ndata": None},
+        "BXD2": {"strain_name": "BXD2", "value": 7.80944, "variance": None, "ndata": None},
+        "BXD21": {"strain_name": "BXD21", "value": 8.93809, "variance": None, "ndata": None},
+        "BXD24": {"strain_name": "BXD24", "value": 7.99415, "variance": None, "ndata": None},
+        "BXD27": {"strain_name": "BXD27", "value": 8.12177, "variance": None, "ndata": None},
+        "BXD28": {"strain_name": "BXD28", "value": 7.67688, "variance": None, "ndata": None},
+        "BXD32": {"strain_name": "BXD32", "value": 7.79062, "variance": None, "ndata": None},
+        "BXD39": {"strain_name": "BXD39", "value": 8.27641, "variance": None, "ndata": None},
+        "BXD40": {"strain_name": "BXD40", "value": 8.18012, "variance": None, "ndata": None},
+        "BXD42": {"strain_name": "BXD42", "value": 7.82433, "variance": None, "ndata": None},
+        "BXD6": {"strain_name": "BXD6", "value": 8.09718, "variance": None, "ndata": None},
+        "BXH14": {"strain_name": "BXH14", "value": 7.97475, "variance": None, "ndata": None},
+        "BXH19": {"strain_name": "BXH19", "value": 7.67223, "variance": None, "ndata": None},
+        "BXH2": {"strain_name": "BXH2", "value": 7.93622, "variance": None, "ndata": None},
+        "BXH22": {"strain_name": "BXH22", "value": 7.43692, "variance": None, "ndata": None},
+        "BXH4": {"strain_name": "BXH4", "value": 7.96336, "variance": None, "ndata": None},
+        "BXH6": {"strain_name": "BXH6", "value": 7.75132, "variance": None, "ndata": None},
+        "BXH7": {"strain_name": "BXH7", "value": 8.12927, "variance": None, "ndata": None},
+        "BXH8": {"strain_name": "BXH8", "value": 6.77338, "variance": None, "ndata": None},
+        "BXH9": {"strain_name": "BXH9", "value": 8.03836, "variance": None, "ndata": None},
+        "C3H/HeJ": {"strain_name": "C3H/HeJ", "value": 7.42795, "variance": None, "ndata": None},
+        "C57BL/6J": {"strain_name": "C57BL/6J", "value": 7.50606, "variance": None, "ndata": None},
+        "DBA/2J": {"strain_name": "DBA/2J", "value": 7.72588, "variance": None, "ndata": None}}}
+
+class TestHeatmap(TestCase):
+    """Class for testing heatmap computation functions"""
+
+    def test_export_trait_data_dtype(self):
+        """
+        Test `export_trait_data` with different values for the `dtype` keyword
+        argument
+        """
+        for dtype, expected in [
+                ["val", (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)],
+                ["var", (None, None, None, None, None, None)],
+                ["N", (None, None, None, None, None, None)],
+                ["all", (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)]]:
+            with self.subTest(dtype=dtype):
+                self.assertEqual(
+                    export_trait_data(trait_data, strainlist, dtype=dtype),
+                    expected)
+
+    def test_export_trait_data_dtype_all_flags(self):
+        """
+        Test `export_trait_data` with different values for the `dtype` keyword
+        argument and the different flags set up
+        """
+        for dtype, vflag, nflag, expected in [
+                ["val", False, False,
+                 (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)],
+                ["val", False, True,
+                 (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)],
+                ["val", True, False,
+                 (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)],
+                ["val", True, True,
+                 (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)],
+                ["var", False, False, (None, None, None, None, None, None)],
+                ["var", False, True, (None, None, None, None, None, None)],
+                ["var", True, False, (None, None, None, None, None, None)],
+                ["var", True, True, (None, None, None, None, None, None)],
+                ["N", False, False, (None, None, None, None, None, None)],
+                ["N", False, True, (None, None, None, None, None, None)],
+                ["N", True, False, (None, None, None, None, None, None)],
+                ["N", True, True, (None, None, None, None, None, None)],
+                ["all", False, False,
+                 (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)],
+                ["all", False, True,
+                 (7.51879, None, 7.77141, None, 8.39265, None, 8.17443, None,
+                  8.30401, None, 7.80944, None)],
+                ["all", True, False,
+                 (7.51879, None, 7.77141, None, 8.39265, None, 8.17443, None,
+                  8.30401, None, 7.80944, None)],
+                ["all", True, True,
+                 (7.51879, None, None, 7.77141, None, None, 8.39265, None, None,
+                  8.17443, None, None, 8.30401, None, None, 7.80944, None, None)]
+        ]:
+            with self.subTest(dtype=dtype, vflag=vflag, nflag=nflag):
+                self.assertEqual(
+                    export_trait_data(
+                        trait_data, strainlist, dtype=dtype, var_exists=vflag,
+                        n_exists=nflag),
+                    expected)
+
+    def test_cluster_traits(self):
+        """
+        Test that the clustering is working as expected.
+        """
+        traits_data_list = [
+            (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944),
+            (6.1427, 6.50588, 7.73705, 6.68328, 7.49293, 7.27398),
+            (8.4211, 8.30581, 9.24076, 8.51173, 9.18455, 8.36077),
+            (10.0904, 10.6509, 9.36716, 9.91202, 8.57444, 10.5731),
+            (10.188, 9.76652, 9.54813, 9.05074, 9.52319, 9.10505),
+            (6.74676, 7.01029, 7.54169, 6.48574, 7.01427, 7.26815),
+            (6.39359, 6.85321, 5.78337, 7.11141, 6.22101, 6.16544),
+            (6.84118, 7.08432, 7.59844, 7.08229, 7.26774, 7.24991),
+            (9.45215, 10.6943, 8.64719, 10.1592, 7.75044, 8.78615),
+            (7.04737, 6.87185, 7.58586, 6.92456, 6.84243, 7.36913)]
+        self.assertEqual(
+            cluster_traits(traits_data_list),
+            ((0.0, 0.20337048635536847, 0.16381088984330505, 1.7388553629398245,
+              1.5025235756329178, 0.6952839500255574, 1.271661230252733,
+              0.2100487290977544, 1.4699690641062024, 0.7934461515867415),
+             (0.20337048635536847, 0.0, 0.2198321044997198, 1.5753041735592204,
+              1.4815755944537086, 0.26087293140686374, 1.6939790104301427,
+              0.06024619831474998, 1.7430082449189215, 0.4497104244247795),
+             (0.16381088984330505, 0.2198321044997198, 0.0, 1.9073926868549234,
+              1.0396738891139845, 0.5278328671176757, 1.6275069061182947,
+              0.2636503792482082, 1.739617877037615, 0.7127042590637039),
+             (1.7388553629398245, 1.5753041735592204, 1.9073926868549234, 0.0,
+              0.9936846292920328, 1.1169999189889366, 0.6007483980555253,
+              1.430209221053372, 0.25879514152086425, 0.9313185954797953),
+             (1.5025235756329178, 1.4815755944537086, 1.0396738891139845,
+              0.9936846292920328, 0.0, 1.027827186339337, 1.1441743109173244,
+              1.4122477962364253, 0.8968250491499363, 1.1683723389247052),
+             (0.6952839500255574, 0.26087293140686374, 0.5278328671176757,
+              1.1169999189889366, 1.027827186339337, 0.0, 1.8420471110023269,
+              0.19179284676938602, 1.4875072385631605, 0.23451785425383564),
+             (1.271661230252733, 1.6939790104301427, 1.6275069061182947,
+              0.6007483980555253, 1.1441743109173244, 1.8420471110023269, 0.0,
+              1.6540234785929928, 0.2140799896286565, 1.7413442197913358),
+             (0.2100487290977544, 0.06024619831474998, 0.2636503792482082,
+              1.430209221053372, 1.4122477962364253, 0.19179284676938602,
+              1.6540234785929928, 0.0, 1.5225640692832796, 0.33370067057028485),
+             (1.4699690641062024, 1.7430082449189215, 1.739617877037615,
+              0.25879514152086425, 0.8968250491499363, 1.4875072385631605,
+              0.2140799896286565, 1.5225640692832796, 0.0, 1.3256191648260216),
+             (0.7934461515867415, 0.4497104244247795, 0.7127042590637039,
+              0.9313185954797953, 1.1683723389247052, 0.23451785425383564,
+              1.7413442197913358, 0.33370067057028485, 1.3256191648260216,
+              0.0)))
diff --git a/tests/unit/computations/test_slink.py b/tests/unit/computations/test_slink.py
new file mode 100644
index 0000000..995393b
--- /dev/null
+++ b/tests/unit/computations/test_slink.py
@@ -0,0 +1,311 @@
+"""Module contains tests for slink"""
+from unittest import TestCase
+
+from gn3.computations.slink import slink
+from gn3.computations.slink import nearest
+from gn3.computations.slink import LengthError
+from gn3.computations.slink import MirrorError
+
+class TestSlink(TestCase):
+    """Class for testing slink functions"""
+
+    def test_nearest_expects_list_of_lists(self):
+        """Test that function only accepts a list of lists."""
+        # This might be better handled with type-hints and mypy
+        for item in [9, "some string", 5.432,
+                     [1, 2, 3], ["test", 7.4]]:
+            with self.subTest(item=item):
+                with self.assertRaises(ValueError, msg="Expected list or tuple"):
+                    nearest(item, 1, 1)
+
+    def test_nearest_does_not_allow_empty_lists(self):
+        """Test that function does not accept an empty list, or any of the child
+        lists to be empty."""
+        for lst in [[],
+                    [[], []],
+                    [[], [], []],
+                    [[0, 1, 2], [], [1, 2, 0]]]:
+            with self.subTest(lst=lst):
+                with self.assertRaises(ValueError):
+                    nearest(lst, 1, 1)
+
+    def test_nearest_expects_children_are_same_length_as_parent(self):
+        """Test that children lists are same length as parent list."""
+        for lst in [[[0, 1]],
+                    [[0, 1, 2], [3, 4, 5]],
+                    [[0, 1, 2, 3], [4, 5, 6], [7, 8, 9, 0]],
+                    [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [1, 2, 3, 4, 5], [2, 3],
+                     [3, 4, 5, 6, 7]]]:
+            with self.subTest(lst=lst):
+                with self.assertRaises(LengthError):
+                    nearest(lst, 1, 1)
+
+    def test_nearest_expects_member_is_zero_distance_from_itself(self):
+        """Test that distance of a member from itself is zero"""
+        for lst in [[[1]],
+                    [[1, 2], [3, 4]],
+                    [1, 0, 0], [0, 0, 5], [0, 3, 4],
+                    [0, 0, 0, 0], [0, 0, 3, 3], [0, 1, 2, 3], [0, 3, 2, 0]]:
+            with self.subTest(lst=lst):
+                with self.assertRaises(ValueError):
+                    nearest(lst, 1, 1)
+
+    def test_nearest_expects_distance_atob_is_equal_to_distance_btoa(self):
+        """Test that the distance from member A to member B is the same as that
+        from member B to member A."""
+        for lst in [[[0, 1], [2, 0]],
+                    [[0, 1, 2], [1, 0, 3], [9, 7, 0]],
+                    [[0, 1, 2, 3], [7, 0, 2, 3], [2, 3, 0, 1], [8, 9, 5, 0]]]:
+            with self.subTest(lst=lst):
+                with self.assertRaises(MirrorError):
+                    nearest(lst, 1, 1)
+
+    def test_nearest_expects_zero_or_positive_distances(self):
+        """Test that all distances are either zero, or greater than zero."""
+        # Based on:
+        # https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/slink.py#L87-L89
+        for lst in [[[0, -1, 2, 3], [-1, 0, 3, 4], [2, 3, 0, 5], [3, 4, 5, 0]],
+                    [[0, 1, -2, 3], [1, 0, 3, 4], [-2, 3, 0, 5], [3, 4, 5, 0]],
+                    [[0, 1, 2, 3], [1, 0, -3, 4], [2, -3, 0, 5], [3, 4, 5, 0]],
+                    [[0, 1, 2, -3], [1, 0, 3, 4], [2, 3, 0, 5], [-3, 4, 5, 0]],
+                    [[0, 1, 2, 3], [1, 0, 3, -4], [2, 3, 0, 5], [3, -4, 5, 0]],
+                    [[0, 1, 2, 3], [1, 0, 3, 4], [2, 3, 0, -5], [3, 4, -5, 0]]]:
+            with self.subTest(lst=lst):
+                with self.assertRaises(ValueError, msg="Distances should be positive."):
+                    nearest(lst, 1, 1)
+
+    def test_nearest_returns_shortest_distance_given_coordinates_to_both_group_members(self):
+        """Test that the shortest distance is returned."""
+        # This test is named wrong - at least I think it is, from the expected results
+        # This tests distance when both `i`, and `j` are integers
+        # We still need to add tests for when (either one/both) (is/are) not (an) integer(s)
+        # https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/slink.py#L39-L40
+        for lst, i, j, expected in [
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 0, 0, 0],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 0, 1, 9],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 0, 2, 3],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 0, 3, 6],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 0, 4, 11],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 1, 0, 9],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 1, 1, 0],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 1, 2, 7],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 1, 3, 5],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 1, 4, 10],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 2, 0, 3],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 2, 1, 7],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 2, 2, 0],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 2, 3, 9],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 2, 4, 2],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 3, 0, 6],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 3, 1, 5],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 3, 2, 9],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 3, 3, 0],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 3, 4, 8],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 4, 0, 11],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 4, 1, 10],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 4, 2, 2],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 4, 3, 8],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 4, 4, 0],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 0, 0, 0],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 0, 1, 9],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 0, 2, 5.5],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 0, 3, 6],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 0, 4, 11],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 1, 0, 9],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 1, 1, 0],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 1, 2, 7],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 1, 3, 5],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 1, 4, 10],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 2, 0, 5.5],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 2, 1, 7],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 2, 2, 0],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 2, 3, 9],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 2, 4, 2],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 3, 0, 6],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 3, 1, 5],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 3, 2, 9],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 3, 3, 0],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 3, 4, 3],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 4, 0, 11],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 4, 1, 10],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 4, 2, 2],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 4, 3, 3],
+                [[[0, 9, 5.5, 6, 11], [9, 0, 7, 5, 10], [5.5, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 3], [11, 10, 2, 3, 0]],
+                 4, 4, 0]]:
+            with self.subTest(lst=lst):
+                self.assertEqual(nearest(lst, i, j), expected)
+
+    def test_nearest_gives_shortest_distance_between_list_of_members_and_member(self):
+        """Test that the shortest distance is returned."""
+        for members_distances, members_list, member_coordinate, expected_distance in [
+                [[[0, 9, 3], [9, 0, 7], [3, 7, 0]], (0, 2, 3), 1, 7],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], [0, 1, 2, 3, 4], 3, 0],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], [0, 1, 2, 4], 3, 5],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], [0, 2, 4], 3, 6],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]], [2, 4], 3, 9]]:
+            with self.subTest(
+                    members_distances=members_distances,
+                    members_list=members_list,
+                    member_coordinate=member_coordinate,
+                    expected_distance=expected_distance):
+                self.assertEqual(
+                    nearest(
+                        members_distances, members_list, member_coordinate),
+                    expected_distance)
+                self.assertEqual(
+                    nearest(
+                        members_distances, member_coordinate, members_list),
+                    expected_distance)
+
+    def test_nearest_returns_shortest_distance_given_two_lists_of_members(self):
+        """Test that the shortest distance is returned."""
+        for members_distances, members_list, member_list2, expected_distance in [
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], 0],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 [0, 1], [3, 4], 6],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 [0, 1], [2, 3, 4], 3],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8], [11, 10, 2, 8, 0]],
+                 [0, 2], [3, 4], 6]]:
+            with self.subTest(
+                    members_distances=members_distances,
+                    members_list=members_list,
+                    member_list2=member_list2,
+                    expected_distance=expected_distance):
+                self.assertEqual(
+                    nearest(
+                        members_distances, members_list, member_list2),
+                    expected_distance)
+                self.assertEqual(
+                    nearest(
+                        members_distances, member_list2, members_list),
+                    expected_distance)
+
+    def test_slink_wrong_data_returns_empty_list(self):
+        """Test that empty list is returned for wrong data."""
+        for data in [1, "test", [], 2.945, nearest, [0]]:
+            with self.subTest(data=data):
+                self.assertEqual(slink(data), [])
+
+    def test_slink_with_data(self):
+        """Test slink with example data, and expected results for each data
+        sample."""
+        for data, expected in [
+                [[[0, 9], [9, 0]], [0, 1, 9]],
+                [[[0, 9, 3], [9, 0, 7], [3, 7, 0]], [(0, 2, 3), 1, 7]],
+                [[[0, 9, 3, 6], [9, 0, 7, 5], [3, 7, 0, 9], [6, 5, 9, 0]],
+                 [(0, 2, 3), (1, 3, 5), 6]],
+                [[[0, 9, 3, 6, 11], [9, 0, 7, 5, 10], [3, 7, 0, 9, 2],
+                  [6, 5, 9, 0, 8],
+                  [11, 10, 2, 8, 0]],
+                 [(0, (2, 4, 2), 3), (1, 3, 5), 6]]]:
+            with self.subTest(data=data):
+                self.assertEqual(slink(data), expected)
diff --git a/tests/unit/db/test_datasets.py b/tests/unit/db/test_datasets.py
new file mode 100644
index 0000000..38de0e2
--- /dev/null
+++ b/tests/unit/db/test_datasets.py
@@ -0,0 +1,133 @@
+"""Tests for gn3/db/datasets.py"""
+
+from unittest import mock, TestCase
+from gn3.db.datasets import (
+    retrieve_dataset_name,
+    retrieve_riset_fields,
+    retrieve_geno_riset_fields,
+    retrieve_publish_riset_fields,
+    retrieve_probeset_riset_fields)
+
+class TestDatasetsDBFunctions(TestCase):
+    """Test cases for datasets functions."""
+
+    def test_retrieve_dataset_name(self):
+        """Test that the function is called correctly."""
+        for trait_type, thresh, trait_name, dataset_name, columns, table in [
+                ["ProbeSet", 9, "probesetTraitName", "probesetDatasetName",
+                 "Id, Name, FullName, ShortName, DataScale", "ProbeSetFreeze"],
+                ["Geno", 3, "genoTraitName", "genoDatasetName",
+                 "Id, Name, FullName, ShortName", "GenoFreeze"],
+                ["Publish", 6, "publishTraitName", "publishDatasetName",
+                 "Id, Name, FullName, ShortName", "PublishFreeze"],
+                ["Temp", 4, "tempTraitName", "tempTraitName",
+                 "Id, Name, FullName, ShortName", "TempFreeze"]]:
+            db_mock = mock.MagicMock()
+            with self.subTest(trait_type=trait_type):
+                with db_mock.cursor() as cursor:
+                    cursor.fetchone.return_value = {}
+                    self.assertEqual(
+                        retrieve_dataset_name(
+                            trait_type, thresh, trait_name, dataset_name, db_mock),
+                        {})
+                    cursor.execute.assert_called_once_with(
+                        "SELECT {cols} "
+                        "FROM {table} "
+                        "WHERE public > %(threshold)s AND "
+                        "(Name = %(name)s "
+                        "OR FullName = %(name)s "
+                        "OR ShortName = %(name)s)".format(
+                            table=table, cols=columns),
+                        {"threshold": thresh, "name": dataset_name})
+
+    def test_retrieve_probeset_riset_fields(self):
+        """
+        Test that the `riset` and `riset_id` fields are retrieved appropriately
+        for the 'ProbeSet' trait type.
+        """
+        for trait_name, expected in [
+                ["testProbeSetName", {}]]:
+            db_mock = mock.MagicMock()
+            with self.subTest(trait_name=trait_name, expected=expected):
+                with db_mock.cursor() as cursor:
+                    cursor.execute.return_value = ()
+                    self.assertEqual(
+                        retrieve_probeset_riset_fields(trait_name, db_mock),
+                        expected)
+                    cursor.execute.assert_called_once_with(
+                        (
+                            "SELECT InbredSet.Name, InbredSet.Id"
+                            " FROM InbredSet, ProbeSetFreeze, ProbeFreeze"
+                            " WHERE ProbeFreeze.InbredSetId = InbredSet.Id"
+                            " AND ProbeFreeze.Id = ProbeSetFreeze.ProbeFreezeId"
+                            " AND ProbeSetFreeze.Name = %(name)s"),
+                        {"name": trait_name})
+
+    def test_retrieve_riset_fields(self):
+        """
+        Test that the riset fields are set up correctly for the different trait
+        types.
+        """
+        for trait_type, trait_name, dataset_info, expected in [
+                ["Publish", "pubTraitName01", {"dataset_name": "pubDBName01"},
+                 {"dataset_name": "pubDBName01", "riset": ""}],
+                ["ProbeSet", "prbTraitName01", {"dataset_name": "prbDBName01"},
+                 {"dataset_name": "prbDBName01", "riset": ""}],
+                ["Geno", "genoTraitName01", {"dataset_name": "genoDBName01"},
+                 {"dataset_name": "genoDBName01", "riset": ""}],
+                ["Temp", "tempTraitName01", {}, {"riset": ""}],
+                ]:
+            db_mock = mock.MagicMock()
+            with self.subTest(
+                    trait_type=trait_type, trait_name=trait_name,
+                    dataset_info=dataset_info):
+                with db_mock.cursor() as cursor:
+                    cursor.execute.return_value = ("riset_name", 0)
+                    self.assertEqual(
+                        retrieve_riset_fields(
+                            trait_type, trait_name, dataset_info, db_mock),
+                        expected)
+
+    def test_retrieve_publish_riset_fields(self):
+        """
+        Test that the `riset` and `riset_id` fields are retrieved appropriately
+        for the 'Publish' trait type.
+        """
+        for trait_name, expected in [
+                ["testPublishName", {}]]:
+            db_mock = mock.MagicMock()
+            with self.subTest(trait_name=trait_name, expected=expected):
+                with db_mock.cursor() as cursor:
+                    cursor.execute.return_value = ()
+                    self.assertEqual(
+                        retrieve_publish_riset_fields(trait_name, db_mock),
+                        expected)
+                    cursor.execute.assert_called_once_with(
+                        (
+                            "SELECT InbredSet.Name, InbredSet.Id"
+                            " FROM InbredSet, PublishFreeze"
+                            " WHERE PublishFreeze.InbredSetId = InbredSet.Id"
+                            " AND PublishFreeze.Name = %(name)s"),
+                        {"name": trait_name})
+
+    def test_retrieve_geno_riset_fields(self):
+        """
+        Test that the `riset` and `riset_id` fields are retrieved appropriately
+        for the 'Geno' trait type.
+        """
+        for trait_name, expected in [
+                ["testGenoName", {}]]:
+            db_mock = mock.MagicMock()
+            with self.subTest(trait_name=trait_name, expected=expected):
+                with db_mock.cursor() as cursor:
+                    cursor.execute.return_value = ()
+                    self.assertEqual(
+                        retrieve_geno_riset_fields(trait_name, db_mock),
+                        expected)
+                    cursor.execute.assert_called_once_with(
+                        (
+                            "SELECT InbredSet.Name, InbredSet.Id"
+                            " FROM InbredSet, GenoFreeze"
+                            " WHERE GenoFreeze.InbredSetId = InbredSet.Id"
+                            " AND GenoFreeze.Name = %(name)s"),
+                        {"name": trait_name})
diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py
new file mode 100644
index 0000000..ee98893
--- /dev/null
+++ b/tests/unit/db/test_traits.py
@@ -0,0 +1,224 @@
+"""Tests for gn3/db/traits.py"""
+from unittest import mock, TestCase
+from gn3.db.traits import (
+    build_trait_name,
+    set_haveinfo_field,
+    update_sample_data,
+    retrieve_trait_info,
+    set_confidential_field,
+    set_homologene_id_field,
+    retrieve_geno_trait_info,
+    retrieve_temp_trait_info,
+    retrieve_publish_trait_info,
+    retrieve_probeset_trait_info)
+
+class TestTraitsDBFunctions(TestCase):
+    "Test cases for traits functions"
+
+    def test_retrieve_publish_trait_info(self):
+        """Test retrieval of type `Publish` traits."""
+        db_mock = mock.MagicMock()
+        with db_mock.cursor() as cursor:
+            cursor.fetchone.return_value = tuple()
+            trait_source = {
+                "trait_name": "PublishTraitName", "trait_dataset_id": 1}
+            self.assertEqual(
+                retrieve_publish_trait_info(trait_source, db_mock), {})
+            cursor.execute.assert_called_once_with(
+                ("SELECT "
+                 "PublishXRef.Id, Publication.PubMed_ID,"
+                 " Phenotype.Pre_publication_description,"
+                 " Phenotype.Post_publication_description,"
+                 " Phenotype.Original_description,"
+                 " Phenotype.Pre_publication_abbreviation,"
+                 " Phenotype.Post_publication_abbreviation,"
+                 " Phenotype.Lab_code, Phenotype.Submitter, Phenotype.Owner,"
+                 " Phenotype.Authorized_Users,"
+                 " CAST(Publication.Authors AS BINARY),"
+                 " Publication.Title, Publication.Abstract,"
+                 " Publication.Journal,"
+                 " Publication.Volume, Publication.Pages, Publication.Month,"
+                 " Publication.Year, PublishXRef.Sequence, Phenotype.Units,"
+                 " PublishXRef.comments"
+                 " FROM"
+                 " PublishXRef, Publication, Phenotype, PublishFreeze"
+                 " WHERE"
+                 " PublishXRef.Id = %(trait_name)s"
+                 " AND Phenotype.Id = PublishXRef.PhenotypeId"
+                 " AND Publication.Id = PublishXRef.PublicationId"
+                 " AND PublishXRef.InbredSetId = PublishFreeze.InbredSetId"
+                 " AND PublishFreeze.Id =%(trait_dataset_id)s"),
+                trait_source)
+
+    def test_retrieve_probeset_trait_info(self):
+        """Test retrieval of type `Probeset` traits."""
+        db_mock = mock.MagicMock()
+        with db_mock.cursor() as cursor:
+            cursor.fetchone.return_value = tuple()
+            trait_source = {
+                "trait_name": "ProbeSetTraitName",
+                "trait_dataset_name": "ProbeSetDatasetTraitName"}
+            self.assertEqual(
+                retrieve_probeset_trait_info(trait_source, db_mock), {})
+            cursor.execute.assert_called_once_with(
+                (
+                    "SELECT "
+                    "ProbeSet.name, ProbeSet.symbol, ProbeSet.description, "
+                    "ProbeSet.probe_target_description, ProbeSet.chr, "
+                    "ProbeSet.mb, ProbeSet.alias, ProbeSet.geneid, "
+                    "ProbeSet.genbankid, ProbeSet.unigeneid, ProbeSet.omim, "
+                    "ProbeSet.refseq_transcriptid, ProbeSet.blatseq, "
+                    "ProbeSet.targetseq, ProbeSet.chipid, ProbeSet.comments, "
+                    "ProbeSet.strand_probe, ProbeSet.strand_gene, "
+                    "ProbeSet.probe_set_target_region, ProbeSet.proteinid, "
+                    "ProbeSet.probe_set_specificity, "
+                    "ProbeSet.probe_set_blat_score, "
+                    "ProbeSet.probe_set_blat_mb_start, "
+                    "ProbeSet.probe_set_blat_mb_end, "
+                    "ProbeSet.probe_set_strand, ProbeSet.probe_set_note_by_rw, "
+                    "ProbeSet.flag "
+                    "FROM "
+                    "ProbeSet, ProbeSetFreeze, ProbeSetXRef "
+                    "WHERE "
+                    "ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id "
+                    "AND ProbeSetXRef.ProbeSetId = ProbeSet.Id "
+                    "AND ProbeSetFreeze.Name = %(trait_dataset_name)s "
+                    "AND ProbeSet.Name = %(trait_name)s"), trait_source)
+
+    def test_retrieve_geno_trait_info(self):
+        """Test retrieval of type `Geno` traits."""
+        db_mock = mock.MagicMock()
+        with db_mock.cursor() as cursor:
+            cursor.fetchone.return_value = tuple()
+            trait_source = {
+                "trait_name": "GenoTraitName",
+                "trait_dataset_name": "GenoDatasetTraitName"}
+            self.assertEqual(
+                retrieve_geno_trait_info(trait_source, db_mock), {})
+            cursor.execute.assert_called_once_with(
+                (
+                    "SELECT "
+                    "Geno.name, Geno.chr, Geno.mb, Geno.source2, Geno.sequence "
+                    "FROM "
+                    "Geno, GenoFreeze, GenoXRef "
+                    "WHERE "
+                    "GenoXRef.GenoFreezeId = GenoFreeze.Id "
+                    "AND GenoXRef.GenoId = Geno.Id "
+                    "AND GenoFreeze.Name = %(trait_dataset_name)s "
+                    "AND Geno.Name = %(trait_name)s"),
+                trait_source)
+
+    def test_retrieve_temp_trait_info(self):
+        """Test retrieval of type `Temp` traits."""
+        db_mock = mock.MagicMock()
+        with db_mock.cursor() as cursor:
+            cursor.fetchone.return_value = tuple()
+            trait_source = {"trait_name": "TempTraitName"}
+            self.assertEqual(
+                retrieve_temp_trait_info(trait_source, db_mock), {})
+            cursor.execute.assert_called_once_with(
+                "SELECT name, description FROM Temp WHERE Name = %(trait_name)s",
+                trait_source)
+
+    def test_build_trait_name_with_good_fullnames(self):
+        """
+        Check that the name is built correctly.
+        """
+        for fullname, expected in [
+                ["testdb::testname",
+                 {"db": {"dataset_name": "testdb", "dataset_type": "ProbeSet"},
+                  "trait_name": "testname", "cellid": "",
+                  "trait_fullname": "testdb::testname"}],
+                ["testdb::testname::testcell",
+                 {"db": {"dataset_name": "testdb", "dataset_type": "ProbeSet"},
+                  "trait_name": "testname", "cellid": "testcell",
+                  "trait_fullname": "testdb::testname::testcell"}]]:
+            with self.subTest(fullname=fullname):
+                self.assertEqual(build_trait_name(fullname), expected)
+
+    def test_build_trait_name_with_bad_fullnames(self):
+        """
+        Check that an exception is raised if the full name format is wrong.
+        """
+        for fullname in ["", "test", "test:test"]:
+            with self.subTest(fullname=fullname):
+                with self.assertRaises(AssertionError, msg="Name format error"):
+                    build_trait_name(fullname)
+
+    def test_retrieve_trait_info(self):
+        """Test that information on traits is retrieved as appropriate."""
+        for threshold, trait_fullname, expected in [
+                [9, "pubDb::PublishTraitName::pubCell", {"haveinfo": 0}],
+                [5, "prbDb::ProbeSetTraitName::prbCell", {"haveinfo": 0}],
+                [12, "genDb::GenoTraitName", {"haveinfo": 0}],
+                [6, "tmpDb::TempTraitName", {"haveinfo": 0}]]:
+            db_mock = mock.MagicMock()
+            with self.subTest(trait_fullname=trait_fullname):
+                with db_mock.cursor() as cursor:
+                    cursor.fetchone.return_value = tuple()
+                    self.assertEqual(
+                        retrieve_trait_info(
+                            threshold, trait_fullname, db_mock),
+                        expected)
+
+    def test_update_sample_data(self):
+        """Test that the SQL queries when calling update_sample_data are called with
+        the right calls.
+
+        """
+        db_mock = mock.MagicMock()
+
+        STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
+        PUBLISH_DATA_SQL: str = ("UPDATE PublishData SET value = %s "
+                                 "WHERE StrainId = %s AND Id = %s")
+        PUBLISH_SE_SQL: str = ("UPDATE PublishSE SET error = %s "
+                               "WHERE StrainId = %s AND DataId = %s")
+        N_STRAIN_SQL: str = ("UPDATE NStrain SET count = %s "
+                             "WHERE StrainId = %s AND DataId = %s")
+
+        with db_mock.cursor() as cursor:
+            type(cursor).rowcount = 1
+            self.assertEqual(update_sample_data(
+                conn=db_mock, strain_name="BXD11",
+                strain_id=10, publish_data_id=8967049,
+                value=18.7, error=2.3, count=2),
+                             (1, 1, 1, 1))
+            cursor.execute.assert_has_calls(
+                [mock.call(STRAIN_ID_SQL, ('BXD11', 10)),
+                 mock.call(PUBLISH_DATA_SQL, (18.7, 10, 8967049)),
+                 mock.call(PUBLISH_SE_SQL, (2.3, 10, 8967049)),
+                 mock.call(N_STRAIN_SQL, (2, 10, 8967049))]
+            )
+
+    def test_set_haveinfo_field(self):
+        """Test that the `haveinfo` field is set up correctly"""
+        for trait_info, expected in [
+                [{}, {"haveinfo": 0}],
+                [{"k1": "v1"}, {"k1": "v1", "haveinfo": 1}]]:
+            with self.subTest(trait_info=trait_info, expected=expected):
+                self.assertEqual(set_haveinfo_field(trait_info), expected)
+
+    def test_set_homologene_id_field(self):
+        """Test that the `homologene_id` field is set up correctly"""
+        for trait_type, trait_info, expected in [
+                ["Publish", {}, {"homologeneid": None}],
+                ["ProbeSet", {}, {"homologeneid": None}],
+                ["Geno", {}, {"homologeneid": None}],
+                ["Temp", {}, {"homologeneid": None}]]:
+            db_mock = mock.MagicMock()
+            with self.subTest(trait_info=trait_info, expected=expected):
+                with db_mock.cursor() as cursor:
+                    cursor.fetchone.return_value = ()
+                    self.assertEqual(
+                        set_homologene_id_field(trait_type, trait_info, db_mock), expected)
+
+    def test_set_confidential_field(self):
+        """Test that the `confidential` field is set up correctly"""
+        for trait_type, trait_info, expected in [
+                ["Publish", {}, {"confidential": 0}],
+                ["ProbeSet", {}, {}],
+                ["Geno", {}, {}],
+                ["Temp", {}, {}]]:
+            with self.subTest(trait_info=trait_info, expected=expected):
+                self.assertEqual(
+                    set_confidential_field(trait_type, trait_info), expected)