about summary refs log tree commit diff
path: root/gn3/db
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-02-18 07:17:50 +0300
committerFrederick Muriuki Muriithi2022-02-18 07:17:50 +0300
commit83a7aa7533f8f4ecac049dc0e93aff6429e6e5ae (patch)
tree0d0e77acba6570a7362c1380fa606535bc90e216 /gn3/db
parent12db8134081bc679565dfff1aaa2da81f913dd9d (diff)
downloadgenenetwork3-83a7aa7533f8f4ecac049dc0e93aff6429e6e5ae.tar.gz
Test partial correlations endpoint with non-existent primary traits
Test that the partial correlations endpoint responds with an appropriate
"not-found" message and the corresponding 404 status code in the case where a
request is made and the primary trait requested for does not exist in the
database.

Summary of the changes in each file:
* gn3/api/correlation.py: generalise the building of the response
* gn3/computations/partial_correlations.py: return with a "not-found" if the
  primary trait does not exist in the database
* gn3/db/partial_correlations.py: Fix a number of bugs that led to exceptions
  in the case that the primary trait did not exist
* pytest.ini: register a `slow` pytest marker
* tests/integration/test_partial_correlations.py: Add a new test to check for
  an appropriate 404 response in case of a primary trait that does not exist
  in the database.
Diffstat (limited to 'gn3/db')
-rw-r--r--gn3/db/partial_correlations.py116
1 files changed, 64 insertions, 52 deletions
diff --git a/gn3/db/partial_correlations.py b/gn3/db/partial_correlations.py
index 157f8ee..3e77367 100644
--- a/gn3/db/partial_correlations.py
+++ b/gn3/db/partial_correlations.py
@@ -62,7 +62,9 @@ def publish_traits_data(conn, traits):
     """
     Retrieve trait data for `Publish` traits.
     """
-    dataset_ids = tuple(set(trait["db"]["dataset_id"] for trait in traits))
+    dataset_ids = tuple(set(
+        trait["db"]["dataset_id"] for trait in traits
+        if trait["db"].get("dataset_id") is not None))
     query = (
         "SELECT "
         "PublishXRef.Id AS trait_name, Strain.Name AS sample_name, "
@@ -83,12 +85,13 @@ def publish_traits_data(conn, traits):
         "ORDER BY Strain.Name").format(
             trait_names=", ".join(["%s"] * len(traits)),
             dataset_ids=", ".join(["%s"] * len(dataset_ids)))
-    with conn.cursor(cursorclass=DictCursor) as cursor:
-        cursor.execute(
-            query,
-            tuple(trait["trait_name"] for trait in traits) +
-            tuple(dataset_ids))
-        return organise_trait_data_by_trait(cursor.fetchall())
+    if len(dataset_ids) > 0:
+        with conn.cursor(cursorclass=DictCursor) as cursor:
+            cursor.execute(
+                query,
+                tuple(trait["trait_name"] for trait in traits) +
+                tuple(dataset_ids))
+            return organise_trait_data_by_trait(cursor.fetchall())
     return {}
 
 def cellid_traits_data(conn, traits):
@@ -161,15 +164,18 @@ def species_ids(conn, traits):
     """
     Retrieve the IDS of the related species from the given list of traits.
     """
-    groups = tuple(set(trait["db"]["group"] for trait in traits))
+    groups = tuple(set(
+        trait["db"]["group"] for trait in traits
+        if trait["db"].get("group") is not None))
     query = (
         "SELECT Name AS `group`, SpeciesId AS species_id "
         "FROM InbredSet "
         "WHERE Name IN ({groups})").format(
             groups=", ".join(["%s"] * len(groups)))
-    with conn.cursor(cursorclass=DictCursor) as cursor:
-        cursor.execute(query, groups)
-        return tuple(row for row in cursor.fetchall())
+    if len(groups) > 0:
+        with conn.cursor(cursorclass=DictCursor) as cursor:
+            cursor.execute(query, groups)
+            return tuple(row for row in cursor.fetchall())
     return tuple()
 
 def geno_traits_data(conn, traits):
@@ -194,12 +200,13 @@ def geno_traits_data(conn, traits):
             species_ids=sp_ids,
             trait_names=", ".join(["%s"] * len(traits)),
             dataset_names=", ".join(["%s"] * len(dataset_names)))
