diff options
-rw-r--r-- | gn3/db/traits.py | 13 | ||||
-rw-r--r-- | tests/unit/db/test_traits.py | 39 |
2 files changed, 45 insertions, 7 deletions
diff --git a/gn3/db/traits.py b/gn3/db/traits.py index 902eb8b..ce6298f 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -160,11 +160,14 @@ def set_confidential_field(trait_info): """Post processing function for 'Publish' trait types. It sets the value for the 'confidential' key.""" - return { - **trait_info, - "confidential": 1 if ( - trait_info.get("pre_publication_description", None) - and not trait_info.get("pubmed_id", None)) else 0} + if trait_info["type"] == "Publish": + return { + **trait_info, + "confidential": 1 if ( + trait_info.get("pre_publication_description", None) + and not trait_info.get("pubmed_id", None)) else 0} + else: + return trait_info def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any): """Retrieve trait information for type `ProbeSet` traits. diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py index 3840dd1..7e8b29c 100644 --- a/tests/unit/db/test_traits.py +++ b/tests/unit/db/test_traits.py @@ -1,13 +1,16 @@ """Tests for gn3/db/traits.py""" from unittest import mock, TestCase from gn3.db.traits import ( + set_haveinfo_field, + update_sample_data, retrieve_trait_info, + set_confidential_field, + set_homologene_id_field, retrieve_geno_trait_info, retrieve_temp_trait_info, retrieve_trait_dataset_name, retrieve_publish_trait_info, - retrieve_probeset_trait_info, - update_sample_data) + retrieve_probeset_trait_info) class TestTraitsDBFunctions(TestCase): "Test cases for traits functions" @@ -198,3 +201,35 @@ class TestTraitsDBFunctions(TestCase): mock.call(PUBLISH_SE_SQL, (2.3, 10, 8967049)), mock.call(N_STRAIN_SQL, (2, 10, 8967049))] ) + + def test_set_haveinfo_field(self): + for trait_info, expected in [ + [{}, {"haveinfo": 0}], + [{"k1": "v1"}, {"k1": "v1", "haveinfo": 1}]]: + with self.subTest(trait_info=trait_info, expected=expected): + self.assertEqual(set_haveinfo_field(trait_info), expected) + + def test_set_homologene_id_field(self): + for trait_info, expected in [ + [{"type": "Publish"}, + {"type": "Publish", "homologeneid": None}], + [{"type": "ProbeSet"}, + {"type": "ProbeSet", "homologeneid": None}], + [{"type": "Geno"}, {"type": "Geno", "homologeneid": None}], + [{"type": "Temp"}, {"type": "Temp", "homologeneid": None}]]: + db_mock = mock.MagicMock() + with self.subTest(trait_info=trait_info, expected=expected): + with db_mock.cursor() as cursor: + cursor.fetchone.return_value = () + self.assertEqual( + set_homologene_id_field(trait_info, db_mock), expected) + + def test_set_confidential_field(self): + for trait_info, expected in [ + [{"type": "Publish"}, {"type": "Publish", "confidential": 0}], + [{"type": "ProbeSet"}, {"type": "ProbeSet"}], + [{"type": "Geno"}, {"type": "Geno"}], + [{"type": "Temp"}, {"type": "Temp"}]]: + with self.subTest(trait_info=trait_info, expected=expected): + self.assertEqual( + set_confidential_field(trait_info), expected) |