about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/db/traits.py33
-rw-r--r--tests/unit/db/test_traits.py27
2 files changed, 36 insertions, 24 deletions
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index a740352..6ea24be 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -326,10 +326,23 @@ def build_trait_name(trait_fullname):
     """
     Initialises the trait's name, and other values from the search data provided
     """
+    def dataset_type(dset_name):
+        if dset_name.find('Temp') >= 0:
+            return "Temp"
+        if dset_name.find('Geno') >= 0:
+            return "Geno"
+        if dset_name.find('Publish') >= 0:
+            return "Publish"
+        return "ProbeSet"
+
     name_parts = trait_fullname.split("::")
     assert len(name_parts) >= 2, "Name format error"
+    dataset_name = name_parts[0]
+    dataset_type = dataset_type(dataset_name)
     return {
-        "db": {"dataset_name": name_parts[0]},
+        "db": {
+            "dataset_name": dataset_name,
+            "dataset_type": dataset_type},
         "trait_fullname": trait_fullname,
         "trait_name": name_parts[1],
         "cellid": name_parts[2] if len(name_parts) == 3 else ""
@@ -357,7 +370,7 @@ def retrieve_probeset_sequence(trait, conn):
         return {**trait, "sequence": seq[0] if seq else ""}
 
 def retrieve_trait_info(
-        trait_type: str, threshold: int, trait_full_name: str, conn: Any,
+        threshold: int, trait_full_name: str, conn: Any,
         qtl=None):
     """Retrieves the trait information.
 
@@ -366,6 +379,7 @@ def retrieve_trait_info(
     This function, or the dependent functions, might be incomplete as they are
     currently."""
     trait = build_trait_name(trait_full_name)
+    trait_dataset_type = trait["db"]["dataset_type"]
     trait_info_function_table = {
         "Publish": retrieve_publish_trait_info,
         "ProbeSet": retrieve_probeset_trait_info,
@@ -374,14 +388,14 @@ def retrieve_trait_info(
     }
 
     common_post_processing_fn = compose(
-        lambda ti: load_qtl_info(qtl, trait_type, ti, conn),
-        lambda ti: set_homologene_id_field(trait_type, ti, conn),
-        lambda ti: {"trait_type": trait_type, **ti},
+        lambda ti: load_qtl_info(qtl, trait_dataset_type, ti, conn),
+        lambda ti: set_homologene_id_field(trait_dataset_type, ti, conn),
+        lambda ti: {"trait_type": trait_dataset_type, **ti},
         lambda ti: {**trait, **ti})
 
     trait_post_processing_functions_table = {
         "Publish": compose(
-            lambda ti: set_confidential_field(trait_type, ti),
+            lambda ti: set_confidential_field(trait_dataset_type, ti),
             common_post_processing_fn),
         "ProbeSet": compose(
             lambda ti: retrieve_probeset_sequence(ti, conn),
@@ -391,9 +405,10 @@ def retrieve_trait_info(
     }
 
     retrieve_info = compose(
-        set_haveinfo_field, trait_info_function_table[trait_type])
+        set_haveinfo_field, trait_info_function_table[trait_dataset_type])
 
-    trait_dataset = retrieve_trait_dataset(trait_type, trait, threshold, conn)
+    trait_dataset = retrieve_trait_dataset(
+        trait_dataset_type, trait, threshold, conn)
     trait_info = retrieve_info(
         {
             "trait_name": trait["trait_name"],
@@ -403,7 +418,7 @@ def retrieve_trait_info(
         conn)
     if trait_info["haveinfo"]:
         return {
-            **trait_post_processing_functions_table[trait_type](trait_info),
+            **trait_post_processing_functions_table[trait_dataset_type](trait_info),
             "db": {**trait["db"], **trait_dataset},
             "riset": trait_dataset["riset"]
         }
diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py
index d9d7bbb..ee98893 100644
--- a/tests/unit/db/test_traits.py
+++ b/tests/unit/db/test_traits.py
@@ -126,11 +126,12 @@ class TestTraitsDBFunctions(TestCase):
         """
         for fullname, expected in [
                 ["testdb::testname",
-                 {"db": {"dataset_name": "testdb"}, "trait_name": "testname",
-                  "cellid": "", "trait_fullname": "testdb::testname"}],
+                 {"db": {"dataset_name": "testdb", "dataset_type": "ProbeSet"},
+                  "trait_name": "testname", "cellid": "",
+                  "trait_fullname": "testdb::testname"}],
                 ["testdb::testname::testcell",
-                 {"db": {"dataset_name": "testdb"}, "trait_name": "testname",
-                  "cellid": "testcell",
+                 {"db": {"dataset_name": "testdb", "dataset_type": "ProbeSet"},
+                  "trait_name": "testname", "cellid": "testcell",
                   "trait_fullname": "testdb::testname::testcell"}]]:
             with self.subTest(fullname=fullname):
                 self.assertEqual(build_trait_name(fullname), expected)
@@ -146,22 +147,18 @@ class TestTraitsDBFunctions(TestCase):
 
     def test_retrieve_trait_info(self):
         """Test that information on traits is retrieved as appropriate."""
-        for trait_type, threshold, trait_fullname, expected in [
-                ["Publish", 9, "pubDb::PublishTraitName::pubCell",
-                 {"haveinfo": 0}],
-                ["ProbeSet", 5, "prbDb::ProbeSetTraitName::prbCell",
-                 {"haveinfo": 0}],
-                ["Geno", 12, "genDb::GenoTraitName",
-                 {"haveinfo": 0}],
-                ["Temp", 6, "tmpDb::TempTraitName",
-                 {"haveinfo": 0}]]:
+        for threshold, trait_fullname, expected in [
+                [9, "pubDb::PublishTraitName::pubCell", {"haveinfo": 0}],
+                [5, "prbDb::ProbeSetTraitName::prbCell", {"haveinfo": 0}],
+                [12, "genDb::GenoTraitName", {"haveinfo": 0}],
+                [6, "tmpDb::TempTraitName", {"haveinfo": 0}]]:
             db_mock = mock.MagicMock()
-            with self.subTest(trait_type=trait_type):
+            with self.subTest(trait_fullname=trait_fullname):
                 with db_mock.cursor() as cursor:
                     cursor.fetchone.return_value = tuple()
                     self.assertEqual(
                         retrieve_trait_info(
-                            trait_type, threshold, trait_fullname, db_mock),
+                            threshold, trait_fullname, db_mock),
                         expected)
 
     def test_update_sample_data(self):