about summary refs log tree commit diff
path: root/tests/unit/db
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unit/db')
-rw-r--r--tests/unit/db/test_audit.py3
-rw-r--r--tests/unit/db/test_correlation.py100
-rw-r--r--tests/unit/db/test_datasets.py35
-rw-r--r--tests/unit/db/test_db.py8
-rw-r--r--tests/unit/db/test_genotypes.py6
-rw-r--r--tests/unit/db/test_genotypes2.py13
-rw-r--r--tests/unit/db/test_sample_data.py188
-rw-r--r--tests/unit/db/test_species.py5
-rw-r--r--tests/unit/db/test_traits.py60
9 files changed, 362 insertions, 56 deletions
diff --git a/tests/unit/db/test_audit.py b/tests/unit/db/test_audit.py
index 7480169..884afc6 100644
--- a/tests/unit/db/test_audit.py
+++ b/tests/unit/db/test_audit.py
@@ -3,6 +3,8 @@ import json
 from unittest import TestCase
 from unittest import mock
 
+import pytest
+
 from gn3.db import insert
 from gn3.db.metadata_audit import MetadataAudit
 
@@ -10,6 +12,7 @@ from gn3.db.metadata_audit import MetadataAudit
 class TestMetadatAudit(TestCase):
     """Test cases for fetching chromosomes"""
 
