diff options
-rw-r--r-- | gn3/db/traits.py | 57 | ||||
-rw-r--r-- | tests/unit/db/test_traits.py | 105 |
2 files changed, 162 insertions, 0 deletions
diff --git a/gn3/db/traits.py b/gn3/db/traits.py index ce6298f..ea35d7e 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -286,6 +286,62 @@ def set_homologene_id_field(trait_info, conn): } return functions_table[trait_info["type"]](trait_info) +def set_geno_riset_fields(name, conn): + """ + Retrieve the RISet, and RISetID values for various Geno trait types. + """ + query = ( + "SELECT InbredSet.Name, InbredSet.Id " + "FROM InbredSet, GenoFreeze " + "WHERE GenoFreeze.InbredSetId = InbredSet.Id " + "AND GenoFreeze.Name = %(name)s") + with conn.cursor() as cursor: + return cursor.execute(query, {"name": name}) + +def set_publish_riset_fields(name, conn): + """ + Retrieve the RISet, and RISetID values for various Publish trait types. + """ + query = ( + "SELECT InbredSet.Name, InbredSet.Id " + "FROM InbredSet, PublishFreeze " + "WHERE PublishFreeze.InbredSetId = InbredSet.Id " + "AND PublishFreeze.Name = %(name)s") + with conn.cursor() as cursor: + return cursor.execute(query, {"name": name}) + +def set_probeset_riset_fields(name, conn): + """ + Retrieve the RISet, and RISetID values for various ProbeSet trait types. + """ + query = ( + "SELECT InbredSet.Name, InbredSet.Id " + "FROM InbredSet, ProbeSetFreeze, ProbeFreeze " + "WHERE ProbeFreeze.InbredSetId = InbredSet.Id " + "AND ProbeFreeze.Id = ProbeSetFreeze.ProbeFreezeId " + "AND ProbeSetFreeze.Name = %(name)s") + with conn.cursor() as cursor: + return cursor.execute(query, {"name": name}) + +def set_riset_fields(trait_info, conn): + """ + Retrieve the RISet, and RISetID values for various trait types. + """ + riset_functions_map = { + "Temp": lambda ti, con: (None, None), + "Geno": set_geno_riset_fields, + "Publish": set_publish_riset_fields, + "ProbeSet": set_probeset_riset_fields + } + if not trait_info.get("haveinfo", None): + return trait_info + + riset, riid = riset_functions_map[trait_info["type"]]( + trait_info["name"], conn) + return { + **trait_info, "risetid": riid, + "riset": "BXD" if riset == "BXD300" else riset} + def retrieve_trait_info( trait_type: str, trait_name: str, trait_dataset_id: int, trait_dataset_name: str, conn: Any, QTL=None): @@ -303,6 +359,7 @@ def retrieve_trait_info( } common_post_processing_fn = compose( + lambda ti: set_riset_fields(ti, conn), lambda ti: set_homologene_id_field(ti, conn), lambda ti: {"type": trait_type, **ti}, set_haveinfo_field) diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py index 7e8b29c..2445d26 100644 --- a/tests/unit/db/test_traits.py +++ b/tests/unit/db/test_traits.py @@ -1,13 +1,17 @@ """Tests for gn3/db/traits.py""" from unittest import mock, TestCase from gn3.db.traits import ( + set_riset_fields, set_haveinfo_field, update_sample_data, retrieve_trait_info, + set_geno_riset_fields, set_confidential_field, set_homologene_id_field, retrieve_geno_trait_info, retrieve_temp_trait_info, + set_publish_riset_fields, + set_probeset_riset_fields, retrieve_trait_dataset_name, retrieve_publish_trait_info, retrieve_probeset_trait_info) @@ -233,3 +237,104 @@ class TestTraitsDBFunctions(TestCase): with self.subTest(trait_info=trait_info, expected=expected): self.assertEqual( set_confidential_field(trait_info), expected) + + def test_set_geno_riset_fields(self): + """ + Test that the `riset` and `riset_id` fields are retrieved appropriately + for the 'Geno' trait type. + """ + for trait_name, expected in [ + ["testGenoName", ()]]: + db_mock = mock.MagicMock() + with self.subTest(trait_name=trait_name, expected=expected): + with db_mock.cursor() as cursor: + cursor.execute.return_value = () + self.assertEqual( + set_geno_riset_fields(trait_name, db_mock), expected) + cursor.execute.assert_called_once_with( + ( + "SELECT InbredSet.Name, InbredSet.Id" + " FROM InbredSet, GenoFreeze" + " WHERE GenoFreeze.InbredSetId = InbredSet.Id" + " AND GenoFreeze.Name = %(name)s"), + {"name": trait_name}) + + + def test_set_publish_riset_fields(self): + """ + Test that the `riset` and `riset_id` fields are retrieved appropriately + for the 'Publish' trait type. + """ + for trait_name, expected in [ + ["testPublishName", ()]]: + db_mock = mock.MagicMock() + with self.subTest(trait_name=trait_name, expected=expected): + with db_mock.cursor() as cursor: + cursor.execute.return_value = () + self.assertEqual( + set_publish_riset_fields(trait_name, db_mock), expected) + cursor.execute.assert_called_once_with( + ( + "SELECT InbredSet.Name, InbredSet.Id" + " FROM InbredSet, PublishFreeze" + " WHERE PublishFreeze.InbredSetId = InbredSet.Id" + " AND PublishFreeze.Name = %(name)s"), + {"name": trait_name}) + + + def test_set_probeset_riset_fields(self): + """ + Test that the `riset` and `riset_id` fields are retrieved appropriately + for the 'ProbeSet' trait type. + """ + for trait_name, expected in [ + ["testProbeSetName", ()]]: + db_mock = mock.MagicMock() + with self.subTest(trait_name=trait_name, expected=expected): + with db_mock.cursor() as cursor: + cursor.execute.return_value = () + self.assertEqual( + set_probeset_riset_fields(trait_name, db_mock), expected) + cursor.execute.assert_called_once_with( + ( + "SELECT InbredSet.Name, InbredSet.Id" + " FROM InbredSet, ProbeSetFreeze, ProbeFreeze" + " WHERE ProbeFreeze.InbredSetId = InbredSet.Id" + " AND ProbeFreeze.Id = ProbeSetFreeze.ProbeFreezeId" + " AND ProbeSetFreeze.Name = %(name)s"), + {"name": trait_name}) + + def test_set_riset_fields(self): + """ + Test that the riset fields are set up correctly for the different trait + types. + """ + for trait_info, expected in [ + [{}, {}], + [{"haveinfo": 0, "type": "Publish"}, + {"haveinfo": 0, "type": "Publish"}], + [{"haveinfo": 0, "type": "ProbeSet"}, + {"haveinfo": 0, "type": "ProbeSet"}], + [{"haveinfo": 0, "type": "Geno"}, + {"haveinfo": 0, "type": "Geno"}], + [{"haveinfo": 0, "type": "Temp"}, + {"haveinfo": 0, "type": "Temp"}], + [{"haveinfo": 1, "type": "Publish", "name": "test"}, + {"haveinfo": 1, "type": "Publish", "name": "test", + "riset": "riset_name", "risetid": 0}], + [{"haveinfo": 1, "type": "ProbeSet", "name": "test"}, + {"haveinfo": 1, "type": "ProbeSet", "name": "test", + "riset": "riset_name", "risetid": 0}], + [{"haveinfo": 1, "type": "Geno", "name": "test"}, + {"haveinfo": 1, "type": "Geno", "name": "test", + "riset": "riset_name", "risetid": 0}], + [{"haveinfo": 1, "type": "Temp", "name": "test"}, + {"haveinfo": 1, "type": "Temp", "name": "test", "riset": None, + "risetid": None}] + ]: + db_mock = mock.MagicMock() + with self.subTest(trait_info=trait_info, expected=expected): + with db_mock.cursor() as cursor: + cursor.execute.return_value = ("riset_name", 0) + self.assertEqual( + set_riset_fields(trait_info, db_mock), expected) |