aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/api/correlation.py26
-rw-r--r--gn3/computations/partial_correlations.py5
-rw-r--r--gn3/db/partial_correlations.py116
-rw-r--r--pytest.ini1
-rw-r--r--tests/integration/test_partial_correlations.py51
5 files changed, 133 insertions, 66 deletions
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py
index 57b808e..cbe01d8 100644
--- a/gn3/api/correlation.py
+++ b/gn3/api/correlation.py
@@ -118,25 +118,25 @@ def partial_correlation():
return str(o, encoding="utf-8")
return json.JSONEncoder.default(self, o)
+ def __build_response__(data):
+ status_codes = {"error": 400, "not-found": 404, "success": 200}
+ response = make_response(
+ json.dumps(data, cls=OutputEncoder),
+ status_codes[data["status"]])
+ response.headers["Content-Type"] = "application/json"
+ return response
+
args = request.get_json()
request_errors = __errors__(
args, ("primary_trait", "control_traits", "target_db", "method"))
if request_errors:
- response = make_response(
- json.dumps({
- "status": "error",
- "messages": request_errors,
- "error_type": "Client Error"}),
- 400)
- response.headers["Content-Type"] = "application/json"
- return response
+ return __build_response__({
+ "status": "error",
+ "messages": request_errors,
+ "error_type": "Client Error"})
conn, _cursor_object = database_connector()
corr_results = partial_correlations_entry(
conn, trait_fullname(args["primary_trait"]),
tuple(trait_fullname(trait) for trait in args["control_traits"]),
args["method"], int(args.get("criteria", 500)), args["target_db"])
- response = make_response(
- json.dumps(corr_results, cls=OutputEncoder),
- 400 if "error" in corr_results.keys() else 200)
- response.headers["Content-Type"] = "application/json"
- return response
+ return __build_response__(corr_results)
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 85e3c11..16cbbdb 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -616,6 +616,11 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
primary_trait = tuple(
trait for trait in all_traits
if trait["trait_fullname"] == primary_trait_name)[0]
+ if not primary_trait["haveinfo"]:
+ return {
+ "status": "not-found",
+ "message": f"Could not find primary trait {primary_trait['trait_fullname']}"
+ }
group = primary_trait["db"]["group"]
primary_trait_data = all_traits_data[primary_trait["trait_name"]]
primary_samples, primary_values, _primary_variances = export_informative(
diff --git a/gn3/db/partial_correlations.py b/gn3/db/partial_correlations.py
index 157f8ee..3e77367 100644
--- a/gn3/db/partial_correlations.py
+++ b/gn3/db/partial_correlations.py
@@ -62,7 +62,9 @@ def publish_traits_data(conn, traits):
"""
Retrieve trait data for `Publish` traits.
"""
- dataset_ids = tuple(set(trait["db"]["dataset_id"] for trait in traits))
+ dataset_ids = tuple(set(
+ trait["db"]["dataset_id"] for trait in traits
+ if trait["db"].get("dataset_id") is not None))
query = (
"SELECT "
"PublishXRef.Id AS trait_name, Strain.Name AS sample_name, "
@@ -83,12 +85,13 @@ def publish_traits_data(conn, traits):
"ORDER BY Strain.Name").format(
trait_names=", ".join(["%s"] * len(traits)),
dataset_ids=", ".join(["%s"] * len(dataset_ids)))
- with conn.cursor(cursorclass=DictCursor) as cursor:
- cursor.execute(
- query,
- tuple(trait["trait_name"] for trait in traits) +
- tuple(dataset_ids))
- return organise_trait_data_by_trait(cursor.fetchall())
+ if len(dataset_ids) > 0:
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(trait["trait_name"] for trait in traits) +
+ tuple(dataset_ids))
+ return organise_trait_data_by_trait(cursor.fetchall())
return {}
def cellid_traits_data(conn, traits):
@@ -161,15 +164,18 @@ def species_ids(conn, traits):
"""
Retrieve the IDS of the related species from the given list of traits.
"""
- groups = tuple(set(trait["db"]["group"] for trait in traits))
+ groups = tuple(set(
+ trait["db"]["group"] for trait in traits
+ if trait["db"].get("group") is not None))
query = (
"SELECT Name AS `group`, SpeciesId AS species_id "
"FROM InbredSet "
"WHERE Name IN ({groups})").format(
groups=", ".join(["%s"] * len(groups)))
- with conn.cursor(cursorclass=DictCursor) as cursor:
- cursor.execute(query, groups)
- return tuple(row for row in cursor.fetchall())
+ if len(groups) > 0:
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, groups)
+ return tuple(row for row in cursor.fetchall())
return tuple()
def geno_traits_data(conn, traits):
@@ -194,12 +200,13 @@ def geno_traits_data(conn, traits):
species_ids=sp_ids,
trait_names=", ".join(["%s"] * len(traits)),
dataset_names=", ".join(["%s"] * len(dataset_names)))
- with conn.cursor(cursorclass=DictCursor) as cursor:
- cursor.execute(
- query,
- tuple(trait["trait_name"] for trait in traits) +
- tuple(dataset_names))
- return organise_trait_data_by_trait(cursor.fetchall())
+ if len(sp_ids) > 0 and len(dataset_names) > 0:
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(trait["trait_name"] for trait in traits) +
+ tuple(dataset_names))
+ return organise_trait_data_by_trait(cursor.fetchall())
return {}
def traits_data(
@@ -283,7 +290,9 @@ def publish_traits_info(
this one fetches multiple items in a single query, unlike the original that
fetches one item per query.
"""
- trait_dataset_ids = set(trait["db"]["dataset_id"] for trait in traits)
+ trait_dataset_ids = set(
+ trait["db"]["dataset_id"] for trait in traits
+ if trait["db"].get("dataset_id") is not None)
columns = (
"PublishXRef.Id, Publication.PubMed_ID, "
"Phenotype.Pre_publication_description, "
@@ -311,13 +320,14 @@ def publish_traits_info(
columns=columns,
trait_names=", ".join(["%s"] * len(traits)),
trait_dataset_ids=", ".join(["%s"] * len(trait_dataset_ids)))
- with conn.cursor(cursorclass=DictCursor) as cursor:
- cursor.execute(
- query,
- (
- tuple(trait["trait_name"] for trait in traits) +
- tuple(trait_dataset_ids)))
- return merge_traits_and_info(traits, cursor.fetchall())
+ if trait_dataset_ids:
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ (
+ tuple(trait["trait_name"] for trait in traits) +
+ tuple(trait_dataset_ids)))
+ return merge_traits_and_info(traits, cursor.fetchall())
return tuple({**trait, "haveinfo": False} for trait in traits)
def probeset_traits_info(
@@ -728,33 +738,35 @@ def set_homologene_id(conn, traits):
"""
Retrieve and set the 'homologene_id' values for ProbeSet traits.
"""
- geneids = set(trait["geneid"] for trait in traits)
- groups = set(trait["db"]["group"] for trait in traits)
- query = (
- "SELECT InbredSet.Name AS `group`, Homologene.GeneId AS geneid, "
- "HomologeneId "
- "FROM Homologene, Species, InbredSet "
- "WHERE Homologene.GeneId IN ({geneids}) "
- "AND InbredSet.Name IN ({groups}) "
- "AND InbredSet.SpeciesId = Species.Id "
- "AND Species.TaxonomyId = Homologene.TaxonomyId").format(
- geneids=", ".join(["%s"] * len(geneids)),
- groups=", ".join(["%s"] * len(groups)))
- with conn.cursor(cursorclass=DictCursor) as cursor:
- cursor.execute(query, (tuple(geneids) + tuple(groups)))
- results = {
- row["group"]: {
- row["geneid"]: {
- key: val for key, val in row.items()
- if key not in ("group", "geneid")
- }
- } for row in cursor.fetchall()
- }
- return tuple(
- {
- **trait, **results.get(
- trait["db"]["group"], {}).get(trait["geneid"], {})
- } for trait in traits)
+ geneids = set(trait.get("geneid") for trait in traits if trait["haveinfo"])
+ groups = set(
+ trait["db"].get("group") for trait in traits if trait["haveinfo"])
+ if len(geneids) > 1 and len(groups) > 1:
+ query = (
+ "SELECT InbredSet.Name AS `group`, Homologene.GeneId AS geneid, "
+ "HomologeneId "
+ "FROM Homologene, Species, InbredSet "
+ "WHERE Homologene.GeneId IN ({geneids}) "
+ "AND InbredSet.Name IN ({groups}) "
+ "AND InbredSet.SpeciesId = Species.Id "
+ "AND Species.TaxonomyId = Homologene.TaxonomyId").format(
+ geneids=", ".join(["%s"] * len(geneids)),
+ groups=", ".join(["%s"] * len(groups)))
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, (tuple(geneids) + tuple(groups)))
+ results = {
+ row["group"]: {
+ row["geneid"]: {
+ key: val for key, val in row.items()
+ if key not in ("group", "geneid")
+ }
+ } for row in cursor.fetchall()
+ }
+ return tuple(
+ {
+ **trait, **results.get(
+ trait["db"]["group"], {}).get(trait["geneid"], {})
+ } for trait in traits)
return traits
def traits_datasets(conn, threshold, traits):
diff --git a/pytest.ini b/pytest.ini
index 58eba11..ba87787 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -1,6 +1,7 @@
[pytest]
addopts = --strict-markers
markers =
+ slow
unit_test
integration_test
performance_test \ No newline at end of file
diff --git a/tests/integration/test_partial_correlations.py b/tests/integration/test_partial_correlations.py
index 5b520e0..17ea539 100644
--- a/tests/integration/test_partial_correlations.py
+++ b/tests/integration/test_partial_correlations.py
@@ -83,8 +83,57 @@ from tests.integration.conftest import client
"target_db": None
}))
def test_partial_correlation_api_with_missing_request_data(client, post_data):
- "Test /api/correlations/partial"
+ """
+ Test /api/correlations/partial endpoint with various expected request data
+ missing.
+ """
response = client.post("/api/correlation/partial", json=post_data)
assert (
response.status_code == 400 and response.is_json and
response.json.get("status") == "error")
+
+
+@pytest.mark.integration_test
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "post_data",
+ ({# ProbeSet
+ "primary_trait": {"dataset": "a_dataset", "name": "a_name"},
+ "control_traits": [
+ {"dataset": "a_dataset", "name": "a_name"},
+ {"dataset": "a_dataset2", "name": "a_name2"}],
+ "method": "a_method",
+ "target_db": "a_db"
+ }, {# Publish
+ "primary_trait": {"dataset": "a_Publish_dataset", "name": "a_name"},
+ "control_traits": [
+ {"dataset": "a_dataset", "name": "a_name"},
+ {"dataset": "a_dataset2", "name": "a_name2"}],
+ "method": "a_method",
+ "target_db": "a_db"
+ }, {# Geno
+ "primary_trait": {"dataset": "a_Geno_dataset", "name": "a_name"},
+ "control_traits": [
+ {"dataset": "a_dataset", "name": "a_name"},
+ {"dataset": "a_dataset2", "name": "a_name2"}],
+ "method": "a_method",
+ "target_db": "a_db"
+ }, {# Temp -- Fails due to missing table. Remove this sample if it is
+ # confirmed that the deletion of the database table is on purpose, and
+ # that Temp traits are no longer a thing
+ "primary_trait": {"dataset": "a_Temp_dataset", "name": "a_name"},
+ "control_traits": [
+ {"dataset": "a_dataset", "name": "a_name"},
+ {"dataset": "a_dataset2", "name": "a_name2"}],
+ "method": "a_method",
+ "target_db": "a_db"
+ }))
+def test_partial_correlation_api_with_non_existent_traits(client, post_data):
+ """
+ Check that the system responds appropriately in the case where the user
+ makes a request with a non-existent primary trait.
+ """
+ response = client.post("/api/correlation/partial", json=post_data)
+ assert (
+ response.status_code == 404 and response.is_json and
+ response.json.get("status") != "error")