diff options
Diffstat (limited to 'gn3')
-rw-r--r-- | gn3/computations/partial_correlations.py | 131 | ||||
-rw-r--r-- | gn3/db/correlations.py | 5 | ||||
-rw-r--r-- | gn3/db/traits.py | 47 |
3 files changed, 100 insertions, 83 deletions
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py index 869bee4..231b0a7 100644 --- a/gn3/computations/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -14,12 +14,20 @@ import pingouin from scipy.stats import pearsonr, spearmanr from gn3.settings import TEXTDIR +from gn3.random import random_string from gn3.function_helpers import compose from gn3.data_helpers import parse_csv_line from gn3.db.traits import export_informative from gn3.db.traits import retrieve_trait_info, retrieve_trait_data from gn3.db.species import species_name, translate_to_mouse_gene_id -from gn3.db.correlations import get_filename, fetch_all_database_data +from gn3.db.correlations import ( + get_filename, + fetch_all_database_data, + check_for_literature_info, + fetch_tissue_correlations, + fetch_literature_correlations, + check_symbol_for_tissue_correlation, + fetch_gene_symbol_tissue_value_dict_for_trait) def control_samples(controls: Sequence[dict], sampleslist: Sequence[str]): """ @@ -311,7 +319,7 @@ def compute_partial( zero_order_corr = pingouin.corr( datafrm["x"], datafrm["y"], method=( - "pearson" if "pearson" in method.lower() else "spearman")) + "pearson" if "pearson" in method.lower() else "spearman")) if math.isnan(pc_coeff): return ( @@ -371,9 +379,10 @@ def partial_correlations_normal(# pylint: disable=R0913 return len(trait_database), all_correlations -def partial_corrs( - conn, samples , primary_vals, control_vals, return_number, species, input_trait_geneid, - input_trait_symbol, tissue_probeset_freeze_id, method, dataset, database_filename): +def partial_corrs(# pylint: disable=[R0913] + conn, samples, primary_vals, control_vals, return_number, species, + input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id, + method, dataset, database_filename): """ Compute the partial correlations, selecting the fast or normal method depending on the existence of the database text file. @@ -404,8 +413,7 @@ def partial_corrs( data_start_pos, dataset, method) def literature_correlation_by_list( - conn: Any, input_trait_mouse_geneid: int, species: str, - trait_list: Tuple[dict]) -> Tuple[dict]: + conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict]: """ This is a migration of the `web.webqtl.correlation.CorrelationPage.getLiteratureCorrelationByList` @@ -415,16 +423,16 @@ def literature_correlation_by_list( bool(t.get("tissue_corr")) and bool(t.get("tissue_p_value"))))(trait) for trait in trait_list): - temp_table_name = f"LITERATURE{random_string(8)}" - q1 = ( + temporary_table_name = f"LITERATURE{random_string(8)}" + query1 = ( f"CREATE TEMPORARY TABLE {temporary_table_name} " "(GeneId1 INT(12) UNSIGNED, GeneId2 INT(12) UNSIGNED PRIMARY KEY, " "value DOUBLE)") - q2 = ( + query2 = ( f"INSERT INTO {temporary_table_name}(GeneId1, GeneId2, value) " "SELECT GeneId1, GeneId2, value FROM LCorrRamin3 " "WHERE GeneId1=%(geneid)s") - q3 = ( + query3 = ( "INSERT INTO {temporary_table_name}(GeneId1, GeneId2, value) " "SELECT GeneId2, GeneId1, value FROM LCorrRamin3 " "WHERE GeneId2=%s AND GeneId1 != %(geneid)s") @@ -433,7 +441,8 @@ def literature_correlation_by_list( if trait.get("geneid"): return { **trait, - "mouse_geneid": translate_to_mouse_gene_id(trait.get("geneid")) + "mouse_geneid": translate_to_mouse_gene_id( + species, trait.get("geneid"), conn) } return {**trait, "mouse_geneid": 0} @@ -441,13 +450,13 @@ def literature_correlation_by_list( cursor.execute( f"SELECT GeneId2, value FROM {temporary_table_name} " "WHERE GeneId2 IN %(geneids)s", - geneids = geneids) - return {geneid: value for geneid, value in cursor.fetchall()} + geneids=geneids) + return dict(cursor.fetchall()) with conn.cursor() as cursor: - cursor.execute(q1) - cursor.execute(q2) - cursor.execute(q3) + cursor.execute(query1) + cursor.execute(query2) + cursor.execute(query3) traits = tuple(__set_mouse_geneid__(trait) for trait in trait_list) lcorrs = __retrieve_lcorr__( @@ -470,9 +479,9 @@ def tissue_correlation_by_list( `web.webqtl.correlation.CorrelationPage.getTissueCorrelationByList` function in GeneNetwork1. """ - def __add_tissue_corr__(trait, primary_trait_value, trait_value): + def __add_tissue_corr__(trait, primary_trait_values, trait_values): result = pingouin.corr( - primary_trait_values, target_trait_values, + primary_trait_values, trait_values, method=("spearman" if "spearman" in method.lower() else "pearson")) return { **trait, @@ -484,7 +493,8 @@ def tissue_correlation_by_list( prim_trait_symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait( (primary_trait_symbol,), tissue_probeset_freeze_id, conn) if primary_trait_symbol.lower() in prim_trait_symbol_value_dict: - primary_trait_value = prim_trait_symbol_value_dict[prim_trait_symbol.lower()] + primary_trait_value = prim_trait_symbol_value_dict[ + primary_trait_symbol.lower()] gene_symbol_list = tuple( trait for trait in trait_list if "symbol" in trait.keys()) symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait( @@ -504,7 +514,7 @@ def tissue_correlation_by_list( } for trait in trait_list) return trait_list -def partial_correlations_entry( +def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911] conn: Any, primary_trait_name: str, control_trait_names: Tuple[str, ...], method: str, criteria: int, group: str, target_db_name: str) -> dict: @@ -524,7 +534,7 @@ def partial_correlations_entry( primary_trait = retrieve_trait_info(threshold, primary_trait_name, conn) primary_trait_data = retrieve_trait_data(primary_trait, conn) - primary_samples, primary_values, primary_variances = export_informative( + primary_samples, primary_values, _primary_variances = export_informative( primary_trait_data) cntrl_traits = tuple( @@ -537,8 +547,8 @@ def partial_correlations_entry( (cntrl_samples, cntrl_values, - cntrl_variances, - cntrl_ns) = control_samples(cntrl_traits_data, primary_samples) + _cntrl_variances, + _cntrl_ns) = control_samples(cntrl_traits_data, primary_samples) common_primary_control_samples = primary_samples fixed_primary_vals = primary_values @@ -547,8 +557,8 @@ def partial_correlations_entry( (common_primary_control_samples, fixed_primary_vals, fixed_control_vals, - primary_variances, - cntrl_variances) = fix_samples(primary_trait, cntrl_traits) + _primary_variances, + _cntrl_variances) = fix_samples(primary_trait, cntrl_traits) if len(common_primary_control_samples) < corr_min_informative: return { @@ -580,7 +590,6 @@ def partial_correlations_entry( tissue_probeset_freeze_id = 1 db_type = primary_trait["db"]["dataset_type"] - db_name = primary_trait["db"]["dataset_name"] if db_type == "ProbeSet" and method.lower() in ( "sgo literature correlation", @@ -605,10 +614,11 @@ def partial_correlations_entry( "associated Literature Information."), "error_type": "Literature Correlation"} - if (method.lower() in ( - "tissue correlation, pearson's r", - "tissue correlation, spearman's rho") - and input_trait_symbol is None): + if ( + method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho") + and input_trait_symbol is None): return { "status": "error", "message": ( @@ -616,11 +626,12 @@ def partial_correlations_entry( "any associated Tissue Correlation Information."), "error_type": "Tissue Correlation"} - if (method.lower() in ( - "tissue correlation, pearson's r", - "tissue correlation, spearman's rho") - and check_symbol_for_tissue_correlation( - conn, tissue_probeset_freeze_id, input_trait_symbol)): + if ( + method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho") + and check_symbol_for_tissue_correlation( + conn, tissue_probeset_freeze_id, input_trait_symbol)): return { "status": "error", "message": ( @@ -629,7 +640,7 @@ def partial_correlations_entry( "error_type": "Tissue Correlation"} database_filename = get_filename(conn, target_db_name, TEXTDIR) - total_traits, all_correlations = partial_corrs( + _total_traits, all_correlations = partial_corrs( conn, common_primary_control_samples, fixed_primary_vals, fixed_control_vals, len(fixed_primary_vals), species, input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id, @@ -637,11 +648,11 @@ def partial_correlations_entry( def __make_sorter__(method): - def __sort_6__(x): - return x[6] + def __sort_6__(row): + return row[6] - def __sort_3__(x): - return x[3] + def __sort_3__(row): + return row[3] if "literature" in method.lower(): return __sort_6__ @@ -655,33 +666,31 @@ def partial_correlations_entry( all_correlations, key=__make_sorter__(method)) add_lit_corr_and_tiss_corr = compose( - partial( - literature_correlation_by_list, conn, input_trait_mouse_geneid, - species), + partial(literature_correlation_by_list, conn, species), partial( tissue_correlation_by_list, conn, input_trait_symbol, tissue_probeset_freeze_id, method)) trait_list = add_lit_corr_and_tiss_corr(tuple( - { - **retrieve_trait_info( - threshold, - f"{primary_trait['db']['dataset_name']}::{item[0]}", - conn), - "noverlap": item[1], - "partial_corr": item[2], - "partial_corr_p_value": item[3], - "corr": item[4], - "corr_p_value": item[5], - "rank_order": (1 if "spearman" in method.lower() else 0), - **({ - "tissue_corr": item[6], - "tissue_p_value": item[7]} + { + **retrieve_trait_info( + threshold, + f"{primary_trait['db']['dataset_name']}::{item[0]}", + conn), + "noverlap": item[1], + "partial_corr": item[2], + "partial_corr_p_value": item[3], + "corr": item[4], + "corr_p_value": item[5], + "rank_order": (1 if "spearman" in method.lower() else 0), + **({ + "tissue_corr": item[6], + "tissue_p_value": item[7]} if len(item) == 8 else {}), - **({"l_corr": item[6]} + **({"l_corr": item[6]} if len(item) == 7 else {}) - } + } for item in - sorted_correlations[:min(criteria, len(all_correlations))])) + sorted_correlations[:min(criteria, len(all_correlations))])) return trait_list diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py index 2a38bae..3d12019 100644 --- a/gn3/db/correlations.py +++ b/gn3/db/correlations.py @@ -29,7 +29,7 @@ def get_filename(conn: Any, target_db_name: str, text_files_dir: str) -> Union[ filename = "ProbeSetFreezeId_{tid}_FullName_{fname}.txt".format( tid=result[0], fname=result[1].replace(' ', '_').replace('/', '_')) - return ((filename in os.listdir(text_file_dir)) + return ((filename in os.listdir(text_files_dir)) and f"{text_files_dir}/{filename}") return False @@ -280,7 +280,8 @@ def build_temporary_tissue_correlations_table( # We should probably pass the `correlations_of_all_tissue_traits` function # as an argument to this function and get rid of the one call immediately # following this comment. - from gn3.computations.partial_correlations import correlations_of_all_tissue_traits + from gn3.computations.partial_correlations import (#pylint: disable=[C0415, R0401] + correlations_of_all_tissue_traits) # This import above is necessary within the function to avoid # circular-imports. # diff --git a/gn3/db/traits.py b/gn3/db/traits.py index 75de4f4..d4a96f0 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -1,9 +1,10 @@ """This class contains functions relating to trait data manipulation""" import os -import MySQLdb from functools import reduce from typing import Any, Dict, Union, Sequence +import MySQLdb + from gn3.settings import TMPDIR from gn3.random import random_string from gn3.function_helpers import compose @@ -81,10 +82,10 @@ def export_trait_data( def get_trait_csv_sample_data(conn: Any, trait_name: int, phenotype_id: int): """Fetch a trait and return it as a csv string""" - def __float_strip(n): - if str(n)[-2:] == ".0": - return str(int(n)) - return str(n) + def __float_strip(num_str): + if str(num_str)[-2:] == ".0": + return str(int(num_str)) + return str(num_str) sql = ("SELECT DISTINCT Strain.Name, PublishData.value, " "PublishSE.error, NStrain.count FROM " "(PublishData, Strain, PublishXRef, PublishFreeze) " @@ -108,7 +109,7 @@ def get_trait_csv_sample_data(conn: Any, return "\n".join(csv_data) -def update_sample_data(conn: Any, +def update_sample_data(conn: Any, #pylint: disable=[R0913] trait_name: str, strain_name: str, phenotype_id: int, @@ -219,7 +220,7 @@ def delete_sample_data(conn: Any, "WHERE StrainId = %s AND DataId = %s" % (strain_id, data_id))) deleted_n_strains = cursor.rowcount - except Exception as e: + except Exception as e: #pylint: disable=[C0103, W0612] conn.rollback() raise MySQLdb.Error conn.commit() @@ -230,7 +231,7 @@ def delete_sample_data(conn: Any, deleted_se_data, deleted_n_strains) -def insert_sample_data(conn: Any, +def insert_sample_data(conn: Any, #pylint: disable=[R0913] trait_name: str, strain_name: str, phenotype_id: int, @@ -272,7 +273,7 @@ def insert_sample_data(conn: Any, "VALUES (%s, %s, %s)") % (strain_id, data_id, count)) inserted_n_strains = cursor.rowcount - except Exception as e: + except Exception as e: #pylint: disable=[C0103, W0612] conn.rollback() raise MySQLdb.Error return (inserted_published_data, @@ -450,7 +451,7 @@ def set_homologene_id_field(trait_type, trait_info, conn): Common postprocessing function for all trait types. Sets the value for the 'homologene' key.""" - def set_to_null(ti): return {**ti, "homologeneid": None} + def set_to_null(ti): return {**ti, "homologeneid": None} # pylint: disable=[C0103, C0321] functions_table = { "Temp": set_to_null, "Geno": set_to_null, @@ -656,8 +657,9 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any): query, {"trait_name": trait_info["trait_name"]}) return [dict(zip( - ["sample_name", "value", "se_error", "nstrain", "id"], row)) - for row in cursor.fetchall()] + ["sample_name", "value", "se_error", "nstrain", "id"], + row)) + for row in cursor.fetchall()] return [] @@ -696,8 +698,10 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any): "dataset_name": trait_info["db"]["dataset_name"], "species_id": retrieve_species_id( trait_info["db"]["group"], conn)}) - return [dict(zip( - ["sample_name", "value", "se_error", "id"], row)) + return [ + dict(zip( + ["sample_name", "value", "se_error", "id"], + row)) for row in cursor.fetchall()] return [] @@ -728,8 +732,9 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any): query, {"trait_name": trait_info["trait_name"], "dataset_id": trait_info["db"]["dataset_id"]}) - return [dict(zip( - ["sample_name", "value", "se_error", "nstrain", "id"], row)) + return [ + dict(zip( + ["sample_name", "value", "se_error", "nstrain", "id"], row)) for row in cursor.fetchall()] return [] @@ -762,8 +767,9 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any): {"cellid": trait_info["cellid"], "trait_name": trait_info["trait_name"], "dataset_id": trait_info["db"]["dataset_id"]}) - return [dict(zip( - ["sample_name", "value", "se_error", "id"], row)) + return [ + dict(zip( + ["sample_name", "value", "se_error", "id"], row)) for row in cursor.fetchall()] return [] @@ -792,8 +798,9 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any): query, {"trait_name": trait_info["trait_name"], "dataset_name": trait_info["db"]["dataset_name"]}) - return [dict(zip( - ["sample_name", "value", "se_error", "id"], row)) + return [ + dict(zip( + ["sample_name", "value", "se_error", "id"], row)) for row in cursor.fetchall()] return [] |