+    @pytest.mark.unit_test
     def test_insert_into_metadata_audit(self):
         """Test that data is inserted correctly in the audit table
 
diff --git a/tests/unit/db/test_correlation.py b/tests/unit/db/test_correlation.py
new file mode 100644
index 0000000..5afe55f
--- /dev/null
+++ b/tests/unit/db/test_correlation.py
@@ -0,0 +1,100 @@
+"""
+Tests for the gn3.db.correlations module
+"""
+
+from unittest import TestCase
+
+import pytest
+
+from gn3.db.correlations import (
+    build_query_sgo_lit_corr,
+    build_query_tissue_corr)
+
+class TestCorrelation(TestCase):
+    """Test cases for correlation data fetching functions"""
+    maxDiff = None
+
+    @pytest.mark.unit_test
+    def test_build_query_sgo_lit_corr(self):
+        """
+        Test that the literature correlation query is built correctly.
+        """
+        self.assertEqual(
+            build_query_sgo_lit_corr(
+                "Probeset",
+                "temp_table_xy45i7wd",
+                "T1.value, T2.value, T3.value",
+                (("LEFT JOIN ProbesetData AS T1 "
+                  "ON T1.Id = ProbesetXRef.DataId "
+                  "AND T1.StrainId=%(T1_sample_id)s"),
+                 (
+                     "LEFT JOIN ProbesetData AS T2 "
+                     "ON T2.Id = ProbesetXRef.DataId "
+                     "AND T2.StrainId=%(T2_sample_id)s"),
+                 (
+                     "LEFT JOIN ProbesetData AS T3 "
+                     "ON T3.Id = ProbesetXRef.DataId "
+                     "AND T3.StrainId=%(T3_sample_id)s"))),
+            (("SELECT Probeset.Name, temp_table_xy45i7wd.value, "
+              "T1.value, T2.value, T3.value "
+              "FROM (Probeset, ProbesetXRef, ProbesetFreeze) "
+              "LEFT JOIN temp_table_xy45i7wd ON temp_table_xy45i7wd.GeneId2=ProbeSet.GeneId "
+              "LEFT JOIN ProbesetData AS T1 "
+              "ON T1.Id = ProbesetXRef.DataId "
+              "AND T1.StrainId=%(T1_sample_id)s "
+              "LEFT JOIN ProbesetData AS T2 "
+              "ON T2.Id = ProbesetXRef.DataId "
+              "AND T2.StrainId=%(T2_sample_id)s "
+              "LEFT JOIN ProbesetData AS T3 "
+              "ON T3.Id = ProbesetXRef.DataId "
+              "AND T3.StrainId=%(T3_sample_id)s "
+              "WHERE ProbeSet.GeneId IS NOT NULL "
+              "AND temp_table_xy45i7wd.value IS NOT NULL "
+              "AND ProbesetXRef.ProbesetFreezeId = ProbesetFreeze.Id "
+              "AND ProbesetFreeze.Name = %(db_name)s "
+              "AND Probeset.Id = ProbesetXRef.ProbesetId "
+              "ORDER BY Probeset.Id"),
+             2))
+
+    @pytest.mark.unit_test
+    def test_build_query_tissue_corr(self):
+        """
+        Test that the tissue correlation query is built correctly.
+        """
+        self.assertEqual(
+            build_query_tissue_corr(
+                "Probeset",
+                "temp_table_xy45i7wd",
+                "T1.value, T2.value, T3.value",
+                (("LEFT JOIN ProbesetData AS T1 "
+                  "ON T1.Id = ProbesetXRef.DataId "
+                  "AND T1.StrainId=%(T1_sample_id)s"),
+                 (
+                     "LEFT JOIN ProbesetData AS T2 "
+                     "ON T2.Id = ProbesetXRef.DataId "
+                     "AND T2.StrainId=%(T2_sample_id)s"),
+                 (
+                     "LEFT JOIN ProbesetData AS T3 "
+                     "ON T3.Id = ProbesetXRef.DataId "
+                     "AND T3.StrainId=%(T3_sample_id)s"))),
+            (("SELECT Probeset.Name, temp_table_xy45i7wd.Correlation, "
+              "temp_table_xy45i7wd.PValue, "
+              "T1.value, T2.value, T3.value "
+              "FROM (Probeset, ProbesetXRef, ProbesetFreeze) "
+              "LEFT JOIN temp_table_xy45i7wd ON temp_table_xy45i7wd.Symbol=ProbeSet.Symbol "
+              "LEFT JOIN ProbesetData AS T1 "
+              "ON T1.Id = ProbesetXRef.DataId "
+              "AND T1.StrainId=%(T1_sample_id)s "
+              "LEFT JOIN ProbesetData AS T2 "
+              "ON T2.Id = ProbesetXRef.DataId "
+              "AND T2.StrainId=%(T2_sample_id)s "
+              "LEFT JOIN ProbesetData AS T3 "
+              "ON T3.Id = ProbesetXRef.DataId "
+              "AND T3.StrainId=%(T3_sample_id)s "
+              "WHERE ProbeSet.Symbol IS NOT NULL "
+              "AND temp_table_xy45i7wd.Correlation IS NOT NULL "
+              "AND ProbesetXRef.ProbesetFreezeId = ProbesetFreeze.Id "
+              "AND ProbesetFreeze.Name = %(db_name)s "
+              "AND Probeset.Id = ProbesetXRef.ProbesetId "
+              "ORDER BY Probeset.Id"),
+             3))
diff --git a/tests/unit/db/test_datasets.py b/tests/unit/db/test_datasets.py
index 39f4af9..e4abd2f 100644
--- a/tests/unit/db/test_datasets.py
+++ b/tests/unit/db/test_datasets.py
@@ -1,6 +1,7 @@
 """Tests for gn3/db/datasets.py"""
 
 from unittest import mock, TestCase
+import pytest
 from gn3.db.datasets import (
     retrieve_dataset_name,
     retrieve_group_fields,
@@ -11,35 +12,36 @@ from gn3.db.datasets import (
 class TestDatasetsDBFunctions(TestCase):
     """Test cases for datasets functions."""
 
+    @pytest.mark.unit_test
     def test_retrieve_dataset_name(self):
         """Test that the function is called correctly."""
-        for trait_type, thresh, trait_name, dataset_name, columns, table in [
-                ["ProbeSet", 9, "probesetTraitName", "probesetDatasetName",
-                 "Id, Name, FullName, ShortName, DataScale", "ProbeSetFreeze"],
-                ["Geno", 3, "genoTraitName", "genoDatasetName",
-                 "Id, Name, FullName, ShortName", "GenoFreeze"],
-                ["Publish", 6, "publishTraitName", "publishDatasetName",
-                 "Id, Name, FullName, ShortName", "PublishFreeze"],
-                ["Temp", 4, "tempTraitName", "tempTraitName",
-                 "Id, Name, FullName, ShortName", "TempFreeze"]]:
+        for trait_type, thresh, dataset_name, columns, table, expected in [
+                ["ProbeSet", 9, "probesetDatasetName",
+                 "Id, Name, FullName, ShortName, DataScale", "ProbeSetFreeze",
+                 {"dataset_id": None, "dataset_name": "probesetDatasetName",
+                  "dataset_fullname": "probesetDatasetName"}],
+                ["Geno", 3, "genoDatasetName",
+                 "Id, Name, FullName, ShortName", "GenoFreeze", {}],
+                ["Publish", 6, "publishDatasetName",
+                 "Id, Name, FullName, ShortName", "PublishFreeze", {}]]:
             db_mock = mock.MagicMock()
             with self.subTest(trait_type=trait_type):
                 with db_mock.cursor() as cursor:
                     cursor.fetchone.return_value = {}
                     self.assertEqual(
                         retrieve_dataset_name(
-                            trait_type, thresh, trait_name, dataset_name, db_mock),
-                        {})
+                            trait_type, thresh, dataset_name, db_mock),
+                        expected)
                     cursor.execute.assert_called_once_with(
-                        "SELECT {cols} "
-                        "FROM {table} "
+                        f"SELECT {columns} "
+                        f"FROM {table} "
                         "WHERE public > %(threshold)s AND "
                         "(Name = %(name)s "
                         "OR FullName = %(name)s "
-                        "OR ShortName = %(name)s)".format(
-                            table=table, cols=columns),
+                        "OR ShortName = %(name)s)",
                         {"threshold": thresh, "name": dataset_name})
 
+    @pytest.mark.unit_test
     def test_retrieve_probeset_group_fields(self):
         """
         Test that the `group` and `group_id` fields are retrieved appropriately
