about summary refs log tree commit diff
path: root/tests/unit/db/test_traits.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unit/db/test_traits.py')
-rw-r--r--tests/unit/db/test_traits.py60
1 files changed, 19 insertions, 41 deletions
diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py
index 4aa9389..434f758 100644
--- a/tests/unit/db/test_traits.py
+++ b/tests/unit/db/test_traits.py
@@ -1,11 +1,11 @@
 """Tests for gn3/db/traits.py"""
 from unittest import mock, TestCase
+import pytest
 from gn3.db.traits import (
     build_trait_name,
     export_trait_data,
     export_informative,
     set_haveinfo_field,
-    update_sample_data,
     retrieve_trait_info,
     set_confidential_field,
     set_homologene_id_field,
@@ -49,6 +49,7 @@ trait_data = {
 class TestTraitsDBFunctions(TestCase):
     "Test cases for traits functions"
 
+    @pytest.mark.unit_test
     def test_retrieve_publish_trait_info(self):
         """Test retrieval of type `Publish` traits."""
         db_mock = mock.MagicMock()
@@ -75,15 +76,15 @@ class TestTraitsDBFunctions(TestCase):
                  " Publication.Year, PublishXRef.Sequence, Phenotype.Units,"
                  " PublishXRef.comments"
                  " FROM"
-                 " PublishXRef, Publication, Phenotype, PublishFreeze"
+                 " PublishXRef, Publication, Phenotype"
                  " WHERE"
                  " PublishXRef.Id = %(trait_name)s"
                  " AND Phenotype.Id = PublishXRef.PhenotypeId"
                  " AND Publication.Id = PublishXRef.PublicationId"
-                 " AND PublishXRef.InbredSetId = PublishFreeze.InbredSetId"
-                 " AND PublishFreeze.Id =%(trait_dataset_id)s"),
+                 " AND PublishXRef.InbredSetId = %(trait_dataset_id)s"),
                 trait_source)
 
+    @pytest.mark.unit_test
     def test_retrieve_probeset_trait_info(self):
         """Test retrieval of type `Probeset` traits."""
         db_mock = mock.MagicMock()
@@ -119,6 +120,7 @@ class TestTraitsDBFunctions(TestCase):
                     "AND ProbeSetFreeze.Name = %(trait_dataset_name)s "
                     "AND ProbeSet.Name = %(trait_name)s"), trait_source)
 
+    @pytest.mark.unit_test
     def test_retrieve_geno_trait_info(self):
         """Test retrieval of type `Geno` traits."""
         db_mock = mock.MagicMock()
@@ -134,14 +136,14 @@ class TestTraitsDBFunctions(TestCase):
                     "SELECT "
                     "Geno.name, Geno.chr, Geno.mb, Geno.source2, Geno.sequence "
                     "FROM "
-                    "Geno, GenoFreeze, GenoXRef "
+                    "Geno INNER JOIN GenoXRef ON GenoXRef.GenoId = Geno.Id "
+                    "INNER JOIN GenoFreeze ON GenoFreeze.Id = GenoXRef.GenoFreezeId "
                     "WHERE "
-                    "GenoXRef.GenoFreezeId = GenoFreeze.Id "
-                    "AND GenoXRef.GenoId = Geno.Id "
-                    "AND GenoFreeze.Name = %(trait_dataset_name)s "
+                    "GenoFreeze.Name = %(trait_dataset_name)s "
                     "AND Geno.Name = %(trait_name)s"),
                 trait_source)
 
+    @pytest.mark.unit_test
     def test_retrieve_temp_trait_info(self):
         """Test retrieval of type `Temp` traits."""
         db_mock = mock.MagicMock()
@@ -154,6 +156,7 @@ class TestTraitsDBFunctions(TestCase):
                 "SELECT name, description FROM Temp WHERE Name = %(trait_name)s",
                 trait_source)
 
+    @pytest.mark.unit_test
     def test_build_trait_name_with_good_fullnames(self):
         """
         Check that the name is built correctly.
@@ -170,6 +173,7 @@ class TestTraitsDBFunctions(TestCase):
             with self.subTest(fullname=fullname):
                 self.assertEqual(build_trait_name(fullname), expected)
 
+    @pytest.mark.unit_test
     def test_build_trait_name_with_bad_fullnames(self):
         """
         Check that an exception is raised if the full name format is wrong.
