diff options
Diffstat (limited to 'gn3')
-rw-r--r-- | gn3/computations/partial_correlations.py | 4 | ||||
-rw-r--r-- | gn3/db/correlations.py | 24 |
2 files changed, 15 insertions, 13 deletions
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py index 4bd26a2..f43c4d4 100644 --- a/gn3/computations/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -200,11 +200,13 @@ def good_dataset_samples_indexes( samples_from_file.index(good) for good in set(samples).intersection(set(samples_from_file)))) -def compute_partial_correlations_fast(# pylint: disable=[R0913, R0914] +def partial_correlations_fast(# pylint: disable=[R0913, R0914] samples, primary_vals, control_vals, database_filename, fetched_correlations, method: str, correlation_type: str) -> Tuple[ float, Tuple[float, ...]]: """ + Computes partial correlation coefficients using data from a CSV file. + This is a partial migration of the `web.webqtl.correlation.PartialCorrDBPage.getPartialCorrelationsFast` function in GeneNetwork1. diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py index 5c3e7b8..a1daa3c 100644 --- a/gn3/db/correlations.py +++ b/gn3/db/correlations.py @@ -398,8 +398,8 @@ def fetch_sample_ids( "AND Species.name=%(species_name)s") with conn.cursor() as cursor: cursor.execute( - query, samples_names=tuple(samples), - species_name=species) + query, samples_names=tuple(sample_names), + species_name=species_name) return cursor.fetchall() def build_query_sgo_lit_corr( @@ -419,7 +419,7 @@ def build_query_sgo_lit_corr( f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " + " ".join(joins) + - f" WHERE ProbeSet.GeneId IS NOT NULL " + + " WHERE ProbeSet.GeneId IS NOT NULL " + f"AND {temp_table}.value IS NOT NULL " + f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + f"AND {db_type}Freeze.Name = %(db_name)s " + @@ -443,7 +443,7 @@ def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins): f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + f"LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol " + " ".join(joins) + - f" WHERE ProbeSet.Symbol IS NOT NULL " + + " WHERE ProbeSet.Symbol IS NOT NULL " + f"AND {temp_table}.Correlation IS NOT NULL " + f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + f"AND {db_type}Freeze.Name = %(db_name)s " + @@ -451,17 +451,17 @@ def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins): f"ORDER BY {db_type}.Id"), 3) -def fetch_all_database_data( - conn: Any, species: str, gene_id: int, gene_symbol: str, +def fetch_all_database_data(# pylint: disable=[R0913, R0914] + conn: Any, species: str, gene_id: int, trait_symbol: str, samples: Tuple[str, ...], db_type: str, db_name: str, method: str, - returnNumber: int, tissueProbeSetFreezeId: int) -> Tuple[Any, Any]: + return_number: int, probeset_freeze_id: int) -> Tuple[Any, Any]: """ This is a migration of the `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in GeneNetwork1. """ def __build_query__(sample_ids, temp_table): - sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in samples_ids) + sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids) if db_type == "Publish": joins = tuple( ("LEFT JOIN PublishData AS T{item} " @@ -484,12 +484,12 @@ def fetch_all_database_data( for item in sample_ids) if method.lower() == "sgo literature correlation": return build_query_sgo_lit_corr( - sample_ids, temp_table, sample_id_columns) + sample_ids, temp_table, sample_id_columns, joins) if method.lower() in ( "tissue correlation, pearson's r", "tissue correlation, spearman's rho"): return build_query_tissue_corr( - sample_ids, temp_table, sample_id_columns) + sample_ids, temp_table, sample_id_columns, joins) joins = tuple( (f"LEFT JOIN {db_type}Data AS T{item} " f"ON T{item}.Id = {db_type}XRef.DataId " @@ -513,7 +513,7 @@ def fetch_all_database_data( cursor.execute( query, db_name=db_name, **{f"T{item}_sample_id": item for item in sample_ids}) - return cursor.fetchall() + return (cursor.fetchall(), data_start_pos) sample_ids = tuple( # look into graduating this to an argument and removing the `samples` @@ -543,4 +543,4 @@ def fetch_all_database_data( with conn.cursor() as cursor: cursor.execute(f"DROP TEMPORARY TABLE {temp_table}") - return trait_database, data_start_pos + return (tuple(item[0] for item in trait_database), trait_database[0][1]) |