aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/computations/partial_correlations.py4
-rw-r--r--gn3/db/correlations.py24
-rw-r--r--tests/unit/db/test_correlation.py6
3 files changed, 21 insertions, 13 deletions
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 4bd26a2..f43c4d4 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -200,11 +200,13 @@ def good_dataset_samples_indexes(
samples_from_file.index(good) for good in
set(samples).intersection(set(samples_from_file))))
-def compute_partial_correlations_fast(# pylint: disable=[R0913, R0914]
+def partial_correlations_fast(# pylint: disable=[R0913, R0914]
samples, primary_vals, control_vals, database_filename,
fetched_correlations, method: str, correlation_type: str) -> Tuple[
float, Tuple[float, ...]]:
"""
+ Computes partial correlation coefficients using data from a CSV file.
+
This is a partial migration of the
`web.webqtl.correlation.PartialCorrDBPage.getPartialCorrelationsFast`
function in GeneNetwork1.
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py
index 5c3e7b8..a1daa3c 100644
--- a/gn3/db/correlations.py
+++ b/gn3/db/correlations.py
@@ -398,8 +398,8 @@ def fetch_sample_ids(
"AND Species.name=%(species_name)s")
with conn.cursor() as cursor:
cursor.execute(
- query, samples_names=tuple(samples),
- species_name=species)
+ query, samples_names=tuple(sample_names),
+ species_name=species_name)
return cursor.fetchall()
def build_query_sgo_lit_corr(
@@ -419,7 +419,7 @@ def build_query_sgo_lit_corr(
f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " +
" ".join(joins) +
- f" WHERE ProbeSet.GeneId IS NOT NULL " +
+ " WHERE ProbeSet.GeneId IS NOT NULL " +
f"AND {temp_table}.value IS NOT NULL " +
f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
f"AND {db_type}Freeze.Name = %(db_name)s " +
@@ -443,7 +443,7 @@ def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins):
f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
f"LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol " +
" ".join(joins) +
- f" WHERE ProbeSet.Symbol IS NOT NULL " +
+ " WHERE ProbeSet.Symbol IS NOT NULL " +
f"AND {temp_table}.Correlation IS NOT NULL " +
f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
f"AND {db_type}Freeze.Name = %(db_name)s " +
@@ -451,17 +451,17 @@ def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins):
f"ORDER BY {db_type}.Id"),
3)
-def fetch_all_database_data(
- conn: Any, species: str, gene_id: int, gene_symbol: str,
+def fetch_all_database_data(# pylint: disable=[R0913, R0914]
+ conn: Any, species: str, gene_id: int, trait_symbol: str,
samples: Tuple[str, ...], db_type: str, db_name: str, method: str,
- returnNumber: int, tissueProbeSetFreezeId: int) -> Tuple[Any, Any]:
+ return_number: int, probeset_freeze_id: int) -> Tuple[Any, Any]:
"""
This is a migration of the
`web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
GeneNetwork1.
"""
def __build_query__(sample_ids, temp_table):
- sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in samples_ids)
+ sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids)
if db_type == "Publish":
joins = tuple(
("LEFT JOIN PublishData AS T{item} "
@@ -484,12 +484,12 @@ def fetch_all_database_data(
for item in sample_ids)
if method.lower() == "sgo literature correlation":
return build_query_sgo_lit_corr(
- sample_ids, temp_table, sample_id_columns)
+ sample_ids, temp_table, sample_id_columns, joins)
if method.lower() in (
"tissue correlation, pearson's r",
"tissue correlation, spearman's rho"):
return build_query_tissue_corr(
- sample_ids, temp_table, sample_id_columns)
+ sample_ids, temp_table, sample_id_columns, joins)
joins = tuple(
(f"LEFT JOIN {db_type}Data AS T{item} "
f"ON T{item}.Id = {db_type}XRef.DataId "
@@ -513,7 +513,7 @@ def fetch_all_database_data(
cursor.execute(
query, db_name=db_name,
**{f"T{item}_sample_id": item for item in sample_ids})
- return cursor.fetchall()
+ return (cursor.fetchall(), data_start_pos)
sample_ids = tuple(
# look into graduating this to an argument and removing the `samples`
@@ -543,4 +543,4 @@ def fetch_all_database_data(
with conn.cursor() as cursor:
cursor.execute(f"DROP TEMPORARY TABLE {temp_table}")
- return trait_database, data_start_pos
+ return (tuple(item[0] for item in trait_database), trait_database[0][1])
diff --git a/tests/unit/db/test_correlation.py b/tests/unit/db/test_correlation.py
index 866d28d..3f940b2 100644
--- a/tests/unit/db/test_correlation.py
+++ b/tests/unit/db/test_correlation.py
@@ -13,6 +13,9 @@ class TestCorrelation(TestCase):
maxDiff = None
def test_build_query_sgo_lit_corr(self):
+ """
+ Test that the literature correlation query is built correctly.
+ """
self.assertEqual(
build_query_sgo_lit_corr(
"Probeset",
@@ -51,6 +54,9 @@ class TestCorrelation(TestCase):
2))
def test_build_query_tissue_corr(self):
+ """
+ Test that the tissue correlation query is built correctly.
+ """
self.assertEqual(
build_query_tissue_corr(
"Probeset",