aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/db/traits.py33
-rw-r--r--tests/unit/db/test_traits.py27
2 files changed, 36 insertions, 24 deletions
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index a740352..6ea24be 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -326,10 +326,23 @@ def build_trait_name(trait_fullname):
"""
Initialises the trait's name, and other values from the search data provided
"""
+ def dataset_type(dset_name):
+ if dset_name.find('Temp') >= 0:
+ return "Temp"
+ if dset_name.find('Geno') >= 0:
+ return "Geno"
+ if dset_name.find('Publish') >= 0:
+ return "Publish"
+ return "ProbeSet"
+
name_parts = trait_fullname.split("::")
assert len(name_parts) >= 2, "Name format error"
+ dataset_name = name_parts[0]
+ dataset_type = dataset_type(dataset_name)
return {
- "db": {"dataset_name": name_parts[0]},
+ "db": {
+ "dataset_name": dataset_name,
+ "dataset_type": dataset_type},
"trait_fullname": trait_fullname,
"trait_name": name_parts[1],
"cellid": name_parts[2] if len(name_parts) == 3 else ""
@@ -357,7 +370,7 @@ def retrieve_probeset_sequence(trait, conn):
return {**trait, "sequence": seq[0] if seq else ""}
def retrieve_trait_info(
- trait_type: str, threshold: int, trait_full_name: str, conn: Any,
+ threshold: int, trait_full_name: str, conn: Any,
qtl=None):
"""Retrieves the trait information.
@@ -366,6 +379,7 @@ def retrieve_trait_info(
This function, or the dependent functions, might be incomplete as they are
currently."""
trait = build_trait_name(trait_full_name)
+ trait_dataset_type = trait["db"]["dataset_type"]
trait_info_function_table = {
"Publish": retrieve_publish_trait_info,
"ProbeSet": retrieve_probeset_trait_info,
@@ -374,14 +388,14 @@ def retrieve_trait_info(
}
common_post_processing_fn = compose(
- lambda ti: load_qtl_info(qtl, trait_type, ti, conn),
- lambda ti: set_homologene_id_field(trait_type, ti, conn),
- lambda ti: {"trait_type": trait_type, **ti},
+ lambda ti: load_qtl_info(qtl, trait_dataset_type, ti, conn),
+ lambda ti: set_homologene_id_field(trait_dataset_type, ti, conn),
+ lambda ti: {"trait_type": trait_dataset_type, **ti},
lambda ti: {**trait, **ti})
trait_post_processing_functions_table = {
"Publish": compose(
- lambda ti: set_confidential_field(trait_type, ti),
+ lambda ti: set_confidential_field(trait_dataset_type, ti),
common_post_processing_fn),
"ProbeSet": compose(
lambda ti: retrieve_probeset_sequence(ti, conn),
@@ -391,9 +405,10 @@ def retrieve_trait_info(
}
retrieve_info = compose(
- set_haveinfo_field, trait_info_function_table[trait_type])
+ set_haveinfo_field, trait_info_function_table[trait_dataset_type])
- trait_dataset = retrieve_trait_dataset(trait_type, trait, threshold, conn)
+ trait_dataset = retrieve_trait_dataset(
+ trait_dataset_type, trait, threshold, conn)
trait_info = retrieve_info(
{
"trait_name": trait["trait_name"],
@@ -403,7 +418,7 @@ def retrieve_trait_info(
conn)
if trait_info["haveinfo"]:
return {
- **trait_post_processing_functions_table[trait_type](trait_info),
+ **trait_post_processing_functions_table[trait_dataset_type](trait_info),
"db": {**trait["db"], **trait_dataset},
"riset": trait_dataset["riset"]
}
diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py
index d9d7bbb..ee98893 100644
--- a/tests/unit/db/test_traits.py
+++ b/tests/unit/db/test_traits.py
@@ -126,11 +126,12 @@ class TestTraitsDBFunctions(TestCase):
"""
for fullname, expected in [
["testdb::testname",
- {"db": {"dataset_name": "testdb"}, "trait_name": "testname",
- "cellid": "", "trait_fullname": "testdb::testname"}],
+ {"db": {"dataset_name": "testdb", "dataset_type": "ProbeSet"},
+ "trait_name": "testname", "cellid": "",
+ "trait_fullname": "testdb::testname"}],
["testdb::testname::testcell",
- {"db": {"dataset_name": "testdb"}, "trait_name": "testname",
- "cellid": "testcell",
+ {"db": {"dataset_name": "testdb", "dataset_type": "ProbeSet"},
+ "trait_name": "testname", "cellid": "testcell",
"trait_fullname": "testdb::testname::testcell"}]]:
with self.subTest(fullname=fullname):
self.assertEqual(build_trait_name(fullname), expected)
@@ -146,22 +147,18 @@ class TestTraitsDBFunctions(TestCase):
def test_retrieve_trait_info(self):
"""Test that information on traits is retrieved as appropriate."""
- for trait_type, threshold, trait_fullname, expected in [
- ["Publish", 9, "pubDb::PublishTraitName::pubCell",
- {"haveinfo": 0}],
- ["ProbeSet", 5, "prbDb::ProbeSetTraitName::prbCell",
- {"haveinfo": 0}],
- ["Geno", 12, "genDb::GenoTraitName",
- {"haveinfo": 0}],
- ["Temp", 6, "tmpDb::TempTraitName",
- {"haveinfo": 0}]]:
+ for threshold, trait_fullname, expected in [
+ [9, "pubDb::PublishTraitName::pubCell", {"haveinfo": 0}],
+ [5, "prbDb::ProbeSetTraitName::prbCell", {"haveinfo": 0}],
+ [12, "genDb::GenoTraitName", {"haveinfo": 0}],
+ [6, "tmpDb::TempTraitName", {"haveinfo": 0}]]:
db_mock = mock.MagicMock()
- with self.subTest(trait_type=trait_type):
+ with self.subTest(trait_fullname=trait_fullname):
with db_mock.cursor() as cursor:
cursor.fetchone.return_value = tuple()
self.assertEqual(
retrieve_trait_info(
- trait_type, threshold, trait_fullname, db_mock),
+ threshold, trait_fullname, db_mock),
expected)
def test_update_sample_data(self):