diff options
-rw-r--r-- | gn3/api/correlation.py | 26 | ||||
-rw-r--r-- | gn3/computations/partial_correlations.py | 5 | ||||
-rw-r--r-- | gn3/db/partial_correlations.py | 116 | ||||
-rw-r--r-- | pytest.ini | 1 | ||||
-rw-r--r-- | tests/integration/test_partial_correlations.py | 51 |
5 files changed, 133 insertions, 66 deletions
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py index 57b808e..cbe01d8 100644 --- a/gn3/api/correlation.py +++ b/gn3/api/correlation.py @@ -118,25 +118,25 @@ def partial_correlation(): return str(o, encoding="utf-8") return json.JSONEncoder.default(self, o) + def __build_response__(data): + status_codes = {"error": 400, "not-found": 404, "success": 200} + response = make_response( + json.dumps(data, cls=OutputEncoder), + status_codes[data["status"]]) + response.headers["Content-Type"] = "application/json" + return response + args = request.get_json() request_errors = __errors__( args, ("primary_trait", "control_traits", "target_db", "method")) if request_errors: - response = make_response( - json.dumps({ - "status": "error", - "messages": request_errors, - "error_type": "Client Error"}), - 400) - response.headers["Content-Type"] = "application/json" - return response + return __build_response__({ + "status": "error", + "messages": request_errors, + "error_type": "Client Error"}) conn, _cursor_object = database_connector() corr_results = partial_correlations_entry( conn, trait_fullname(args["primary_trait"]), tuple(trait_fullname(trait) for trait in args["control_traits"]), args["method"], int(args.get("criteria", 500)), args["target_db"]) - response = make_response( - json.dumps(corr_results, cls=OutputEncoder), - 400 if "error" in corr_results.keys() else 200) - response.headers["Content-Type"] = "application/json" - return response + return __build_response__(corr_results) diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py index 85e3c11..16cbbdb 100644 --- a/gn3/computations/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -616,6 +616,11 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911] primary_trait = tuple( trait for trait in all_traits if trait["trait_fullname"] == primary_trait_name)[0] + if not primary_trait["haveinfo"]: + return { + "status": "not-found", + "message": f"Could not find primary trait {primary_trait['trait_fullname']}" + } group = primary_trait["db"]["group"] primary_trait_data = all_traits_data[primary_trait["trait_name"]] primary_samples, primary_values, _primary_variances = export_informative( diff --git a/gn3/db/partial_correlations.py b/gn3/db/partial_correlations.py index 157f8ee..3e77367 100644 --- a/gn3/db/partial_correlations.py +++ b/gn3/db/partial_correlations.py @@ -62,7 +62,9 @@ def publish_traits_data(conn, traits): """ Retrieve trait data for `Publish` traits. """ - dataset_ids = tuple(set(trait["db"]["dataset_id"] for trait in traits)) + dataset_ids = tuple(set( + trait["db"]["dataset_id"] for trait in traits + if trait["db"].get("dataset_id") is not None)) query = ( "SELECT " "PublishXRef.Id AS trait_name, Strain.Name AS sample_name, " @@ -83,12 +85,13 @@ def publish_traits_data(conn, traits): "ORDER BY Strain.Name").format( trait_names=", ".join(["%s"] * len(traits)), dataset_ids=", ".join(["%s"] * len(dataset_ids))) - with conn.cursor(cursorclass=DictCursor) as cursor: - cursor.execute( - query, - tuple(trait["trait_name"] for trait in traits) + - tuple(dataset_ids)) - return organise_trait_data_by_trait(cursor.fetchall()) + if len(dataset_ids) > 0: + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + tuple(trait["trait_name"] for trait in traits) + + tuple(dataset_ids)) + return organise_trait_data_by_trait(cursor.fetchall()) return {} def cellid_traits_data(conn, traits): @@ -161,15 +164,18 @@ def species_ids(conn, traits): """ Retrieve the IDS of the related species from the given list of traits. """ - groups = tuple(set(trait["db"]["group"] for trait in traits)) + groups = tuple(set( + trait["db"]["group"] for trait in traits + if trait["db"].get("group") is not None)) query = ( "SELECT Name AS `group`, SpeciesId AS species_id " "FROM InbredSet " "WHERE Name IN ({groups})").format( groups=", ".join(["%s"] * len(groups))) - with conn.cursor(cursorclass=DictCursor) as cursor: - cursor.execute(query, groups) - return tuple(row for row in cursor.fetchall()) + if len(groups) > 0: + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute(query, groups) + return tuple(row for row in cursor.fetchall()) return tuple() def geno_traits_data(conn, traits): @@ -194,12 +200,13 @@ def geno_traits_data(conn, traits): species_ids=sp_ids, trait_names=", ".join(["%s"] * len(traits)), dataset_names=", ".join(["%s"] * len(dataset_names))) - with conn.cursor(cursorclass=DictCursor) as cursor: - cursor.execute( - query, - tuple(trait["trait_name"] for trait in traits) + - tuple(dataset_names)) - return organise_trait_data_by_trait(cursor.fetchall()) + if len(sp_ids) > 0 and len(dataset_names) > 0: + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + tuple(trait["trait_name"] for trait in traits) + + tuple(dataset_names)) + return organise_trait_data_by_trait(cursor.fetchall()) return {} def traits_data( @@ -283,7 +290,9 @@ def publish_traits_info( this one fetches multiple items in a single query, unlike the original that fetches one item per query. """ - trait_dataset_ids = set(trait["db"]["dataset_id"] for trait in traits) + trait_dataset_ids = set( + trait["db"]["dataset_id"] for trait in traits + if trait["db"].get("dataset_id") is not None) columns = ( "PublishXRef.Id, Publication.PubMed_ID, " "Phenotype.Pre_publication_description, " @@ -311,13 +320,14 @@ def publish_traits_info( columns=columns, trait_names=", ".join(["%s"] * len(traits)), trait_dataset_ids=", ".join(["%s"] * len(trait_dataset_ids))) - with conn.cursor(cursorclass=DictCursor) as cursor: - cursor.execute( - query, - ( - tuple(trait["trait_name"] for trait in traits) + - tuple(trait_dataset_ids))) - return merge_traits_and_info(traits, cursor.fetchall()) + if trait_dataset_ids: + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + ( + tuple(trait["trait_name"] for trait in traits) + + tuple(trait_dataset_ids))) + return merge_traits_and_info(traits, cursor.fetchall()) return tuple({**trait, "haveinfo": False} for trait in traits) def probeset_traits_info( @@ -728,33 +738,35 @@ def set_homologene_id(conn, traits): """ Retrieve and set the 'homologene_id' values for ProbeSet traits. """ - geneids = set(trait["geneid"] for trait in traits) - groups = set(trait["db"]["group"] for trait in traits) - query = ( - "SELECT InbredSet.Name AS `group`, Homologene.GeneId AS geneid, " - "HomologeneId " - "FROM Homologene, Species, InbredSet " - "WHERE Homologene.GeneId IN ({geneids}) " - "AND InbredSet.Name IN ({groups}) " - "AND InbredSet.SpeciesId = Species.Id " - "AND Species.TaxonomyId = Homologene.TaxonomyId").format( - geneids=", ".join(["%s"] * len(geneids)), - groups=", ".join(["%s"] * len(groups))) - with conn.cursor(cursorclass=DictCursor) as cursor: - cursor.execute(query, (tuple(geneids) + tuple(groups))) - results = { - row["group"]: { - row["geneid"]: { - key: val for key, val in row.items() - if key not in ("group", "geneid") - } - } for row in cursor.fetchall() - } - return tuple( - { - **trait, **results.get( - trait["db"]["group"], {}).get(trait["geneid"], {}) - } for trait in traits) + geneids = set(trait.get("geneid") for trait in traits if trait["haveinfo"]) + groups = set( + trait["db"].get("group") for trait in traits if trait["haveinfo"]) + if len(geneids) > 1 and len(groups) > 1: + query = ( + "SELECT InbredSet.Name AS `group`, Homologene.GeneId AS geneid, " + "HomologeneId " + "FROM Homologene, Species, InbredSet " + "WHERE Homologene.GeneId IN ({geneids}) " + "AND InbredSet.Name IN ({groups}) " + "AND InbredSet.SpeciesId = Species.Id " + "AND Species.TaxonomyId = Homologene.TaxonomyId").format( + geneids=", ".join(["%s"] * len(geneids)), + groups=", ".join(["%s"] * len(groups))) + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute(query, (tuple(geneids) + tuple(groups))) + results = { + row["group"]: { + row["geneid"]: { + key: val for key, val in row.items() + if key not in ("group", "geneid") + } + } for row in cursor.fetchall() + } + return tuple( + { + **trait, **results.get( + trait["db"]["group"], {}).get(trait["geneid"], {}) + } for trait in traits) return traits def traits_datasets(conn, threshold, traits): @@ -1,6 +1,7 @@ [pytest] addopts = --strict-markers markers = + slow unit_test integration_test performance_test
\ No newline at end of file diff --git a/tests/integration/test_partial_correlations.py b/tests/integration/test_partial_correlations.py index 5b520e0..17ea539 100644 --- a/tests/integration/test_partial_correlations.py +++ b/tests/integration/test_partial_correlations.py @@ -83,8 +83,57 @@ from tests.integration.conftest import client "target_db": None })) def test_partial_correlation_api_with_missing_request_data(client, post_data): - "Test /api/correlations/partial" + """ + Test /api/correlations/partial endpoint with various expected request data + missing. + """ response = client.post("/api/correlation/partial", json=post_data) assert ( response.status_code == 400 and response.is_json and response.json.get("status") == "error") + + +@pytest.mark.integration_test +@pytest.mark.slow +@pytest.mark.parametrize( + "post_data", + ({# ProbeSet + "primary_trait": {"dataset": "a_dataset", "name": "a_name"}, + "control_traits": [ + {"dataset": "a_dataset", "name": "a_name"}, + {"dataset": "a_dataset2", "name": "a_name2"}], + "method": "a_method", + "target_db": "a_db" + }, {# Publish + "primary_trait": {"dataset": "a_Publish_dataset", "name": "a_name"}, + "control_traits": [ + {"dataset": "a_dataset", "name": "a_name"}, + {"dataset": "a_dataset2", "name": "a_name2"}], + "method": "a_method", + "target_db": "a_db" + }, {# Geno + "primary_trait": {"dataset": "a_Geno_dataset", "name": "a_name"}, + "control_traits": [ + {"dataset": "a_dataset", "name": "a_name"}, + {"dataset": "a_dataset2", "name": "a_name2"}], + "method": "a_method", + "target_db": "a_db" + }, {# Temp -- Fails due to missing table. Remove this sample if it is + # confirmed that the deletion of the database table is on purpose, and + # that Temp traits are no longer a thing + "primary_trait": {"dataset": "a_Temp_dataset", "name": "a_name"}, + "control_traits": [ + {"dataset": "a_dataset", "name": "a_name"}, + {"dataset": "a_dataset2", "name": "a_name2"}], + "method": "a_method", + "target_db": "a_db" + })) +def test_partial_correlation_api_with_non_existent_traits(client, post_data): + """ + Check that the system responds appropriately in the case where the user + makes a request with a non-existent primary trait. + """ + response = client.post("/api/correlation/partial", json=post_data) + assert ( + response.status_code == 404 and response.is_json and + response.json.get("status") != "error") |