-    with conn.cursor(cursorclass=DictCursor) as cursor:
-        cursor.execute(
-            query,
-            tuple(trait["trait_name"] for trait in traits) +
-            tuple(dataset_names))
-        return organise_trait_data_by_trait(cursor.fetchall())
+    if len(sp_ids) > 0 and len(dataset_names) > 0:
+        with conn.cursor(cursorclass=DictCursor) as cursor:
+            cursor.execute(
+                query,
+                tuple(trait["trait_name"] for trait in traits) +
+                tuple(dataset_names))
+            return organise_trait_data_by_trait(cursor.fetchall())
     return {}
 
 def traits_data(
@@ -283,7 +290,9 @@ def publish_traits_info(
     this one fetches multiple items in a single query, unlike the original that
     fetches one item per query.
     """
-    trait_dataset_ids = set(trait["db"]["dataset_id"] for trait in traits)
+    trait_dataset_ids = set(
+        trait["db"]["dataset_id"] for trait in traits
+        if trait["db"].get("dataset_id") is not None)
     columns = (
         "PublishXRef.Id, Publication.PubMed_ID, "
         "Phenotype.Pre_publication_description, "
@@ -311,13 +320,14 @@ def publish_traits_info(
             columns=columns,
             trait_names=", ".join(["%s"] * len(traits)),
             trait_dataset_ids=", ".join(["%s"] * len(trait_dataset_ids)))
-    with conn.cursor(cursorclass=DictCursor) as cursor:
-        cursor.execute(
-            query,
-            (
-                tuple(trait["trait_name"] for trait in traits) +
-                tuple(trait_dataset_ids)))
-        return merge_traits_and_info(traits, cursor.fetchall())
+    if trait_dataset_ids:
+        with conn.cursor(cursorclass=DictCursor) as cursor:
+            cursor.execute(
+                query,
+                (
+                    tuple(trait["trait_name"] for trait in traits) +
+                    tuple(trait_dataset_ids)))
+            return merge_traits_and_info(traits, cursor.fetchall())
     return tuple({**trait, "haveinfo": False} for trait in traits)
 
 def probeset_traits_info(
@@ -728,33 +738,35 @@ def set_homologene_id(conn, traits):
     """
     Retrieve and set the 'homologene_id' values for ProbeSet traits.
     """
-    geneids = set(trait["geneid"] for trait in traits)
-    groups = set(trait["db"]["group"] for trait in traits)
-    query = (
-        "SELECT InbredSet.Name AS `group`, Homologene.GeneId AS geneid, "
-        "HomologeneId "
-        "FROM Homologene, Species, InbredSet "
-        "WHERE Homologene.GeneId IN ({geneids}) "
-        "AND InbredSet.Name IN ({groups}) "
-        "AND InbredSet.SpeciesId = Species.Id "
-        "AND Species.TaxonomyId = Homologene.TaxonomyId").format(
-            geneids=", ".join(["%s"] * len(geneids)),
-            groups=", ".join(["%s"] * len(groups)))
-    with conn.cursor(cursorclass=DictCursor) as cursor:
-        cursor.execute(query, (tuple(geneids) + tuple(groups)))
-        results = {
-            row["group"]: {
-                row["geneid"]: {
-                    key: val for key, val in row.items()
-                    if key not in ("group", "geneid")
-                }
-            } for row in cursor.fetchall()
-        }
-        return tuple(
-            {
-                **trait, **results.get(
-                    trait["db"]["group"], {}).get(trait["geneid"], {})
-            } for trait in traits)
+    geneids = set(trait.get("geneid") for trait in traits if trait["haveinfo"])
+    groups = set(
+        trait["db"].get("group") for trait in traits if trait["haveinfo"])
+    if len(geneids) > 1 and len(groups) > 1:
+        query = (
+            "SELECT InbredSet.Name AS `group`, Homologene.GeneId AS geneid, "
+            "HomologeneId "
+            "FROM Homologene, Species, InbredSet "
+            "WHERE Homologene.GeneId IN ({geneids}) "
+            "AND InbredSet.Name IN ({groups}) "
+            "AND InbredSet.SpeciesId = Species.Id "
+            "AND Species.TaxonomyId = Homologene.TaxonomyId").format(
+                geneids=", ".join(["%s"] * len(geneids)),
+                groups=", ".join(["%s"] * len(groups)))
+        with conn.cursor(cursorclass=DictCursor) as cursor:
+            cursor.execute(query, (tuple(geneids) + tuple(groups)))
+            results = {
+                row["group"]: {
+                    row["geneid"]: {
+                        key: val for key, val in row.items()
+                        if key not in ("group", "geneid")
+                    }
+                } for row in cursor.fetchall()
+            }
+            return tuple(
+                {
+                    **trait, **results.get(
+                        trait["db"]["group"], {}).get(trait["geneid"], {})
+                } for trait in traits)
     return traits
 
 def traits_datasets(conn, threshold, traits):