about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/db/traits.py13
-rw-r--r--tests/unit/db/test_traits.py39
2 files changed, 45 insertions, 7 deletions
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index 902eb8b..ce6298f 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -160,11 +160,14 @@ def set_confidential_field(trait_info):
     """Post processing function for 'Publish' trait types.
 
     It sets the value for the 'confidential' key."""
-    return {
-        **trait_info,
-        "confidential": 1 if (
-            trait_info.get("pre_publication_description", None)
-            and not trait_info.get("pubmed_id", None)) else 0}
+    if trait_info["type"] == "Publish":
+        return {
+            **trait_info,
+            "confidential": 1 if (
+                trait_info.get("pre_publication_description", None)
+                and not trait_info.get("pubmed_id", None)) else 0}
+    else:
+        return trait_info
 
 def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any):
     """Retrieve trait information for type `ProbeSet` traits.
diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py
index 3840dd1..7e8b29c 100644
--- a/tests/unit/db/test_traits.py
+++ b/tests/unit/db/test_traits.py
@@ -1,13 +1,16 @@
 """Tests for gn3/db/traits.py"""
 from unittest import mock, TestCase
 from gn3.db.traits import (
+    set_haveinfo_field,
+    update_sample_data,
     retrieve_trait_info,
+    set_confidential_field,
+    set_homologene_id_field,
     retrieve_geno_trait_info,
     retrieve_temp_trait_info,
     retrieve_trait_dataset_name,
     retrieve_publish_trait_info,
-    retrieve_probeset_trait_info,
-    update_sample_data)
+    retrieve_probeset_trait_info)
 
 class TestTraitsDBFunctions(TestCase):
     "Test cases for traits functions"
@@ -198,3 +201,35 @@ class TestTraitsDBFunctions(TestCase):
                  mock.call(PUBLISH_SE_SQL, (2.3, 10, 8967049)),
                  mock.call(N_STRAIN_SQL, (2, 10, 8967049))]
             )
+
+    def test_set_haveinfo_field(self):
+        for trait_info, expected in [
+                [{}, {"haveinfo": 0}],
+                [{"k1": "v1"}, {"k1": "v1", "haveinfo": 1}]]:
+            with self.subTest(trait_info=trait_info, expected=expected):
+                self.assertEqual(set_haveinfo_field(trait_info), expected)
+
+    def test_set_homologene_id_field(self):
+        for trait_info, expected in [
+                [{"type": "Publish"},
+                 {"type": "Publish", "homologeneid": None}],
+                [{"type": "ProbeSet"},
+                 {"type": "ProbeSet", "homologeneid": None}],
+                [{"type": "Geno"}, {"type": "Geno", "homologeneid": None}],
+                [{"type": "Temp"}, {"type": "Temp", "homologeneid": None}]]:
+            db_mock = mock.MagicMock()
+            with self.subTest(trait_info=trait_info, expected=expected):
+                with db_mock.cursor() as cursor:
+                    cursor.fetchone.return_value = ()
+                    self.assertEqual(
+                        set_homologene_id_field(trait_info, db_mock), expected)
+
+    def test_set_confidential_field(self):
+        for trait_info, expected in [
+                [{"type": "Publish"}, {"type": "Publish", "confidential": 0}],
+                [{"type": "ProbeSet"}, {"type": "ProbeSet"}],
+                [{"type": "Geno"}, {"type": "Geno"}],
+                [{"type": "Temp"}, {"type": "Temp"}]]:
+            with self.subTest(trait_info=trait_info, expected=expected):
+                self.assertEqual(
+                    set_confidential_field(trait_info), expected)