@@ -63,6 +65,7 @@ class TestDatasetsDBFunctions(TestCase):
                             " AND ProbeSetFreeze.Name = %(name)s"),
                         {"name": trait_name})
 
+    @pytest.mark.unit_test
     def test_retrieve_group_fields(self):
         """
         Test that the group fields are set up correctly for the different trait
@@ -88,6 +91,7 @@ class TestDatasetsDBFunctions(TestCase):
                             trait_type, trait_name, dataset_info, db_mock),
                         expected)
 
+    @pytest.mark.unit_test
     def test_retrieve_publish_group_fields(self):
         """
         Test that the `group` and `group_id` fields are retrieved appropriately
@@ -110,6 +114,7 @@ class TestDatasetsDBFunctions(TestCase):
                             " AND PublishFreeze.Name = %(name)s"),
                         {"name": trait_name})
 
+    @pytest.mark.unit_test
     def test_retrieve_geno_group_fields(self):
         """
         Test that the `group` and `group_id` fields are retrieved appropriately
diff --git a/tests/unit/db/test_db.py b/tests/unit/db/test_db.py
index e47c9fd..8ac468c 100644
--- a/tests/unit/db/test_db.py
+++ b/tests/unit/db/test_db.py
@@ -2,6 +2,8 @@
 from unittest import TestCase
 from unittest import mock
 
+import pytest
+
 from gn3.db import fetchall
 from gn3.db import fetchone
 from gn3.db import update
@@ -14,6 +16,7 @@ from gn3.db.metadata_audit import MetadataAudit
 class TestCrudMethods(TestCase):
     """Test cases for CRUD methods"""
 
+    @pytest.mark.unit_test
     def test_update_phenotype_with_no_data(self):
         """Test that a phenotype is updated correctly if an empty Phenotype dataclass
         is provided
@@ -24,6 +27,7 @@ class TestCrudMethods(TestCase):
             conn=db_mock, table="Phenotype",
             data=Phenotype(), where=Phenotype()), None)
 
+    @pytest.mark.unit_test
     def test_update_phenotype_with_data(self):
         """
         Test that a phenotype is updated correctly if some
@@ -46,6 +50,7 @@ class TestCrudMethods(TestCase):
                 "Submitter = %s WHERE id = %s AND Owner = %s",
                 ('Test Pre Pub', 'Test Post Pub', 'Rob', 1, 'Rob'))
 
+    @pytest.mark.unit_test
     def test_fetch_phenotype(self):
         """Test that a single phenotype is fetched properly
 