@@ -179,6 +183,7 @@ class TestTraitsDBFunctions(TestCase):
                 with self.assertRaises(AssertionError, msg="Name format error"):
                     build_trait_name(fullname)
 
+    @pytest.mark.unit_test
     def test_retrieve_trait_info(self):
         """Test that information on traits is retrieved as appropriate."""
         for threshold, trait_fullname, expected in [
@@ -195,39 +200,7 @@ class TestTraitsDBFunctions(TestCase):
                             threshold, trait_fullname, db_mock),
                         expected)
 
-    def test_update_sample_data(self):
-        """Test that the SQL queries when calling update_sample_data are called with
-        the right calls.
-
-        """
-        # pylint: disable=C0103
-        db_mock = mock.MagicMock()
-
-        STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
-        PUBLISH_DATA_SQL: str = (
-            "UPDATE PublishData SET value = %s "
-            "WHERE StrainId = %s AND Id = %s")
-        PUBLISH_SE_SQL: str = (
-            "UPDATE PublishSE SET error = %s "
-            "WHERE StrainId = %s AND DataId = %s")
-        N_STRAIN_SQL: str = (
-            "UPDATE NStrain SET count = %s "
-            "WHERE StrainId = %s AND DataId = %s")
-
-        with db_mock.cursor() as cursor:
-            type(cursor).rowcount = 1
-            self.assertEqual(update_sample_data(
-                conn=db_mock, strain_name="BXD11",
-                strain_id=10, publish_data_id=8967049,
-                value=18.7, error=2.3, count=2),
-                             (1, 1, 1, 1))
-            cursor.execute.assert_has_calls(
-                [mock.call(STRAIN_ID_SQL, ('BXD11', 10)),
-                 mock.call(PUBLISH_DATA_SQL, (18.7, 10, 8967049)),
-                 mock.call(PUBLISH_SE_SQL, (2.3, 10, 8967049)),
-                 mock.call(N_STRAIN_SQL, (2, 10, 8967049))]
-            )
-
+    @pytest.mark.unit_test
     def test_set_haveinfo_field(self):
         """Test that the `haveinfo` field is set up correctly"""
         for trait_info, expected in [
@@ -236,6 +209,7 @@ class TestTraitsDBFunctions(TestCase):
             with self.subTest(trait_info=trait_info, expected=expected):
                 self.assertEqual(set_haveinfo_field(trait_info), expected)
 
+    @pytest.mark.unit_test
     def test_set_homologene_id_field(self):
         """Test that the `homologene_id` field is set up correctly"""
         for trait_type, trait_info, expected in [
@@ -250,6 +224,7 @@ class TestTraitsDBFunctions(TestCase):
                     self.assertEqual(
                         set_homologene_id_field(trait_type, trait_info, db_mock), expected)
 
+    @pytest.mark.unit_test
     def test_set_confidential_field(self):
         """Test that the `confidential` field is set up correctly"""
         for trait_type, trait_info, expected in [
@@ -261,6 +236,7 @@ class TestTraitsDBFunctions(TestCase):
                 self.assertEqual(
                     set_confidential_field(trait_type, trait_info), expected)
 
+    @pytest.mark.unit_test
     def test_export_trait_data_dtype(self):
         """
         Test `export_trait_data` with different values for the `dtype` keyword
@@ -276,6 +252,7 @@ class TestTraitsDBFunctions(TestCase):
                     export_trait_data(trait_data, samplelist, dtype=dtype),
                     expected)
 
+    @pytest.mark.unit_test
     def test_export_trait_data_dtype_all_flags(self):
         """
         Test `export_trait_data` with different values for the `dtype` keyword
@@ -317,6 +294,7 @@ class TestTraitsDBFunctions(TestCase):
                         n_exists=nflag),
                     expected)
 
+    @pytest.mark.unit_test
     def test_export_informative(self):
         """Test that the function exports appropriate data."""
         # pylint: disable=W0621