diff options
-rw-r--r-- | gn3/computations/partial_correlations.py | 430 | ||||
-rw-r--r-- | gn3/db/correlations.py | 211 | ||||
-rw-r--r-- | gn3/db/species.py | 17 | ||||
-rw-r--r-- | gn3/db/traits.py | 47 | ||||
-rw-r--r-- | tests/unit/computations/test_partial_correlations.py | 16 | ||||
-rw-r--r-- | tests/unit/db/test_correlation.py | 96 |
6 files changed, 744 insertions, 73 deletions
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py index 4f45159..231b0a7 100644 --- a/gn3/computations/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -6,15 +6,28 @@ GeneNetwork1. """ import math -from functools import reduce +from functools import reduce, partial from typing import Any, Tuple, Union, Sequence -from scipy.stats import pearsonr, spearmanr import pandas import pingouin +from scipy.stats import pearsonr, spearmanr from gn3.settings import TEXTDIR +from gn3.random import random_string +from gn3.function_helpers import compose from gn3.data_helpers import parse_csv_line +from gn3.db.traits import export_informative +from gn3.db.traits import retrieve_trait_info, retrieve_trait_data +from gn3.db.species import species_name, translate_to_mouse_gene_id +from gn3.db.correlations import ( + get_filename, + fetch_all_database_data, + check_for_literature_info, + fetch_tissue_correlations, + fetch_literature_correlations, + check_symbol_for_tissue_correlation, + fetch_gene_symbol_tissue_value_dict_for_trait) def control_samples(controls: Sequence[dict], sampleslist: Sequence[str]): """ @@ -112,7 +125,7 @@ def find_identical_traits( return acc + ident[1] def __dictify_controls__(acc, control_item): - ckey = "{:.3f}".format(control_item[0]) + ckey = tuple("{:.3f}".format(item) for item in control_item[0]) return {**acc, ckey: acc.get(ckey, tuple()) + (control_item[1],)} return (reduce(## for identical control traits @@ -200,33 +213,19 @@ def good_dataset_samples_indexes( samples_from_file.index(good) for good in set(samples).intersection(set(samples_from_file)))) -def determine_partials( - primary_vals, control_vals, all_target_trait_names, - all_target_trait_values, method): - """ - This **WILL** be a migration of - `web.webqtl.correlation.correlationFunction.determinePartialsByR` function - in GeneNetwork1. - - The function in GeneNetwork1 contains code written in R that is then used to - compute the partial correlations. - """ - ## This function is not implemented at this stage - return tuple( - primary_vals, control_vals, all_target_trait_names, - all_target_trait_values, method) - -def compute_partial_correlations_fast(# pylint: disable=[R0913, R0914] +def partial_correlations_fast(# pylint: disable=[R0913, R0914] samples, primary_vals, control_vals, database_filename, fetched_correlations, method: str, correlation_type: str) -> Tuple[ float, Tuple[float, ...]]: """ + Computes partial correlation coefficients using data from a CSV file. + This is a partial migration of the `web.webqtl.correlation.PartialCorrDBPage.getPartialCorrelationsFast` function in GeneNetwork1. """ assert method in ("spearman", "pearson") - with open(f"{TEXTDIR}/{database_filename}", "r") as dataset_file: + with open(database_filename, "r") as dataset_file: dataset = tuple(dataset_file.readlines()) good_dataset_samples = good_dataset_samples_indexes( @@ -300,33 +299,398 @@ def compute_partial( """ # replace the R code with `pingouin.partial_corr` def __compute_trait_info__(target): + targ_vals = target[0] + targ_name = target[1] primary = [ - prim for targ, prim in zip(target, primary_vals) + prim for targ, prim in zip(targ_vals, primary_vals) if targ is not None] + datafrm = build_data_frame( primary, - [targ for targ in target if targ is not None], - [cont for i, cont in enumerate(control_vals) - if target[i] is not None]) + tuple(targ for targ in targ_vals if targ is not None), + tuple(cont for i, cont in enumerate(control_vals) + if target[i] is not None)) covariates = "z" if datafrm.shape[1] == 3 else [ col for col in datafrm.columns if col not in ("x", "y")] ppc = pingouin.partial_corr( - data=datafrm, x="x", y="y", covar=covariates, method=method) - pc_coeff = ppc["r"] + data=datafrm, x="x", y="y", covar=covariates, method=( + "pearson" if "pearson" in method.lower() else "spearman")) + pc_coeff = ppc["r"][0] zero_order_corr = pingouin.corr( - datafrm["x"], datafrm["y"], method=method) + datafrm["x"], datafrm["y"], method=( + "pearson" if "pearson" in method.lower() else "spearman")) if math.isnan(pc_coeff): return ( - target[1], len(primary), pc_coeff, 1, zero_order_corr["r"], - zero_order_corr["p-val"]) + targ_name, len(primary), pc_coeff, 1, zero_order_corr["r"][0], + zero_order_corr["p-val"][0]) return ( - target[1], len(primary), pc_coeff, - (ppc["p-val"] if not math.isnan(ppc["p-val"]) else ( + targ_name, len(primary), pc_coeff, + (ppc["p-val"][0] if not math.isnan(ppc["p-val"][0]) else ( 0 if (abs(pc_coeff - 1) < 0.0000001) else 1)), - zero_order_corr["r"], zero_order_corr["p-val"]) + zero_order_corr["r"][0], zero_order_corr["p-val"][0]) return tuple( __compute_trait_info__(target) for target in zip(target_vals, target_names)) + +def partial_correlations_normal(# pylint: disable=R0913 + primary_vals, control_vals, input_trait_gene_id, trait_database, + data_start_pos: int, db_type: str, method: str) -> Tuple[ + float, Tuple[float, ...]]: + """ + Computes the correlation coefficients. + + This is a migration of the + `web.webqtl.correlation.PartialCorrDBPage.getPartialCorrelationsNormal` + function in GeneNetwork1. + """ + def __add_lit_and_tiss_corr__(item): + if method.lower() == "sgo literature correlation": + # if method is 'SGO Literature Correlation', `compute_partial` + # would give us LitCorr in the [1] position + return tuple(item) + trait_database[1] + if method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho"): + # if method is 'Tissue Correlation, *', `compute_partial` would give + # us Tissue Corr in the [1] position and Tissue Corr P Value in the + # [2] position + return tuple(item) + (trait_database[1], trait_database[2]) + return item + + target_trait_names, target_trait_vals = reduce( + lambda acc, item: (acc[0]+(item[0],), acc[1]+(item[data_start_pos:],)), + trait_database, (tuple(), tuple())) + + all_correlations = compute_partial( + primary_vals, control_vals, target_trait_vals, target_trait_names, + method) + + if (input_trait_gene_id and db_type == "ProbeSet" and method.lower() in ( + "sgo literature correlation", "tissue correlation, pearson's r", + "tissue correlation, spearman's rho")): + return ( + len(trait_database), + tuple( + __add_lit_and_tiss_corr__(item) + for idx, item in enumerate(all_correlations))) + + return len(trait_database), all_correlations + +def partial_corrs(# pylint: disable=[R0913] + conn, samples, primary_vals, control_vals, return_number, species, + input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id, + method, dataset, database_filename): + """ + Compute the partial correlations, selecting the fast or normal method + depending on the existence of the database text file. + + This is a partial migration of the + `web.webqtl.correlation.PartialCorrDBPage.__init__` function in + GeneNetwork1. + """ + if database_filename: + return partial_correlations_fast( + samples, primary_vals, control_vals, database_filename, + ( + fetch_literature_correlations( + species, input_trait_geneid, dataset, return_number, conn) + if "literature" in method.lower() else + fetch_tissue_correlations( + dataset, input_trait_symbol, tissue_probeset_freeze_id, + method, return_number, conn)), + method, + ("literature" if method.lower() == "sgo literature correlation" + else ("tissue" if "tissue" in method.lower() else "genetic"))) + + trait_database, data_start_pos = fetch_all_database_data( + conn, species, input_trait_geneid, input_trait_symbol, samples, dataset, + method, return_number, tissue_probeset_freeze_id) + return partial_correlations_normal( + primary_vals, control_vals, input_trait_geneid, trait_database, + data_start_pos, dataset, method) + +def literature_correlation_by_list( + conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict]: + """ + This is a migration of the + `web.webqtl.correlation.CorrelationPage.getLiteratureCorrelationByList` + function in GeneNetwork1. + """ + if any((lambda t: ( + bool(t.get("tissue_corr")) and + bool(t.get("tissue_p_value"))))(trait) + for trait in trait_list): + temporary_table_name = f"LITERATURE{random_string(8)}" + query1 = ( + f"CREATE TEMPORARY TABLE {temporary_table_name} " + "(GeneId1 INT(12) UNSIGNED, GeneId2 INT(12) UNSIGNED PRIMARY KEY, " + "value DOUBLE)") + query2 = ( + f"INSERT INTO {temporary_table_name}(GeneId1, GeneId2, value) " + "SELECT GeneId1, GeneId2, value FROM LCorrRamin3 " + "WHERE GeneId1=%(geneid)s") + query3 = ( + "INSERT INTO {temporary_table_name}(GeneId1, GeneId2, value) " + "SELECT GeneId2, GeneId1, value FROM LCorrRamin3 " + "WHERE GeneId2=%s AND GeneId1 != %(geneid)s") + + def __set_mouse_geneid__(trait): + if trait.get("geneid"): + return { + **trait, + "mouse_geneid": translate_to_mouse_gene_id( + species, trait.get("geneid"), conn) + } + return {**trait, "mouse_geneid": 0} + + def __retrieve_lcorr__(cursor, geneids): + cursor.execute( + f"SELECT GeneId2, value FROM {temporary_table_name} " + "WHERE GeneId2 IN %(geneids)s", + geneids=geneids) + return dict(cursor.fetchall()) + + with conn.cursor() as cursor: + cursor.execute(query1) + cursor.execute(query2) + cursor.execute(query3) + + traits = tuple(__set_mouse_geneid__(trait) for trait in trait_list) + lcorrs = __retrieve_lcorr__( + cursor, ( + trait["mouse_geneid"] for trait in traits + if (trait["mouse_geneid"] != 0 and + trait["mouse_geneid"].find(";") < 0))) + return tuple( + {**trait, "l_corr": lcorrs.get(trait["mouse_geneid"], None)} + for trait in traits) + + return trait_list + return trait_list + +def tissue_correlation_by_list( + conn: Any, primary_trait_symbol: str, tissue_probeset_freeze_id: int, + method: str, trait_list: Tuple[dict]) -> Tuple[dict]: + """ + This is a migration of the + `web.webqtl.correlation.CorrelationPage.getTissueCorrelationByList` + function in GeneNetwork1. + """ + def __add_tissue_corr__(trait, primary_trait_values, trait_values): + result = pingouin.corr( + primary_trait_values, trait_values, + method=("spearman" if "spearman" in method.lower() else "pearson")) + return { + **trait, + "tissue_corr": result["r"], + "tissue_p_value": result["p-val"] + } + + if any((lambda t: bool(t.get("l_corr")))(trait) for trait in trait_list): + prim_trait_symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait( + (primary_trait_symbol,), tissue_probeset_freeze_id, conn) + if primary_trait_symbol.lower() in prim_trait_symbol_value_dict: + primary_trait_value = prim_trait_symbol_value_dict[ + primary_trait_symbol.lower()] + gene_symbol_list = tuple( + trait for trait in trait_list if "symbol" in trait.keys()) + symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait( + gene_symbol_list, tissue_probeset_freeze_id, conn) + return tuple( + __add_tissue_corr__( + trait, primary_trait_value, + symbol_value_dict[trait["symbol"].lower()]) + for trait in trait_list + if ("symbol" in trait and + bool(trait["symbol"]) and + trait["symbol"].lower() in symbol_value_dict)) + return tuple({ + **trait, + "tissue_corr": None, + "tissue_p_value": None + } for trait in trait_list) + return trait_list + +def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911] + conn: Any, primary_trait_name: str, + control_trait_names: Tuple[str, ...], method: str, + criteria: int, group: str, target_db_name: str) -> dict: + """ + This is the 'ochestration' function for the partial-correlation feature. + + This function will dispatch the functions doing data fetches from the + database (and various other places) and feed that data to the functions + doing the conversions and computations. It will then return the results of + all of that work. + + This function is doing way too much. Look into splitting out the + functionality into smaller functions that do fewer things. + """ + threshold = 0 + corr_min_informative = 4 + + primary_trait = retrieve_trait_info(threshold, primary_trait_name, conn) + primary_trait_data = retrieve_trait_data(primary_trait, conn) + primary_samples, primary_values, _primary_variances = export_informative( + primary_trait_data) + + cntrl_traits = tuple( + retrieve_trait_info(threshold, trait_full_name, conn) + for trait_full_name in control_trait_names) + cntrl_traits_data = tuple( + retrieve_trait_data(cntrl_trait, conn) + for cntrl_trait in cntrl_traits) + species = species_name(conn, group) + + (cntrl_samples, + cntrl_values, + _cntrl_variances, + _cntrl_ns) = control_samples(cntrl_traits_data, primary_samples) + + common_primary_control_samples = primary_samples + fixed_primary_vals = primary_values + fixed_control_vals = cntrl_values + if not all(cnt_smp == primary_samples for cnt_smp in cntrl_samples): + (common_primary_control_samples, + fixed_primary_vals, + fixed_control_vals, + _primary_variances, + _cntrl_variances) = fix_samples(primary_trait, cntrl_traits) + + if len(common_primary_control_samples) < corr_min_informative: + return { + "status": "error", + "message": ( + f"Fewer than {corr_min_informative} samples data entered for " + f"{group} dataset. No calculation of correlation has been " + "attempted."), + "error_type": "Inadequate Samples"} + + identical_traits_names = find_identical_traits( + primary_trait_name, primary_values, control_trait_names, cntrl_values) + if len(identical_traits_names) > 0: + return { + "status": "error", + "message": ( + f"{identical_traits_names[0]} and {identical_traits_names[1]} " + "have the same values for the {len(fixed_primary_vals)} " + "samples that will be used to compute the partial correlation " + "(common for all primary and control traits). In such cases, " + "partial correlation cannot be computed. Please re-select your " + "traits."), + "error_type": "Identical Traits"} + + input_trait_geneid = primary_trait.get("geneid") + input_trait_symbol = primary_trait.get("symbol") + input_trait_mouse_geneid = translate_to_mouse_gene_id( + species, input_trait_geneid, conn) + + tissue_probeset_freeze_id = 1 + db_type = primary_trait["db"]["dataset_type"] + + if db_type == "ProbeSet" and method.lower() in ( + "sgo literature correlation", + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho"): + return { + "status": "error", + "message": ( + "Wrong correlation type: It is not possible to compute the " + f"{method} between your trait and data in the {target_db_name} " + "database. Please try again after selecting another type of " + "correlation."), + "error_type": "Correlation Type"} + + if (method.lower() == "sgo literature correlation" and ( + input_trait_geneid is None or + check_for_literature_info(conn, input_trait_mouse_geneid))): + return { + "status": "error", + "message": ( + "No Literature Information: This gene does not have any " + "associated Literature Information."), + "error_type": "Literature Correlation"} + + if ( + method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho") + and input_trait_symbol is None): + return { + "status": "error", + "message": ( + "No Tissue Correlation Information: This gene does not have " + "any associated Tissue Correlation Information."), + "error_type": "Tissue Correlation"} + + if ( + method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho") + and check_symbol_for_tissue_correlation( + conn, tissue_probeset_freeze_id, input_trait_symbol)): + return { + "status": "error", + "message": ( + "No Tissue Correlation Information: This gene does not have " + "any associated Tissue Correlation Information."), + "error_type": "Tissue Correlation"} + + database_filename = get_filename(conn, target_db_name, TEXTDIR) + _total_traits, all_correlations = partial_corrs( + conn, common_primary_control_samples, fixed_primary_vals, + fixed_control_vals, len(fixed_primary_vals), species, + input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id, + method, primary_trait["db"], database_filename) + + + def __make_sorter__(method): + def __sort_6__(row): + return row[6] + + def __sort_3__(row): + return row[3] + + if "literature" in method.lower(): + return __sort_6__ + + if "tissue" in method.lower(): + return __sort_6__ + + return __sort_3__ + + sorted_correlations = sorted( + all_correlations, key=__make_sorter__(method)) + + add_lit_corr_and_tiss_corr = compose( + partial(literature_correlation_by_list, conn, species), + partial( + tissue_correlation_by_list, conn, input_trait_symbol, + tissue_probeset_freeze_id, method)) + + trait_list = add_lit_corr_and_tiss_corr(tuple( + { + **retrieve_trait_info( + threshold, + f"{primary_trait['db']['dataset_name']}::{item[0]}", + conn), + "noverlap": item[1], + "partial_corr": item[2], + "partial_corr_p_value": item[3], + "corr": item[4], + "corr_p_value": item[5], + "rank_order": (1 if "spearman" in method.lower() else 0), + **({ + "tissue_corr": item[6], + "tissue_p_value": item[7]} + if len(item) == 8 else {}), + **({"l_corr": item[6]} + if len(item) == 7 else {}) + } + for item in + sorted_correlations[:min(criteria, len(all_correlations))])) + + return trait_list diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py index 06b3310..3d12019 100644 --- a/gn3/db/correlations.py +++ b/gn3/db/correlations.py @@ -2,17 +2,16 @@ This module will hold functions that are used in the (partial) correlations feature to access the database to retrieve data needed for computations. """ - +import os from functools import reduce -from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, Union from gn3.random import random_string from gn3.data_helpers import partition_all from gn3.db.species import translate_to_mouse_gene_id -from gn3.computations.partial_correlations import correlations_of_all_tissue_traits - -def get_filename(target_db_name: str, conn: Any) -> str: +def get_filename(conn: Any, target_db_name: str, text_files_dir: str) -> Union[ + str, bool]: """ Retrieve the name of the reference database file with which correlations are computed. @@ -23,18 +22,20 @@ def get_filename(target_db_name: str, conn: Any) -> str: """ with conn.cursor() as cursor: cursor.execute( - "SELECT Id, FullName from ProbeSetFreeze WHERE Name-%s", - target_db_name) + "SELECT Id, FullName from ProbeSetFreeze WHERE Name=%s", + (target_db_name,)) result = cursor.fetchone() if result: - return "ProbeSetFreezeId_{tid}_FullName_{fname}.txt".format( + filename = "ProbeSetFreezeId_{tid}_FullName_{fname}.txt".format( tid=result[0], fname=result[1].replace(' ', '_').replace('/', '_')) + return ((filename in os.listdir(text_files_dir)) + and f"{text_files_dir}/{filename}") - return "" + return False def build_temporary_literature_table( - species: str, gene_id: int, return_number: int, conn: Any) -> str: + conn: Any, species: str, gene_id: int, return_number: int) -> str: """ Build and populate a temporary table to hold the literature correlation data to be used in computations. @@ -128,7 +129,7 @@ def fetch_literature_correlations( GeneNetwork1. """ temp_table = build_temporary_literature_table( - species, gene_id, return_number, conn) + conn, species, gene_id, return_number) query_fns = { "Geno": fetch_geno_literature_correlations, # "Temp": fetch_temp_literature_correlations, @@ -268,8 +269,8 @@ def fetch_gene_symbol_tissue_value_dict_for_trait( return {} def build_temporary_tissue_correlations_table( - trait_symbol: str, probeset_freeze_id: int, method: str, - return_number: int, conn: Any) -> str: + conn: Any, trait_symbol: str, probeset_freeze_id: int, method: str, + return_number: int) -> str: """ Build a temporary table to hold the tissue correlations data. @@ -279,6 +280,16 @@ def build_temporary_tissue_correlations_table( # We should probably pass the `correlations_of_all_tissue_traits` function # as an argument to this function and get rid of the one call immediately # following this comment. + from gn3.computations.partial_correlations import (#pylint: disable=[C0415, R0401] + correlations_of_all_tissue_traits) + # This import above is necessary within the function to avoid + # circular-imports. + # + # + # This import above is indicative of convoluted code, with the computation + # being interwoven with the data retrieval. This needs to be changed, such + # that the function being imported here is no longer necessary, or have the + # imported function passed to this function as an argument. symbol_corr_dict, symbol_p_value_dict = correlations_of_all_tissue_traits( fetch_gene_symbol_tissue_value_dict_for_trait( (trait_symbol,), probeset_freeze_id, conn), @@ -320,7 +331,7 @@ def fetch_tissue_correlations(# pylint: disable=R0913 GeneNetwork1. """ temp_table = build_temporary_tissue_correlations_table( - trait_symbol, probeset_freeze_id, method, return_number, conn) + conn, trait_symbol, probeset_freeze_id, method, return_number) with conn.cursor() as cursor: cursor.execute( ( @@ -379,3 +390,175 @@ def check_symbol_for_tissue_correlation( return True return False + +def fetch_sample_ids( + conn: Any, sample_names: Tuple[str, ...], species_name: str) -> Tuple[ + int, ...]: + """ + Given a sequence of sample names, and a species name, return the sample ids + that correspond to both. + + This is a partial migration of the + `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in + GeneNetwork1. + """ + query = ( + "SELECT Strain.Id FROM Strain, Species " + "WHERE Strain.Name IN %(samples_names)s " + "AND Strain.SpeciesId=Species.Id " + "AND Species.name=%(species_name)s") + with conn.cursor() as cursor: + cursor.execute( + query, + { + "samples_names": tuple(sample_names), + "species_name": species_name + }) + return tuple(row[0] for row in cursor.fetchall()) + +def build_query_sgo_lit_corr( + db_type: str, temp_table: str, sample_id_columns: str, + joins: Tuple[str, ...]) -> str: + """ + Build query for `SGO Literature Correlation` data, when querying the given + `temp_table` temporary table. + + This is a partial migration of the + `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in + GeneNetwork1. + """ + return ( + (f"SELECT {db_type}.Name, {temp_table}.value, " + + sample_id_columns + + f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + + f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " + + " ".join(joins) + + " WHERE ProbeSet.GeneId IS NOT NULL " + + f"AND {temp_table}.value IS NOT NULL " + + f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + + f"AND {db_type}Freeze.Name = %(db_name)s " + + f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " + + f"ORDER BY {db_type}.Id"), + 2) + +def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins): + """ + Build query for `Tissue Correlation` data, when querying the given + `temp_table` temporary table. + + This is a partial migration of the + `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in + GeneNetwork1. + """ + return ( + (f"SELECT {db_type}.Name, {temp_table}.Correlation, " + + f"{temp_table}.PValue, " + + sample_id_columns + + f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + + f"LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol " + + " ".join(joins) + + " WHERE ProbeSet.Symbol IS NOT NULL " + + f"AND {temp_table}.Correlation IS NOT NULL " + + f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + + f"AND {db_type}Freeze.Name = %(db_name)s " + + f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " + f"ORDER BY {db_type}.Id"), + 3) + +def fetch_all_database_data(# pylint: disable=[R0913, R0914] + conn: Any, species: str, gene_id: int, trait_symbol: str, + samples: Tuple[str, ...], dataset: dict, method: str, + return_number: int, probeset_freeze_id: int) -> Tuple[ + Tuple[float], int]: + """ + This is a migration of the + `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in + GeneNetwork1. + """ + db_type = dataset["dataset_type"] + db_name = dataset["dataset_name"] + def __build_query__(sample_ids, temp_table): + sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids) + if db_type == "Publish": + joins = tuple( + ("LEFT JOIN PublishData AS T{item} " + "ON T{item}.Id = PublishXRef.DataId " + "AND T{item}.StrainId = %(T{item}_sample_id)s") + for item in sample_ids) + return ( + ("SELECT PublishXRef.Id, " + + sample_id_columns + + "FROM (PublishXRef, PublishFreeze) " + + " ".join(joins) + + " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId " + "AND PublishFreeze.Name = %(db_name)s"), + 1) + if temp_table is not None: + joins = tuple( + (f"LEFT JOIN {db_type}Data AS T{item} " + f"ON T{item}.Id = {db_type}XRef.DataId " + f"AND T{item}.StrainId=%(T{item}_sample_id)s") + for item in sample_ids) + if method.lower() == "sgo literature correlation": + return build_query_sgo_lit_corr( + sample_ids, temp_table, sample_id_columns, joins) + if method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho"): + return build_query_tissue_corr( + sample_ids, temp_table, sample_id_columns, joins) + joins = tuple( + (f"LEFT JOIN {db_type}Data AS T{item} " + f"ON T{item}.Id = {db_type}XRef.DataId " + f"AND T{item}.StrainId = %(T{item}_sample_id)s") + for item in sample_ids) + return ( + ( + f"SELECT {db_type}.Name, " + + sample_id_columns + + f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + + " ".join(joins) + + f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + + f"AND {db_type}Freeze.Name = %(db_name)s " + + f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " + + f"ORDER BY {db_type}.Id"), + 1) + + def __fetch_data__(sample_ids, temp_table): + query, data_start_pos = __build_query__(sample_ids, temp_table) + with conn.cursor() as cursor: + cursor.execute( + query, + {"db_name": db_name, + **{f"T{item}_sample_id": item for item in sample_ids}}) + return (cursor.fetchall(), data_start_pos) + + sample_ids = tuple( + # look into graduating this to an argument and removing the `samples` + # and `species` argument: function currying and compositions might help + # with this + f"{sample_id}" for sample_id in + fetch_sample_ids(conn, samples, species)) + + temp_table = None + if gene_id and db_type == "probeset": + if method.lower() == "sgo literature correlation": + temp_table = build_temporary_literature_table( + conn, species, gene_id, return_number) + if method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho"): + temp_table = build_temporary_tissue_correlations_table( + conn, trait_symbol, probeset_freeze_id, method, return_number) + + trait_database = tuple( + item for sublist in + (__fetch_data__(ssample_ids, temp_table) + for ssample_ids in partition_all(25, sample_ids)) + for item in sublist) + + if temp_table: + with conn.cursor() as cursor: + cursor.execute(f"DROP TEMPORARY TABLE {temp_table}") + + return (trait_database[0], trait_database[1]) diff --git a/gn3/db/species.py b/gn3/db/species.py index 702a9a8..5b8e096 100644 --- a/gn3/db/species.py +++ b/gn3/db/species.py @@ -57,3 +57,20 @@ def translate_to_mouse_gene_id(species: str, geneid: int, conn: Any) -> int: return translated_gene_id[0] return 0 # default if all else fails + +def species_name(conn: Any, group: str) -> str: + """ + Retrieve the name of the species, given the group (RISet). + + This is a migration of the + `web.webqtl.dbFunction.webqtlDatabaseFunction.retrieveSpecies` function in + GeneNetwork1. + """ + with conn.cursor() as cursor: + cursor.execute( + ("SELECT Species.Name FROM Species, InbredSet " + "WHERE InbredSet.Name = %(group_name)s " + "AND InbredSet.SpeciesId = Species.Id"), + {"group_name": group}) + return cursor.fetchone()[0] + return None diff --git a/gn3/db/traits.py b/gn3/db/traits.py index 75de4f4..4098b08 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -4,6 +4,8 @@ import MySQLdb from functools import reduce from typing import Any, Dict, Union, Sequence +import MySQLdb + from gn3.settings import TMPDIR from gn3.random import random_string from gn3.function_helpers import compose @@ -81,10 +83,10 @@ def export_trait_data( def get_trait_csv_sample_data(conn: Any, trait_name: int, phenotype_id: int): """Fetch a trait and return it as a csv string""" - def __float_strip(n): - if str(n)[-2:] == ".0": - return str(int(n)) - return str(n) + def __float_strip(num_str): + if str(num_str)[-2:] == ".0": + return str(int(num_str)) + return str(num_str) sql = ("SELECT DISTINCT Strain.Name, PublishData.value, " "PublishSE.error, NStrain.count FROM " "(PublishData, Strain, PublishXRef, PublishFreeze) " @@ -108,7 +110,8 @@ def get_trait_csv_sample_data(conn: Any, return "\n".join(csv_data) -def update_sample_data(conn: Any, +def update_sample_data(conn: Any, #pylint: disable=[R0913] + trait_name: str, strain_name: str, phenotype_id: int, @@ -219,7 +222,7 @@ def delete_sample_data(conn: Any, "WHERE StrainId = %s AND DataId = %s" % (strain_id, data_id))) deleted_n_strains = cursor.rowcount - except Exception as e: + except Exception as e: #pylint: disable=[C0103, W0612] conn.rollback() raise MySQLdb.Error conn.commit() @@ -230,7 +233,7 @@ def delete_sample_data(conn: Any, deleted_se_data, deleted_n_strains) -def insert_sample_data(conn: Any, +def insert_sample_data(conn: Any, #pylint: disable=[R0913] trait_name: str, strain_name: str, phenotype_id: int, @@ -272,7 +275,7 @@ def insert_sample_data(conn: Any, "VALUES (%s, %s, %s)") % (strain_id, data_id, count)) inserted_n_strains = cursor.rowcount - except Exception as e: + except Exception as e: #pylint: disable=[C0103, W0612] conn.rollback() raise MySQLdb.Error return (inserted_published_data, @@ -450,7 +453,7 @@ def set_homologene_id_field(trait_type, trait_info, conn): Common postprocessing function for all trait types. Sets the value for the 'homologene' key.""" - def set_to_null(ti): return {**ti, "homologeneid": None} + def set_to_null(ti): return {**ti, "homologeneid": None} # pylint: disable=[C0103, C0321] functions_table = { "Temp": set_to_null, "Geno": set_to_null, @@ -656,8 +659,9 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any): query, {"trait_name": trait_info["trait_name"]}) return [dict(zip( - ["sample_name", "value", "se_error", "nstrain", "id"], row)) - for row in cursor.fetchall()] + ["sample_name", "value", "se_error", "nstrain", "id"], + row)) + for row in cursor.fetchall()] return [] @@ -696,8 +700,10 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any): "dataset_name": trait_info["db"]["dataset_name"], "species_id": retrieve_species_id( trait_info["db"]["group"], conn)}) - return [dict(zip( - ["sample_name", "value", "se_error", "id"], row)) + return [ + dict(zip( + ["sample_name", "value", "se_error", "id"], + row)) for row in cursor.fetchall()] return [] @@ -728,8 +734,9 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any): query, {"trait_name": trait_info["trait_name"], "dataset_id": trait_info["db"]["dataset_id"]}) - return [dict(zip( - ["sample_name", "value", "se_error", "nstrain", "id"], row)) + return [ + dict(zip( + ["sample_name", "value", "se_error", "nstrain", "id"], row)) for row in cursor.fetchall()] return [] @@ -762,8 +769,9 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any): {"cellid": trait_info["cellid"], "trait_name": trait_info["trait_name"], "dataset_id": trait_info["db"]["dataset_id"]}) - return [dict(zip( - ["sample_name", "value", "se_error", "id"], row)) + return [ + dict(zip( + ["sample_name", "value", "se_error", "id"], row)) for row in cursor.fetchall()] return [] @@ -792,8 +800,9 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any): query, {"trait_name": trait_info["trait_name"], "dataset_name": trait_info["db"]["dataset_name"]}) - return [dict(zip( - ["sample_name", "value", "se_error", "id"], row)) + return [ + dict(zip( + ["sample_name", "value", "se_error", "id"], row)) for row in cursor.fetchall()] return [] diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py index 3e1b6e1..3690ca4 100644 --- a/tests/unit/computations/test_partial_correlations.py +++ b/tests/unit/computations/test_partial_correlations.py @@ -193,7 +193,7 @@ class TestPartialCorrelations(TestCase): Given: - the name of a primary trait - - the value of a primary trait + - a sequence of values for the primary trait - a sequence of names of control traits - a sequence of values of control traits When: @@ -204,12 +204,14 @@ class TestPartialCorrelations(TestCase): decimal places are considered """ for primn, primv, contn, contv, expected in ( - ("pt", 12.98395, ("ct0", "ct1", "ct2"), - (0.1234, 2.3456, 3.4567), tuple()), - ("pt", 12.98395, ("ct0", "ct1", "ct2"), - (12.98354, 2.3456, 3.4567), ("pt", "ct0")), - ("pt", 12.98395, ("ct0", "ct1", "ct2", "ct3"), - (0.1234, 2.3456, 0.1233, 4.5678), ("ct0", "ct2")) + ("pt", (12.98395,), ("ct0", "ct1", "ct2"), + ((0.1234, 2.3456, 3.4567),), tuple()), + ("pt", (12.98395, 2.3456, 3.4567), ("ct0", "ct1", "ct2"), + ((12.98354, 2.3456, 3.4567), (64.2334, 6.3256, 64.2364), + (4.2374, 67.2345, 7.48234)), ("pt", "ct0")), + ("pt", (12.98395, 75.52382), ("ct0", "ct1", "ct2", "ct3"), + ((0.1234, 2.3456), (0.3621, 6543.572), (0.1234, 2.3456), + (0.1233, 4.5678)), ("ct0", "ct2")) ): with self.subTest( primary_name=primn, primary_value=primv, diff --git a/tests/unit/db/test_correlation.py b/tests/unit/db/test_correlation.py new file mode 100644 index 0000000..3f940b2 --- /dev/null +++ b/tests/unit/db/test_correlation.py @@ -0,0 +1,96 @@ +""" +Tests for the gn3.db.correlations module +""" + +from unittest import TestCase + +from gn3.db.correlations import ( + build_query_sgo_lit_corr, + build_query_tissue_corr) + +class TestCorrelation(TestCase): + """Test cases for correlation data fetching functions""" + maxDiff = None + + def test_build_query_sgo_lit_corr(self): + """ + Test that the literature correlation query is built correctly. + """ + self.assertEqual( + build_query_sgo_lit_corr( + "Probeset", + "temp_table_xy45i7wd", + "T1.value, T2.value, T3.value", + (("LEFT JOIN ProbesetData AS T1 " + "ON T1.Id = ProbesetXRef.DataId " + "AND T1.StrainId=%(T1_sample_id)s"), + ( + "LEFT JOIN ProbesetData AS T2 " + "ON T2.Id = ProbesetXRef.DataId " + "AND T2.StrainId=%(T2_sample_id)s"), + ( + "LEFT JOIN ProbesetData AS T3 " + "ON T3.Id = ProbesetXRef.DataId " + "AND T3.StrainId=%(T3_sample_id)s"))), + (("SELECT Probeset.Name, temp_table_xy45i7wd.value, " + "T1.value, T2.value, T3.value " + "FROM (Probeset, ProbesetXRef, ProbesetFreeze) " + "LEFT JOIN temp_table_xy45i7wd ON temp_table_xy45i7wd.GeneId2=ProbeSet.GeneId " + "LEFT JOIN ProbesetData AS T1 " + "ON T1.Id = ProbesetXRef.DataId " + "AND T1.StrainId=%(T1_sample_id)s " + "LEFT JOIN ProbesetData AS T2 " + "ON T2.Id = ProbesetXRef.DataId " + "AND T2.StrainId=%(T2_sample_id)s " + "LEFT JOIN ProbesetData AS T3 " + "ON T3.Id = ProbesetXRef.DataId " + "AND T3.StrainId=%(T3_sample_id)s " + "WHERE ProbeSet.GeneId IS NOT NULL " + "AND temp_table_xy45i7wd.value IS NOT NULL " + "AND ProbesetXRef.ProbesetFreezeId = ProbesetFreeze.Id " + "AND ProbesetFreeze.Name = %(db_name)s " + "AND Probeset.Id = ProbesetXRef.ProbesetId " + "ORDER BY Probeset.Id"), + 2)) + + def test_build_query_tissue_corr(self): + """ + Test that the tissue correlation query is built correctly. + """ + self.assertEqual( + build_query_tissue_corr( + "Probeset", + "temp_table_xy45i7wd", + "T1.value, T2.value, T3.value", + (("LEFT JOIN ProbesetData AS T1 " + "ON T1.Id = ProbesetXRef.DataId " + "AND T1.StrainId=%(T1_sample_id)s"), + ( + "LEFT JOIN ProbesetData AS T2 " + "ON T2.Id = ProbesetXRef.DataId " + "AND T2.StrainId=%(T2_sample_id)s"), + ( + "LEFT JOIN ProbesetData AS T3 " + "ON T3.Id = ProbesetXRef.DataId " + "AND T3.StrainId=%(T3_sample_id)s"))), + (("SELECT Probeset.Name, temp_table_xy45i7wd.Correlation, " + "temp_table_xy45i7wd.PValue, " + "T1.value, T2.value, T3.value " + "FROM (Probeset, ProbesetXRef, ProbesetFreeze) " + "LEFT JOIN temp_table_xy45i7wd ON temp_table_xy45i7wd.Symbol=ProbeSet.Symbol " + "LEFT JOIN ProbesetData AS T1 " + "ON T1.Id = ProbesetXRef.DataId " + "AND T1.StrainId=%(T1_sample_id)s " + "LEFT JOIN ProbesetData AS T2 " + "ON T2.Id = ProbesetXRef.DataId " + "AND T2.StrainId=%(T2_sample_id)s " + "LEFT JOIN ProbesetData AS T3 " + "ON T3.Id = ProbesetXRef.DataId " + "AND T3.StrainId=%(T3_sample_id)s " + "WHERE ProbeSet.Symbol IS NOT NULL " + "AND temp_table_xy45i7wd.Correlation IS NOT NULL " + "AND ProbesetXRef.ProbesetFreezeId = ProbesetFreeze.Id " + "AND ProbesetFreeze.Name = %(db_name)s " + "AND Probeset.Id = ProbesetXRef.ProbesetId " + "ORDER BY Probeset.Id"), + 3)) |