@@ -68,6 +73,7 @@ class TestCrudMethods(TestCase):
                 "SELECT * FROM Phenotype WHERE id = %s AND Owner = %s",
                 (35, 'Rob'))
 
+    @pytest.mark.unit_test
     def test_fetchall_metadataaudit(self):
         """Test that multiple metadata_audit entries are fetched properly
 
@@ -96,6 +102,7 @@ class TestCrudMethods(TestCase):
                  "dataset_id = %s AND editor = %s"),
                 (35, 'Rob'))
 
+    @pytest.mark.unit_test
     # pylint: disable=R0201
     def test_probeset_called_with_right_columns(self):
         """Given a columns argument, test that the correct sql query is
@@ -112,6 +119,7 @@ class TestCrudMethods(TestCase):
                 "Name = %s",
                 ("1446112_at",))
 
+    @pytest.mark.unit_test
     def test_diff_from_dict(self):
         """Test that a correct diff is generated"""
         self.assertEqual(diff_from_dict({"id": 1, "data": "a"},
diff --git a/tests/unit/db/test_genotypes.py b/tests/unit/db/test_genotypes.py
index c125224..28728bf 100644
--- a/tests/unit/db/test_genotypes.py
+++ b/tests/unit/db/test_genotypes.py
@@ -1,5 +1,6 @@
 """Tests gn3.db.genotypes"""
 from unittest import TestCase
+import pytest
 from gn3.db.genotypes import (
     parse_genotype_file,
     parse_genotype_labels,
@@ -10,6 +11,7 @@ from gn3.db.genotypes import (
 class TestGenotypes(TestCase):
     """Tests for functions in `gn3.db.genotypes`."""
 
+    @pytest.mark.unit_test
     def test_parse_genotype_labels(self):
         """Test that the genotype labels are parsed correctly."""
         self.assertEqual(
@@ -22,6 +24,7 @@ class TestGenotypes(TestCase):
              ("type", "test_type"), ("mat", "test_mat"), ("pat", "test_pat"),
              ("het", "test_het"), ("unk", "test_unk")))
 
+    @pytest.mark.unit_test
     def test_parse_genotype_header(self):
         """Test that the genotype header is parsed correctly."""
         for header, expected in [
@@ -43,6 +46,7 @@ class TestGenotypes(TestCase):
             with self.subTest(header=header):
                 self.assertEqual(parse_genotype_header(header), expected)
 
+    @pytest.mark.unit_test
     def test_parse_genotype_data_line(self):
         """Test parsing of data lines."""
         for line, geno_obj, parlist, expected in [
@@ -76,6 +80,7 @@ class TestGenotypes(TestCase):
                     parse_genotype_marker(line, geno_obj, parlist),
                     expected)
 
+    @pytest.mark.unit_test
     def test_build_genotype_chromosomes(self):
         """
         Given `markers` and `geno_obj`, test that `build_genotype_chromosomes`
@@ -115,6 +120,7 @@ class TestGenotypes(TestCase):
                     build_genotype_chromosomes(geno_obj, markers),
                     expected)
 
+    @pytest.mark.unit_test
     def test_parse_genotype_file(self):
         """Test the parsing of genotype files. """
         self.assertEqual(
diff --git a/tests/unit/db/test_genotypes2.py b/tests/unit/db/test_genotypes2.py
new file mode 100644
index 0000000..453120b
--- /dev/null
+++ b/tests/unit/db/test_genotypes2.py
@@ -0,0 +1,13 @@
+"""Module to test functions in gn3.db.genotypes"""
+
+import pytest
+
+from gn3.db.genotypes import load_genotype_samples
+
+@pytest.mark.unit_test
+@pytest.mark.parametrize(
+    "genotype_filename,file_type,expected", (
+        ("tests/unit/test_data/genotype.txt", "geno", ("BXD1","BXD2")),))
+def test_load_genotype_samples(genotype_filename, file_type, expected):
+    """Test that the genotype samples are loaded correctly"""
+    assert load_genotype_samples(genotype_filename, file_type) == expected
diff --git a/tests/unit/db/test_sample_data.py b/tests/unit/db/test_sample_data.py
new file mode 100644
index 0000000..2524e07
--- /dev/null
+++ b/tests/unit/db/test_sample_data.py
@@ -0,0 +1,188 @@
+"""Tests for gn3.db.sample_data"""
+import pytest
+import gn3
+
+from gn3.db.sample_data import __extract_actions
+from gn3.db.sample_data import delete_sample_data
+from gn3.db.sample_data import insert_sample_data
+from gn3.db.sample_data import update_sample_data
+
+
+@pytest.mark.unit_test
+def test_insert_sample_data(mocker):
+    """Test that inserts work properly"""
+    mock_conn = mocker.MagicMock()
+    strain_id, data_id, inbredset_id = 1, 17373, 20
+    with mock_conn.cursor() as cursor:
+        cursor.fetchone.side_effect = (
+            0,
+            [
+                19,
+            ],
+            0,
+        )
+        mocker.patch(
+            "gn3.db.sample_data.get_sample_data_ids",
+            return_value=(strain_id, data_id, inbredset_id),
+        )
+        insert_sample_data(
+            conn=mock_conn,
+            trait_name=35,
+            data="BXD1,18,3,0,M",
+            csv_header="Strain Name,Value,SE,Count,Sex",
+            phenotype_id=10007,
+        )
+        calls = [
+            mocker.call(
+                "SELECT Id FROM PublishData where Id = %s " "AND StrainId = %s",
+                (data_id, strain_id),
+            ),
+            mocker.call(
+                "INSERT INTO PublishData " "(StrainId, Id, value) VALUES (%s, %s, %s)",
+                (strain_id, data_id, "18"),
+            ),
+            mocker.call(
+                "INSERT INTO PublishSE "
+                "(StrainId, DataId, error) VALUES (%s, %s, %s)",
+                (strain_id, data_id, "3"),
+            ),
+            mocker.call(
+                "INSERT INTO NStrain " "(StrainId, DataId, count) VALUES (%s, %s, %s)",
+                (strain_id, data_id, "0"),
+            ),
+            mocker.call("SELECT Id FROM CaseAttribute WHERE Name = %s", ("Sex",)),
+            mocker.call(
+                "SELECT StrainId FROM CaseAttributeXRefNew "
+                "WHERE StrainId = %s AND "
+                "CaseAttributeId = %s AND InbredSetId = %s",
+                (strain_id, 19, inbredset_id),
+            ),
+            mocker.call(
+                "INSERT INTO CaseAttributeXRefNew "
+                "(StrainId, CaseAttributeId, Value, "
+                "InbredSetId) VALUES (%s, %s, %s, %s)",
+                (strain_id, 19, "M", inbredset_id),
+            ),
+        ]
+        cursor.execute.assert_has_calls(calls, any_order=False)
+
+
+@pytest.mark.unit_test
+def test_delete_sample_data(mocker):
+    """Test that deletes work properly"""
+    mock_conn = mocker.MagicMock()
+    strain_id, data_id, inbredset_id = 1, 17373, 20
+    with mock_conn.cursor() as cursor:
+        cursor.fetchone.side_effect = (
+            0,
+            [
+                19,
+            ],
+            0,
+        )
+        mocker.patch(
+            "gn3.db.sample_data.get_sample_data_ids",
+            return_value=(strain_id, data_id, inbredset_id),
+        )
+        delete_sample_data(
+            conn=mock_conn,
+            trait_name=35,
+            data="BXD1,18,3,0,M",
+            csv_header="Strain Name,Value,SE,Count,Sex",
+            phenotype_id=10007,
+        )
+        calls = [
+            mocker.call(
+                "DELETE FROM PublishData WHERE " "StrainId = %s AND Id = %s",
+                (strain_id, data_id),
+            ),
+            mocker.call(
+                "DELETE FROM PublishSE WHERE " "StrainId = %s AND DataId = %s",
+                (strain_id, data_id),
+            ),
+            mocker.call(
+                "DELETE FROM NStrain WHERE " "StrainId = %s AND DataId = %s",
+                (strain_id, data_id),
+            ),
+            mocker.call(
+                "DELETE FROM CaseAttributeXRefNew WHERE "
+                "StrainId = %s AND CaseAttributeId = "
+                "(SELECT CaseAttributeId FROM "
+                "CaseAttribute WHERE Name = %s) "
+                "AND InbredSetId = %s",
+                (strain_id, "Sex", inbredset_id),
+            ),
+        ]
+        cursor.execute.assert_has_calls(calls, any_order=False)
+
+
+@pytest.mark.unit_test
+def test_extract_actions():
+    """Test that extracting the correct dict of 'actions' work properly"""
+    assert __extract_actions(
+        original_data="BXD1,18,x,0,x",
+        updated_data="BXD1,x,2,1,F",
+        csv_header="Strain Name,Value,SE,Count,Sex",
+    ) == {
+        "delete": {"data": "BXD1,18", "csv_header": "Strain Name,Value"},
+        "insert": {"data": "BXD1,2,F", "csv_header": "Strain Name,SE,Sex"},
+        "update": {"data": "BXD1,1", "csv_header": "Strain Name,Count"},
+    }
+    assert __extract_actions(
+        original_data="BXD1,18,x,0,x",
+        updated_data="BXD1,19,2,1,F",
+        csv_header="Strain Name,Value,SE,Count,Sex",
+    ) == {
+        "delete": None,
+        "insert": {"data": "BXD1,2,F", "csv_header": "Strain Name,SE,Sex"},
+        "update": {"data": "BXD1,19,1", "csv_header": "Strain Name,Value,Count"},
+    }
+
+
+@pytest.mark.unit_test
+def test_update_sample_data(mocker):
+    """Test that updates work properly"""
+    mock_conn = mocker.MagicMock()
+    strain_id, data_id, inbredset_id = 1, 17373, 20
+    with mock_conn.cursor() as cursor:
+        # cursor.fetchone.side_effect = (0, [19, ], 0)
+        mocker.patch(
+            "gn3.db.sample_data.get_sample_data_ids",
+            return_value=(strain_id, data_id, inbredset_id),
+        )
+        mocker.patch("gn3.db.sample_data.insert_sample_data", return_value=1)
+        mocker.patch("gn3.db.sample_data.delete_sample_data", return_value=1)
+        update_sample_data(
+            conn=mock_conn,
+            trait_name=35,
+            original_data="BXD1,18,x,0,x",
+            updated_data="BXD1,x,2,1,F",
+            csv_header="Strain Name,Value,SE,Count,Sex",
+            phenotype_id=10007,
+        )
+        # pylint: disable=[E1101]
+        gn3.db.sample_data.insert_sample_data.assert_called_once_with(
+            conn=mock_conn,
+            trait_name=35,
+            data="BXD1,2,F",
+            csv_header="Strain Name,SE,Sex",
+            phenotype_id=10007,
+        )
+        # pylint: disable=[E1101]
+        gn3.db.sample_data.delete_sample_data.assert_called_once_with(
+            conn=mock_conn,
+            trait_name=35,
+            data="BXD1,18",
+            csv_header="Strain Name,Value",
+            phenotype_id=10007,
+        )
+        cursor.execute.assert_has_calls(
+            [
+                mocker.call(
+                    "UPDATE NStrain SET count = %s "
+                    "WHERE StrainId = %s AND DataId = %s",
+                    ("1", strain_id, data_id),
+                )
+            ],
+            any_order=False,
+        )
diff --git a/tests/unit/db/test_species.py b/tests/unit/db/test_species.py
index b2c4844..e883b21 100644
--- a/tests/unit/db/test_species.py
+++ b/tests/unit/db/test_species.py
@@ -2,6 +2,8 @@
 from unittest import TestCase
 from unittest import mock
 
+import pytest
+
 from gn3.db.species import get_chromosome
 from gn3.db.species import get_all_species
 
@@ -9,6 +11,7 @@ from gn3.db.species import get_all_species
 class TestChromosomes(TestCase):
     """Test cases for fetching chromosomes"""
 
+    @pytest.mark.unit_test
     def test_get_chromosome_using_species_name(self):
         """Test that the chromosome is fetched using a species name"""
         db_mock = mock.MagicMock()
@@ -24,6 +27,7 @@ class TestChromosomes(TestCase):
                 "Species.Name = 'TestCase' ORDER BY OrderId"
             )
 
+    @pytest.mark.unit_test
     def test_get_chromosome_using_group_name(self):
         """Test that the chromosome is fetched using a group name"""
         db_mock = mock.MagicMock()
@@ -39,6 +43,7 @@ class TestChromosomes(TestCase):
                 "InbredSet.Name = 'TestCase' ORDER BY OrderId"
             )
 
+    @pytest.mark.unit_test
     def test_get_all_species(self):
         """Test that species are fetched correctly"""
         db_mock = mock.MagicMock()
diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py
index 4aa9389..434f758 100644
--- a/tests/unit/db/test_traits.py
+++ b/tests/unit/db/test_traits.py
@@ -1,11 +1,11 @@
 """Tests for gn3/db/traits.py"""
 from unittest import mock, TestCase
+import pytest
 from gn3.db.traits import (
     build_trait_name,
     export_trait_data,
     export_informative,
     set_haveinfo_field,
-    update_sample_data,
     retrieve_trait_info,
     set_confidential_field,
     set_homologene_id_field,
@@ -49,6 +49,7 @@ trait_data = {
 class TestTraitsDBFunctions(TestCase):
     "Test cases for traits functions"
 
+    @pytest.mark.unit_test
     def test_retrieve_publish_trait_info(self):
         """Test retrieval of type `Publish` traits."""
         db_mock = mock.MagicMock()
@@ -75,15 +76,15 @@ class TestTraitsDBFunctions(TestCase):
                  " Publication.Year, PublishXRef.Sequence, Phenotype.Units,"
                  " PublishXRef.comments"
                  " FROM"
-                 " PublishXRef, Publication, Phenotype, PublishFreeze"
+                 " PublishXRef, Publication, Phenotype"
                  " WHERE"
                  " PublishXRef.Id = %(trait_name)s"
                  " AND Phenotype.Id = PublishXRef.PhenotypeId"
                  " AND Publication.Id = PublishXRef.PublicationId"
-                 " AND PublishXRef.InbredSetId = PublishFreeze.InbredSetId"
-                 " AND PublishFreeze.Id =%(trait_dataset_id)s"),
+                 " AND PublishXRef.InbredSetId = %(trait_dataset_id)s"),
                 trait_source)
 
+    @pytest.mark.unit_test
     def test_retrieve_probeset_trait_info(self):
         """Test retrieval of type `Probeset` traits."""
         db_mock = mock.MagicMock()
@@ -119,6 +120,7 @@ class TestTraitsDBFunctions(TestCase):
                     "AND ProbeSetFreeze.Name = %(trait_dataset_name)s "
                     "AND ProbeSet.Name = %(trait_name)s"), trait_source)
 
+    @pytest.mark.unit_test
     def test_retrieve_geno_trait_info(self):
         """Test retrieval of type `Geno` traits."""
         db_mock = mock.MagicMock()
@@ -134,14 +136,14 @@ class TestTraitsDBFunctions(TestCase):
                     "SELECT "
                     "Geno.name, Geno.chr, Geno.mb, Geno.source2, Geno.sequence "
                     "FROM "
-                    "Geno, GenoFreeze, GenoXRef "
+                    "Geno INNER JOIN GenoXRef ON GenoXRef.GenoId = Geno.Id "
+                    "INNER JOIN GenoFreeze ON GenoFreeze.Id = GenoXRef.GenoFreezeId "
                     "WHERE "
-                    "GenoXRef.GenoFreezeId = GenoFreeze.Id "
-                    "AND GenoXRef.GenoId = Geno.Id "
-                    "AND GenoFreeze.Name = %(trait_dataset_name)s "
+                    "GenoFreeze.Name = %(trait_dataset_name)s "
                     "AND Geno.Name = %(trait_name)s"),
                 trait_source)
 
+    @pytest.mark.unit_test
     def test_retrieve_temp_trait_info(self):
         """Test retrieval of type `Temp` traits."""
         db_mock = mock.MagicMock()
@@ -154,6 +156,7 @@ class TestTraitsDBFunctions(TestCase):
                 "SELECT name, description FROM Temp WHERE Name = %(trait_name)s",
                 trait_source)
 
+    @pytest.mark.unit_test
     def test_build_trait_name_with_good_fullnames(self):
         """
         Check that the name is built correctly.
@@ -170,6 +173,7 @@ class TestTraitsDBFunctions(TestCase):
             with self.subTest(fullname=fullname):
                 self.assertEqual(build_trait_name(fullname), expected)
 
+    @pytest.mark.unit_test
     def test_build_trait_name_with_bad_fullnames(self):
         """
         Check that an exception is raised if the full name format is wrong.
@@ -179,6 +183,7 @@ class TestTraitsDBFunctions(TestCase):
                 with self.assertRaises(AssertionError, msg="Name format error"):
                     build_trait_name(fullname)
 
+    @pytest.mark.unit_test
     def test_retrieve_trait_info(self):
         """Test that information on traits is retrieved as appropriate."""
         for threshold, trait_fullname, expected in [
@@ -195,39 +200,7 @@ class TestTraitsDBFunctions(TestCase):
                             threshold, trait_fullname, db_mock),
                         expected)
 
-    def test_update_sample_data(self):
-        """Test that the SQL queries when calling update_sample_data are called with
-        the right calls.
-
-        """
-        # pylint: disable=C0103
-        db_mock = mock.MagicMock()
-
-        STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
-        PUBLISH_DATA_SQL: str = (
-            "UPDATE PublishData SET value = %s "
-            "WHERE StrainId = %s AND Id = %s")
-        PUBLISH_SE_SQL: str = (
-            "UPDATE PublishSE SET error = %s "
-            "WHERE StrainId = %s AND DataId = %s")
-        N_STRAIN_SQL: str = (
-            "UPDATE NStrain SET count = %s "
-            "WHERE StrainId = %s AND DataId = %s")
-
-        with db_mock.cursor() as cursor:
-            type(cursor).rowcount = 1
-            self.assertEqual(update_sample_data(
-                conn=db_mock, strain_name="BXD11",
-                strain_id=10, publish_data_id=8967049,
-                value=18.7, error=2.3, count=2),
-                             (1, 1, 1, 1))
-            cursor.execute.assert_has_calls(
-                [mock.call(STRAIN_ID_SQL, ('BXD11', 10)),
-                 mock.call(PUBLISH_DATA_SQL, (18.7, 10, 8967049)),
-                 mock.call(PUBLISH_SE_SQL, (2.3, 10, 8967049)),
-                 mock.call(N_STRAIN_SQL, (2, 10, 8967049))]
-            )
-
+    @pytest.mark.unit_test
     def test_set_haveinfo_field(self):
         """Test that the `haveinfo` field is set up correctly"""
         for trait_info, expected in [
@@ -236,6 +209,7 @@ class TestTraitsDBFunctions(TestCase):
             with self.subTest(trait_info=trait_info, expected=expected):
                 self.assertEqual(set_haveinfo_field(trait_info), expected)
 
+    @pytest.mark.unit_test
     def test_set_homologene_id_field(self):
         """Test that the `homologene_id` field is set up correctly"""
         for trait_type, trait_info, expected in [
@@ -250,6 +224,7 @@ class TestTraitsDBFunctions(TestCase):
                     self.assertEqual(
                         set_homologene_id_field(trait_type, trait_info, db_mock), expected)
 
+    @pytest.mark.unit_test
     def test_set_confidential_field(self):
         """Test that the `confidential` field is set up correctly"""
         for trait_type, trait_info, expected in [
@@ -261,6 +236,7 @@ class TestTraitsDBFunctions(TestCase):
                 self.assertEqual(
                     set_confidential_field(trait_type, trait_info), expected)
 
+    @pytest.mark.unit_test
     def test_export_trait_data_dtype(self):
         """
         Test `export_trait_data` with different values for the `dtype` keyword
@@ -276,6 +252,7 @@ class TestTraitsDBFunctions(TestCase):
                     export_trait_data(trait_data, samplelist, dtype=dtype),
                     expected)
 
+    @pytest.mark.unit_test
     def test_export_trait_data_dtype_all_flags(self):
         """
         Test `export_trait_data` with different values for the `dtype` keyword
@@ -317,6 +294,7 @@ class TestTraitsDBFunctions(TestCase):
                         n_exists=nflag),
                     expected)
 
+    @pytest.mark.unit_test
     def test_export_informative(self):
         """Test that the function exports appropriate data."""
         # pylint: disable=W0621