about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/db/traits.py26
-rw-r--r--tests/unit/db/test_traits.py40
2 files changed, 41 insertions, 25 deletions
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index 37b111e..fddb8be 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -91,20 +91,24 @@ def insert_publication(pubmed_id: int, publication: Optional[Dict],
         with conn.cursor() as cursor:
             cursor.execute(insert_query, tuple(publication.values()))
 
-def retrieve_probeset_trait_name(threshold, name, connection):
+def retrieve_type_trait_name(trait_type, threshold, name, connection):
     """
-    Retrieve the name for a Probeset trait
+    Retrieve the name of a trait given the trait's name
 
-    This is extracted from the `webqtlDataset.retrieveName` function,
-    specifically the section dealing with 'ProbeSet' type traits
-    https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlDataset.py#L140-154"""
+    This is extracted from the `webqtlDataset.retrieveName` function as is
+    implemented at
+    https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlDataset.py#L140-L169
+    """
+    columns = "Id, Name, FullName, ShortName{}".format(
+        ", DataScale" if trait_type == "ProbeSet" else "")
     query = (
-        'SELECT Id, Name, FullName, ShortName, DataScale '
-        'FROM ProbeSetFreeze '
-        'WHERE '
-        'public > %(threshold)s '
-        'AND '
-        '(Name = %(name)s OR FullName = %(name)s OR ShortName = %(name)s)')
+        "SELECT {columns} "
+        "FROM {trait_type}Freeze "
+        "WHERE "
+        "public > %(threshold)s "
+        "AND "
+        "(Name = %(name)s OR FullName = %(name)s OR ShortName = %(name)s)").format(
+            columns=columns, trait_type=trait_type)
     with connection.cursor() as cursor:
         cursor.execute(query, {"threshold": threshold, "name": name})
         return cursor.fetchone()
diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py
index 6d2ba4d..95c5b27 100644
--- a/tests/unit/db/test_traits.py
+++ b/tests/unit/db/test_traits.py
@@ -1,22 +1,34 @@
 """Tests for gn3/db/traits.py"""
 from unittest import mock, TestCase
-from gn3.db.traits import retrieve_probeset_trait_name
+from gn3.db.traits import retrieve_type_trait_name
 
 class TestTraitsDBFunctions(TestCase):
     "Test cases for traits functions"
 
     def test_retrieve_probeset_trait_name(self):
         """Test that the function is called correctly."""
-        db_mock = mock.MagicMock()
-        with db_mock.cursor() as cursor:
-            cursor.fetchone.return_value = (
-                "testName", "testNameFull", "testNameShort", "dataScale")
-            self.assertEqual(
-                retrieve_probeset_trait_name(9, "testName", db_mock),
-                ("testName", "testNameFull", "testNameShort", "dataScale"))
-            cursor.execute.assert_called_once_with(
-                "SELECT Id, Name, FullName, ShortName, DataScale "
-                "FROM ProbeSetFreeze "
-                "WHERE public > %(threshold)s AND "
-                "(Name = %(name)s OR FullName = %(name)s OR ShortName = %(name)s)",
-                {"threshold": 9, "name": "testName"})
+        for trait_type, thresh, trait_name, columns in [
+                ["ProbeSet", 9, "testName",
+                 "Id, Name, FullName, ShortName, DataScale"],
+                ["Geno", 3, "genoTraitName", "Id, Name, FullName, ShortName"],
+                ["Publish", 6, "publishTraitName",
+                 "Id, Name, FullName, ShortName"],
+                ["Temp", 4, "tempTraitName", "Id, Name, FullName, ShortName"]]:
+            db_mock = mock.MagicMock()
+            with self.subTest(trait_type=trait_type):
+                with db_mock.cursor() as cursor:
+                    cursor.fetchone.return_value = (
+                        "testName", "testNameFull", "testNameShort",
+                        "dataScale")
+                    self.assertEqual(
+                        retrieve_type_trait_name(
+                            trait_type, thresh, trait_name, db_mock),
+                        ("testName", "testNameFull", "testNameShort",
+                         "dataScale"))
+                    cursor.execute.assert_called_once_with(
+                        "SELECT {cols} "
+                        "FROM {ttype}Freeze "
+                        "WHERE public > %(threshold)s AND "
+                        "(Name = %(name)s OR FullName = %(name)s OR ShortName = %(name)s)".format(
+                            cols=columns, ttype=trait_type),
+                        {"threshold": thresh, "name": trait_name})