about summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/unit/computations/test_correlation.py1
-rw-r--r--tests/unit/db/test_datasets.py14
-rw-r--r--tests/unit/db/test_traits.py33
3 files changed, 32 insertions, 16 deletions
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py
index 0de347d..7523d99 100644
--- a/tests/unit/computations/test_correlation.py
+++ b/tests/unit/computations/test_correlation.py
@@ -1,7 +1,6 @@
 """Module contains the tests for correlation"""
 from unittest import TestCase
 from unittest import mock
-import unittest
 
 from collections import namedtuple
 import math
diff --git a/tests/unit/db/test_datasets.py b/tests/unit/db/test_datasets.py
index 39f4af9..0b8c2fe 100644
--- a/tests/unit/db/test_datasets.py
+++ b/tests/unit/db/test_datasets.py
@@ -13,15 +13,17 @@ class TestDatasetsDBFunctions(TestCase):
 
     def test_retrieve_dataset_name(self):
         """Test that the function is called correctly."""
-        for trait_type, thresh, trait_name, dataset_name, columns, table in [
+        for trait_type, thresh, trait_name, dataset_name, columns, table, expected in [
                 ["ProbeSet", 9, "probesetTraitName", "probesetDatasetName",
-                 "Id, Name, FullName, ShortName, DataScale", "ProbeSetFreeze"],
+                 "Id, Name, FullName, ShortName, DataScale", "ProbeSetFreeze",
+                 {"dataset_id": None, "dataset_name": "probesetDatasetName",
+                  "dataset_fullname": "probesetDatasetName"}],
                 ["Geno", 3, "genoTraitName", "genoDatasetName",
-                 "Id, Name, FullName, ShortName", "GenoFreeze"],
+                 "Id, Name, FullName, ShortName", "GenoFreeze", {}],
                 ["Publish", 6, "publishTraitName", "publishDatasetName",
-                 "Id, Name, FullName, ShortName", "PublishFreeze"],
+                 "Id, Name, FullName, ShortName", "PublishFreeze", {}],
                 ["Temp", 4, "tempTraitName", "tempTraitName",
-                 "Id, Name, FullName, ShortName", "TempFreeze"]]:
+                 "Id, Name, FullName, ShortName", "TempFreeze", {}]]:
             db_mock = mock.MagicMock()
             with self.subTest(trait_type=trait_type):
                 with db_mock.cursor() as cursor:
@@ -29,7 +31,7 @@ class TestDatasetsDBFunctions(TestCase):
                     self.assertEqual(
                         retrieve_dataset_name(
                             trait_type, thresh, trait_name, dataset_name, db_mock),
-                        {})
+                        expected)
                     cursor.execute.assert_called_once_with(
                         "SELECT {cols} "
                         "FROM {table} "
diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py
index 4aa9389..75f3d4c 100644
--- a/tests/unit/db/test_traits.py
+++ b/tests/unit/db/test_traits.py
@@ -202,8 +202,6 @@ class TestTraitsDBFunctions(TestCase):
         """
         # pylint: disable=C0103
         db_mock = mock.MagicMock()
-
-        STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
         PUBLISH_DATA_SQL: str = (
             "UPDATE PublishData SET value = %s "
             "WHERE StrainId = %s AND Id = %s")
@@ -216,16 +214,33 @@ class TestTraitsDBFunctions(TestCase):
 
         with db_mock.cursor() as cursor:
             type(cursor).rowcount = 1
+            mock_fetchone = mock.MagicMock()
+            mock_fetchone.return_value = (1, 1)
+            type(cursor).fetchone = mock_fetchone
             self.assertEqual(update_sample_data(
                 conn=db_mock, strain_name="BXD11",
-                strain_id=10, publish_data_id=8967049,
-                value=18.7, error=2.3, count=2),
-                             (1, 1, 1, 1))
+                trait_name="1",
+                phenotype_id=10, value=18.7,
+                error=2.3, count=2),
+                             (1, 1, 1))
             cursor.execute.assert_has_calls(
-                [mock.call(STRAIN_ID_SQL, ('BXD11', 10)),
-                 mock.call(PUBLISH_DATA_SQL, (18.7, 10, 8967049)),
-                 mock.call(PUBLISH_SE_SQL, (2.3, 10, 8967049)),
-                 mock.call(N_STRAIN_SQL, (2, 10, 8967049))]
+                [mock.call('SELECT Strain.Id, PublishData.Id FROM'
+                           ' (PublishData, Strain, PublishXRef, '
+                           'PublishFreeze) LEFT JOIN PublishSE ON '
+                           '(PublishSE.DataId = PublishData.Id '
+                           'AND PublishSE.StrainId = '
+                           'PublishData.StrainId) LEFT JOIN NStrain ON '
+                           '(NStrain.DataId = PublishData.Id AND '
+                           'NStrain.StrainId = PublishData.StrainId) WHERE '
+                           'PublishXRef.InbredSetId = '
+                           'PublishFreeze.InbredSetId AND PublishData.Id = '
+                           'PublishXRef.DataId AND PublishXRef.Id = 1 AND '
+                           'PublishXRef.PhenotypeId = 10 AND '
+                           'PublishData.StrainId = Strain.Id AND '
+                           'Strain.Name = "BXD11"'),
+                 mock.call(PUBLISH_DATA_SQL, (18.7, 1, 1)),
+                 mock.call(PUBLISH_SE_SQL, (2.3, 1, 1)),
+                 mock.call(N_STRAIN_SQL, (2, 1, 1))]
             )
 
     def test_set_haveinfo_field(self):