diff options
author | Frederick Muriuki Muriithi | 2021-12-06 14:04:59 +0300 |
---|---|---|
committer | Frederick Muriuki Muriithi | 2021-12-06 14:04:59 +0300 |
commit | 66406115f41594ba40e3fbbc6f69aace2d11800f (patch) | |
tree | 0f3de09b74a3f47918dd4a192665c8a06c508144 /gn3/db | |
parent | 77099cac68e8f4792bf54d8e1f7ce6f315bedfa7 (diff) | |
parent | 5d2248f1dabbc7dd04f48aafcc9f327817a9c92c (diff) | |
download | genenetwork3-66406115f41594ba40e3fbbc6f69aace2d11800f.tar.gz |
Merge branch 'partial-correlations'
Diffstat (limited to 'gn3/db')
-rw-r--r-- | gn3/db/correlations.py | 564 | ||||
-rw-r--r-- | gn3/db/species.py | 44 | ||||
-rw-r--r-- | gn3/db/traits.py | 361 |
3 files changed, 916 insertions, 53 deletions
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py new file mode 100644 index 0000000..3d12019 --- /dev/null +++ b/gn3/db/correlations.py @@ -0,0 +1,564 @@ +""" +This module will hold functions that are used in the (partial) correlations +feature to access the database to retrieve data needed for computations. +""" +import os +from functools import reduce +from typing import Any, Dict, Tuple, Union + +from gn3.random import random_string +from gn3.data_helpers import partition_all +from gn3.db.species import translate_to_mouse_gene_id + +def get_filename(conn: Any, target_db_name: str, text_files_dir: str) -> Union[ + str, bool]: + """ + Retrieve the name of the reference database file with which correlations are + computed. + + This is a migration of the + `web.webqtl.correlation.CorrelationPage.getFileName` function in + GeneNetwork1. + """ + with conn.cursor() as cursor: + cursor.execute( + "SELECT Id, FullName from ProbeSetFreeze WHERE Name=%s", + (target_db_name,)) + result = cursor.fetchone() + if result: + filename = "ProbeSetFreezeId_{tid}_FullName_{fname}.txt".format( + tid=result[0], + fname=result[1].replace(' ', '_').replace('/', '_')) + return ((filename in os.listdir(text_files_dir)) + and f"{text_files_dir}/{filename}") + + return False + +def build_temporary_literature_table( + conn: Any, species: str, gene_id: int, return_number: int) -> str: + """ + Build and populate a temporary table to hold the literature correlation data + to be used in computations. + + "This is a migration of the + `web.webqtl.correlation.CorrelationPage.getTempLiteratureTable` function in + GeneNetwork1. + """ + def __translated_species_id(row, cursor): + if species == "mouse": + return row[1] + query = { + "rat": "SELECT rat FROM GeneIDXRef WHERE mouse=%s", + "human": "SELECT human FROM GeneIDXRef WHERE mouse=%d"} + if species in query.keys(): + cursor.execute(query[species], row[1]) + record = cursor.fetchone() + if record: + return record[0] + return None + return None + + temp_table_name = f"TOPLITERATURE{random_string(8)}" + with conn.cursor as cursor: + mouse_geneid = translate_to_mouse_gene_id(species, gene_id, conn) + data_query = ( + "SELECT GeneId1, GeneId2, value FROM LCorrRamin3 " + "WHERE GeneId1 = %(mouse_gene_id)s " + "UNION ALL " + "SELECT GeneId2, GeneId1, value FROM LCorrRamin3 " + "WHERE GeneId2 = %(mouse_gene_id)s " + "AND GeneId1 != %(mouse_gene_id)s") + cursor.execute( + (f"CREATE TEMPORARY TABLE {temp_table_name} (" + "GeneId1 int(12) unsigned, " + "GeneId2 int(12) unsigned PRIMARY KEY, " + "value double)")) + cursor.execute(data_query, mouse_gene_id=mouse_geneid) + literature_data = [ + {"GeneId1": row[0], "GeneId2": row[1], "value": row[2]} + for row in cursor.fetchall() + if __translated_species_id(row, cursor)] + + cursor.execute( + (f"INSERT INTO {temp_table_name} " + "VALUES (%(GeneId1)s, %(GeneId2)s, %(value)s)"), + literature_data[0:(2 * return_number)]) + + return temp_table_name + +def fetch_geno_literature_correlations(temp_table: str) -> str: + """ + Helper function for `fetch_literature_correlations` below, to build query + for `Geno*` tables. + """ + return ( + f"SELECT Geno.Name, {temp_table}.value " + "FROM Geno, GenoXRef, GenoFreeze " + f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " + "WHERE ProbeSet.GeneId IS NOT NULL " + f"AND {temp_table}.value IS NOT NULL " + "AND GenoXRef.GenoFreezeId = GenoFreeze.Id " + "AND GenoFreeze.Name = %(db_name)s " + "AND Geno.Id=GenoXRef.GenoId " + "ORDER BY Geno.Id") + +def fetch_probeset_literature_correlations(temp_table: str) -> str: + """ + Helper function for `fetch_literature_correlations` below, to build query + for `ProbeSet*` tables. + """ + return ( + f"SELECT ProbeSet.Name, {temp_table}.value " + "FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze " + "LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " + "WHERE ProbeSet.GeneId IS NOT NULL " + "AND {temp_table}.value IS NOT NULL " + "AND ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id " + "AND ProbeSetFreeze.Name = %(db_name)s " + "AND ProbeSet.Id=ProbeSetXRef.ProbeSetId " + "ORDER BY ProbeSet.Id") + +def fetch_literature_correlations( + species: str, gene_id: int, dataset: dict, return_number: int, + conn: Any) -> dict: + """ + Gather the literature correlation data and pair it with trait id string(s). + + This is a migration of the + `web.webqtl.correlation.CorrelationPage.fetchLitCorrelations` function in + GeneNetwork1. + """ + temp_table = build_temporary_literature_table( + conn, species, gene_id, return_number) + query_fns = { + "Geno": fetch_geno_literature_correlations, + # "Temp": fetch_temp_literature_correlations, + # "Publish": fetch_publish_literature_correlations, + "ProbeSet": fetch_probeset_literature_correlations} + with conn.cursor as cursor: + cursor.execute( + query_fns[dataset["dataset_type"]](temp_table), + db_name=dataset["dataset_name"]) + results = cursor.fetchall() + cursor.execute("DROP TEMPORARY TABLE %s", temp_table) + return dict(results) + +def fetch_symbol_value_pair_dict( + symbol_list: Tuple[str, ...], data_id_dict: dict, + conn: Any) -> Dict[str, Tuple[float, ...]]: + """ + Map each gene symbols to the corresponding tissue expression data. + + This is a migration of the + `web.webqtl.correlation.correlationFunction.getSymbolValuePairDict` function + in GeneNetwork1. + """ + data_ids = { + symbol: data_id_dict.get(symbol) for symbol in symbol_list + if data_id_dict.get(symbol) is not None + } + query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN %(data_ids)s" + with conn.cursor() as cursor: + cursor.execute( + query, + data_ids=tuple(data_ids.values())) + value_results = cursor.fetchall() + return { + key: tuple(row[1] for row in value_results if row[0] == key) + for key in data_ids.keys() + } + + return {} + +def fetch_gene_symbol_tissue_value_dict( + symbol_list: Tuple[str, ...], data_id_dict: dict, conn: Any, + limit_num: int = 1000) -> dict:#getGeneSymbolTissueValueDict + """ + Wrapper function for `gn3.db.correlations.fetch_symbol_value_pair_dict`. + + This is a migrations of the + `web.webqtl.correlation.correlationFunction.getGeneSymbolTissueValueDict` in + GeneNetwork1. + """ + count = len(symbol_list) + if count != 0 and count <= limit_num: + return fetch_symbol_value_pair_dict(symbol_list, data_id_dict, conn) + + if count > limit_num: + return { + key: value for dct in [ + fetch_symbol_value_pair_dict(sl, data_id_dict, conn) + for sl in partition_all(limit_num, symbol_list)] + for key, value in dct.items() + } + + return {} + +def fetch_tissue_probeset_xref_info( + gene_name_list: Tuple[str, ...], probeset_freeze_id: int, + conn: Any) -> Tuple[tuple, dict, dict, dict, dict, dict, dict]: + """ + Retrieve the ProbeSet XRef information for tissues. + + This is a migration of the + `web.webqtl.correlation.correlationFunction.getTissueProbeSetXRefInfo` + function in GeneNetwork1.""" + with conn.cursor() as cursor: + if len(gene_name_list) == 0: + query = ( + "SELECT t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, " + "t.description, t.Probe_Target_Description " + "FROM " + "(" + " SELECT Symbol, max(Mean) AS maxmean " + " FROM TissueProbeSetXRef " + " WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s " + " AND Symbol != '' " + " AND Symbol IS NOT NULL " + " GROUP BY Symbol" + ") AS x " + "INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol " + "AND t.Mean = x.maxmean") + cursor.execute(query, probeset_freeze_id=probeset_freeze_id) + else: + query = ( + "SELECT t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, " + "t.description, t.Probe_Target_Description " + "FROM " + "(" + " SELECT Symbol, max(Mean) AS maxmean " + " FROM TissueProbeSetXRef " + " WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s " + " AND Symbol in %(symbols)s " + " GROUP BY Symbol" + ") AS x " + "INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol " + "AND t.Mean = x.maxmean") + cursor.execute( + query, probeset_freeze_id=probeset_freeze_id, + symbols=tuple(gene_name_list)) + + results = cursor.fetchall() + + return reduce( + lambda acc, item: ( + acc[0] + (item[0],), + {**acc[1], item[0].lower(): item[1]}, + {**acc[1], item[0].lower(): item[2]}, + {**acc[1], item[0].lower(): item[3]}, + {**acc[1], item[0].lower(): item[4]}, + {**acc[1], item[0].lower(): item[5]}, + {**acc[1], item[0].lower(): item[6]}), + results or tuple(), + (tuple(), {}, {}, {}, {}, {}, {})) + +def fetch_gene_symbol_tissue_value_dict_for_trait( + gene_name_list: Tuple[str, ...], probeset_freeze_id: int, + conn: Any) -> dict: + """ + Fetches a map of the gene symbols to the tissue values. + + This is a migration of the + `web.webqtl.correlation.correlationFunction.getGeneSymbolTissueValueDictForTrait` + function in GeneNetwork1. + """ + xref_info = fetch_tissue_probeset_xref_info( + gene_name_list, probeset_freeze_id, conn) + if xref_info[0]: + return fetch_gene_symbol_tissue_value_dict(xref_info[0], xref_info[2], conn) + return {} + +def build_temporary_tissue_correlations_table( + conn: Any, trait_symbol: str, probeset_freeze_id: int, method: str, + return_number: int) -> str: + """ + Build a temporary table to hold the tissue correlations data. + + This is a migration of the + `web.webqtl.correlation.CorrelationPage.getTempTissueCorrTable` function in + GeneNetwork1.""" + # We should probably pass the `correlations_of_all_tissue_traits` function + # as an argument to this function and get rid of the one call immediately + # following this comment. + from gn3.computations.partial_correlations import (#pylint: disable=[C0415, R0401] + correlations_of_all_tissue_traits) + # This import above is necessary within the function to avoid + # circular-imports. + # + # + # This import above is indicative of convoluted code, with the computation + # being interwoven with the data retrieval. This needs to be changed, such + # that the function being imported here is no longer necessary, or have the + # imported function passed to this function as an argument. + symbol_corr_dict, symbol_p_value_dict = correlations_of_all_tissue_traits( + fetch_gene_symbol_tissue_value_dict_for_trait( + (trait_symbol,), probeset_freeze_id, conn), + fetch_gene_symbol_tissue_value_dict_for_trait( + tuple(), probeset_freeze_id, conn), + method) + + symbol_corr_list = sorted( + symbol_corr_dict.items(), key=lambda key_val: key_val[1]) + + temp_table_name = f"TOPTISSUE{random_string(8)}" + create_query = ( + "CREATE TEMPORARY TABLE {temp_table_name}" + "(Symbol varchar(100) PRIMARY KEY, Correlation float, PValue float)") + insert_query = ( + f"INSERT INTO {temp_table_name}(Symbol, Correlation, PValue) " + " VALUES (%(symbol)s, %(correlation)s, %(pvalue)s)") + + with conn.cursor() as cursor: + cursor.execute(create_query) + cursor.execute( + insert_query, + tuple({ + "symbol": symbol, + "correlation": corr, + "pvalue": symbol_p_value_dict[symbol] + } for symbol, corr in symbol_corr_list[0: 2 * return_number])) + + return temp_table_name + +def fetch_tissue_correlations(# pylint: disable=R0913 + dataset: dict, trait_symbol: str, probeset_freeze_id: int, method: str, + return_number: int, conn: Any) -> dict: + """ + Pair tissue correlations data with a trait id string. + + This is a migration of the + `web.webqtl.correlation.CorrelationPage.fetchTissueCorrelations` function in + GeneNetwork1. + """ + temp_table = build_temporary_tissue_correlations_table( + conn, trait_symbol, probeset_freeze_id, method, return_number) + with conn.cursor() as cursor: + cursor.execute( + ( + f"SELECT ProbeSet.Name, {temp_table}.Correlation, " + f"{temp_table}.PValue " + "FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) " + "LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol " + "WHERE ProbeSetFreeze.Name = %(db_name) " + "AND ProbeSetFreeze.Id=ProbeSetXRef.ProbeSetFreezeId " + "AND ProbeSet.Id = ProbeSetXRef.ProbeSetId " + "AND ProbeSet.Symbol IS NOT NULL " + "AND %s.Correlation IS NOT NULL"), + db_name=dataset["dataset_name"]) + results = cursor.fetchall() + cursor.execute("DROP TEMPORARY TABLE %s", temp_table) + return { + trait_name: (tiss_corr, tiss_p_val) + for trait_name, tiss_corr, tiss_p_val in results} + +def check_for_literature_info(conn: Any, geneid: int) -> bool: + """ + Checks the database to find out whether the trait with `geneid` has any + associated literature. + + This is a migration of the + `web.webqtl.correlation.CorrelationPage.checkForLitInfo` function in + GeneNetwork1. + """ + query = "SELECT 1 FROM LCorrRamin3 WHERE GeneId1=%s LIMIT 1" + with conn.cursor() as cursor: + cursor.execute(query, geneid) + result = cursor.fetchone() + if result: + return True + + return False + +def check_symbol_for_tissue_correlation( + conn: Any, tissue_probeset_freeze_id: int, symbol: str = "") -> bool: + """ + Checks whether a symbol has any associated tissue correlations. + + This is a migration of the + `web.webqtl.correlation.CorrelationPage.checkSymbolForTissueCorr` function + in GeneNetwork1. + """ + query = ( + "SELECT 1 FROM TissueProbeSetXRef " + "WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s " + "AND Symbol=%(symbol)s LIMIT 1") + with conn.cursor() as cursor: + cursor.execute( + query, probeset_freeze_id=tissue_probeset_freeze_id, symbol=symbol) + result = cursor.fetchone() + if result: + return True + + return False + +def fetch_sample_ids( + conn: Any, sample_names: Tuple[str, ...], species_name: str) -> Tuple[ + int, ...]: + """ + Given a sequence of sample names, and a species name, return the sample ids + that correspond to both. + + This is a partial migration of the + `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in + GeneNetwork1. + """ + query = ( + "SELECT Strain.Id FROM Strain, Species " + "WHERE Strain.Name IN %(samples_names)s " + "AND Strain.SpeciesId=Species.Id " + "AND Species.name=%(species_name)s") + with conn.cursor() as cursor: + cursor.execute( + query, + { + "samples_names": tuple(sample_names), + "species_name": species_name + }) + return tuple(row[0] for row in cursor.fetchall()) + +def build_query_sgo_lit_corr( + db_type: str, temp_table: str, sample_id_columns: str, + joins: Tuple[str, ...]) -> str: + """ + Build query for `SGO Literature Correlation` data, when querying the given + `temp_table` temporary table. + + This is a partial migration of the + `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in + GeneNetwork1. + """ + return ( + (f"SELECT {db_type}.Name, {temp_table}.value, " + + sample_id_columns + + f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + + f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " + + " ".join(joins) + + " WHERE ProbeSet.GeneId IS NOT NULL " + + f"AND {temp_table}.value IS NOT NULL " + + f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + + f"AND {db_type}Freeze.Name = %(db_name)s " + + f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " + + f"ORDER BY {db_type}.Id"), + 2) + +def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins): + """ + Build query for `Tissue Correlation` data, when querying the given + `temp_table` temporary table. + + This is a partial migration of the + `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in + GeneNetwork1. + """ + return ( + (f"SELECT {db_type}.Name, {temp_table}.Correlation, " + + f"{temp_table}.PValue, " + + sample_id_columns + + f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + + f"LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol " + + " ".join(joins) + + " WHERE ProbeSet.Symbol IS NOT NULL " + + f"AND {temp_table}.Correlation IS NOT NULL " + + f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + + f"AND {db_type}Freeze.Name = %(db_name)s " + + f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " + f"ORDER BY {db_type}.Id"), + 3) + +def fetch_all_database_data(# pylint: disable=[R0913, R0914] + conn: Any, species: str, gene_id: int, trait_symbol: str, + samples: Tuple[str, ...], dataset: dict, method: str, + return_number: int, probeset_freeze_id: int) -> Tuple[ + Tuple[float], int]: + """ + This is a migration of the + `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in + GeneNetwork1. + """ + db_type = dataset["dataset_type"] + db_name = dataset["dataset_name"] + def __build_query__(sample_ids, temp_table): + sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids) + if db_type == "Publish": + joins = tuple( + ("LEFT JOIN PublishData AS T{item} " + "ON T{item}.Id = PublishXRef.DataId " + "AND T{item}.StrainId = %(T{item}_sample_id)s") + for item in sample_ids) + return ( + ("SELECT PublishXRef.Id, " + + sample_id_columns + + "FROM (PublishXRef, PublishFreeze) " + + " ".join(joins) + + " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId " + "AND PublishFreeze.Name = %(db_name)s"), + 1) + if temp_table is not None: + joins = tuple( + (f"LEFT JOIN {db_type}Data AS T{item} " + f"ON T{item}.Id = {db_type}XRef.DataId " + f"AND T{item}.StrainId=%(T{item}_sample_id)s") + for item in sample_ids) + if method.lower() == "sgo literature correlation": + return build_query_sgo_lit_corr( + sample_ids, temp_table, sample_id_columns, joins) + if method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho"): + return build_query_tissue_corr( + sample_ids, temp_table, sample_id_columns, joins) + joins = tuple( + (f"LEFT JOIN {db_type}Data AS T{item} " + f"ON T{item}.Id = {db_type}XRef.DataId " + f"AND T{item}.StrainId = %(T{item}_sample_id)s") + for item in sample_ids) + return ( + ( + f"SELECT {db_type}.Name, " + + sample_id_columns + + f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + + " ".join(joins) + + f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + + f"AND {db_type}Freeze.Name = %(db_name)s " + + f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " + + f"ORDER BY {db_type}.Id"), + 1) + + def __fetch_data__(sample_ids, temp_table): + query, data_start_pos = __build_query__(sample_ids, temp_table) + with conn.cursor() as cursor: + cursor.execute( + query, + {"db_name": db_name, + **{f"T{item}_sample_id": item for item in sample_ids}}) + return (cursor.fetchall(), data_start_pos) + + sample_ids = tuple( + # look into graduating this to an argument and removing the `samples` + # and `species` argument: function currying and compositions might help + # with this + f"{sample_id}" for sample_id in + fetch_sample_ids(conn, samples, species)) + + temp_table = None + if gene_id and db_type == "probeset": + if method.lower() == "sgo literature correlation": + temp_table = build_temporary_literature_table( + conn, species, gene_id, return_number) + if method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho"): + temp_table = build_temporary_tissue_correlations_table( + conn, trait_symbol, probeset_freeze_id, method, return_number) + + trait_database = tuple( + item for sublist in + (__fetch_data__(ssample_ids, temp_table) + for ssample_ids in partition_all(25, sample_ids)) + for item in sublist) + + if temp_table: + with conn.cursor() as cursor: + cursor.execute(f"DROP TEMPORARY TABLE {temp_table}") + + return (trait_database[0], trait_database[1]) diff --git a/gn3/db/species.py b/gn3/db/species.py index 0deae4e..5b8e096 100644 --- a/gn3/db/species.py +++ b/gn3/db/species.py @@ -30,3 +30,47 @@ def get_chromosome(name: str, is_species: bool, conn: Any) -> Optional[Tuple]: with conn.cursor() as cursor: cursor.execute(_sql) return cursor.fetchall() + +def translate_to_mouse_gene_id(species: str, geneid: int, conn: Any) -> int: + """ + Translate rat or human geneid to mouse geneid + + This is a migration of the + `web.webqtl.correlation/CorrelationPage.translateToMouseGeneID` function in + GN1 + """ + assert species in ("rat", "mouse", "human"), "Invalid species" + if geneid is None: + return 0 + + if species == "mouse": + return geneid + + with conn.cursor as cursor: + query = { + "rat": "SELECT mouse FROM GeneIDXRef WHERE rat = %s", + "human": "SELECT mouse FROM GeneIDXRef WHERE human = %s" + } + cursor.execute(query[species], geneid) + translated_gene_id = cursor.fetchone() + if translated_gene_id: + return translated_gene_id[0] + + return 0 # default if all else fails + +def species_name(conn: Any, group: str) -> str: + """ + Retrieve the name of the species, given the group (RISet). + + This is a migration of the + `web.webqtl.dbFunction.webqtlDatabaseFunction.retrieveSpecies` function in + GeneNetwork1. + """ + with conn.cursor() as cursor: + cursor.execute( + ("SELECT Species.Name FROM Species, InbredSet " + "WHERE InbredSet.Name = %(group_name)s " + "AND InbredSet.SpeciesId = Species.Id"), + {"group_name": group}) + return cursor.fetchone()[0] + return None diff --git a/gn3/db/traits.py b/gn3/db/traits.py index f2673c8..4098b08 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -1,17 +1,93 @@ """This class contains functions relating to trait data manipulation""" import os +import MySQLdb +from functools import reduce from typing import Any, Dict, Union, Sequence + +import MySQLdb + from gn3.settings import TMPDIR from gn3.random import random_string from gn3.function_helpers import compose from gn3.db.datasets import retrieve_trait_dataset +def export_trait_data( + trait_data: dict, samplelist: Sequence[str], dtype: str = "val", + var_exists: bool = False, n_exists: bool = False): + """ + Export data according to `samplelist`. Mostly used in calculating + correlations. + + DESCRIPTION: + Migrated from + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211 + + PARAMETERS + trait: (dict) + The dictionary of key-value pairs representing a trait + samplelist: (list) + A list of sample names + dtype: (str) + ... verify what this is ... + var_exists: (bool) + A flag indicating existence of variance + n_exists: (bool) + A flag indicating existence of ndata + """ + def __export_all_types(tdata, sample): + sample_data = [] + if tdata[sample]["value"]: + sample_data.append(tdata[sample]["value"]) + if var_exists: + if tdata[sample]["variance"]: + sample_data.append(tdata[sample]["variance"]) + else: + sample_data.append(None) + if n_exists: + if tdata[sample]["ndata"]: + sample_data.append(tdata[sample]["ndata"]) + else: + sample_data.append(None) + else: + if var_exists and n_exists: + sample_data += [None, None, None] + elif var_exists or n_exists: + sample_data += [None, None] + else: + sample_data.append(None) + + return tuple(sample_data) + + def __exporter(accumulator, sample): + # pylint: disable=[R0911] + if sample in trait_data["data"]: + if dtype == "val": + return accumulator + (trait_data["data"][sample]["value"], ) + if dtype == "var": + return accumulator + (trait_data["data"][sample]["variance"], ) + if dtype == "N": + return accumulator + (trait_data["data"][sample]["ndata"], ) + if dtype == "all": + return accumulator + __export_all_types(trait_data["data"], sample) + raise KeyError("Type `%s` is incorrect" % dtype) + if var_exists and n_exists: + return accumulator + (None, None, None) + if var_exists or n_exists: + return accumulator + (None, None) + return accumulator + (None,) + + return reduce(__exporter, samplelist, tuple()) + + def get_trait_csv_sample_data(conn: Any, trait_name: int, phenotype_id: int): """Fetch a trait and return it as a csv string""" - sql = ("SELECT DISTINCT Strain.Id, PublishData.Id, Strain.Name, " - "PublishData.value, " + def __float_strip(num_str): + if str(num_str)[-2:] == ".0": + return str(int(num_str)) + return str(num_str) + sql = ("SELECT DISTINCT Strain.Name, PublishData.value, " "PublishSE.error, NStrain.count FROM " "(PublishData, Strain, PublishXRef, PublishFreeze) " "LEFT JOIN PublishSE ON " @@ -23,65 +99,189 @@ def get_trait_csv_sample_data(conn: Any, "PublishData.Id = PublishXRef.DataId AND " "PublishXRef.Id = %s AND PublishXRef.PhenotypeId = %s " "AND PublishData.StrainId = Strain.Id Order BY Strain.Name") - csv_data = ["Strain Id,Strain Name,Value,SE,Count"] - publishdata_id = "" + csv_data = ["Strain Name,Value,SE,Count"] with conn.cursor() as cursor: cursor.execute(sql, (trait_name, phenotype_id,)) for record in cursor.fetchall(): - (strain_id, publishdata_id, - strain_name, value, error, count) = record + (strain_name, value, error, count) = record csv_data.append( - ",".join([str(val) if val else "x" - for val in (strain_id, strain_name, - value, error, count)])) - return f"# Publish Data Id: {publishdata_id}\n\n" + "\n".join(csv_data) + ",".join([__float_strip(val) if val else "x" + for val in (strain_name, value, error, count)])) + return "\n".join(csv_data) + +def update_sample_data(conn: Any, #pylint: disable=[R0913] -def update_sample_data(conn: Any, + trait_name: str, strain_name: str, - strain_id: int, - publish_data_id: int, + phenotype_id: int, value: Union[int, float, str], error: Union[int, float, str], count: Union[int, str]): """Given the right parameters, update sample-data from the relevant table.""" - # pylint: disable=[R0913, R0914, C0103] - STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s" - PUBLISH_DATA_SQL: str = ("UPDATE PublishData SET value = %s " - "WHERE StrainId = %s AND Id = %s") - PUBLISH_SE_SQL: str = ("UPDATE PublishSE SET error = %s " - "WHERE StrainId = %s AND DataId = %s") - N_STRAIN_SQL: str = ("UPDATE NStrain SET count = %s " - "WHERE StrainId = %s AND DataId = %s") - - updated_strains: int = 0 + strain_id, data_id = "", "" + + with conn.cursor() as cursor: + cursor.execute( + ("SELECT Strain.Id, PublishData.Id FROM " + "(PublishData, Strain, PublishXRef, PublishFreeze) " + "LEFT JOIN PublishSE ON " + "(PublishSE.DataId = PublishData.Id AND " + "PublishSE.StrainId = PublishData.StrainId) " + "LEFT JOIN NStrain ON " + "(NStrain.DataId = PublishData.Id AND " + "NStrain.StrainId = PublishData.StrainId) " + "WHERE PublishXRef.InbredSetId = " + "PublishFreeze.InbredSetId AND " + "PublishData.Id = PublishXRef.DataId AND " + "PublishXRef.Id = %s AND " + "PublishXRef.PhenotypeId = %s " + "AND PublishData.StrainId = Strain.Id " + "AND Strain.Name = \"%s\"") % (trait_name, + phenotype_id, + str(strain_name))) + strain_id, data_id = cursor.fetchone() updated_published_data: int = 0 updated_se_data: int = 0 updated_n_strains: int = 0 with conn.cursor() as cursor: - # Update the Strains table - cursor.execute(STRAIN_ID_SQL, (strain_name, strain_id)) - updated_strains = cursor.rowcount # Update the PublishData table - cursor.execute(PUBLISH_DATA_SQL, + cursor.execute(("UPDATE PublishData SET value = %s " + "WHERE StrainId = %s AND Id = %s"), (None if value == "x" else value, - strain_id, publish_data_id)) + strain_id, data_id)) updated_published_data = cursor.rowcount + # Update the PublishSE table - cursor.execute(PUBLISH_SE_SQL, + cursor.execute(("UPDATE PublishSE SET error = %s " + "WHERE StrainId = %s AND DataId = %s"), (None if error == "x" else error, - strain_id, publish_data_id)) + strain_id, data_id)) updated_se_data = cursor.rowcount + # Update the NStrain table - cursor.execute(N_STRAIN_SQL, + cursor.execute(("UPDATE NStrain SET count = %s " + "WHERE StrainId = %s AND DataId = %s"), (None if count == "x" else count, - strain_id, publish_data_id)) + strain_id, data_id)) updated_n_strains = cursor.rowcount - return (updated_strains, updated_published_data, + return (updated_published_data, updated_se_data, updated_n_strains) + +def delete_sample_data(conn: Any, + trait_name: str, + strain_name: str, + phenotype_id: int): + """Given the right parameters, delete sample-data from the relevant + table.""" + strain_id, data_id = "", "" + + deleted_published_data: int = 0 + deleted_se_data: int = 0 + deleted_n_strains: int = 0 + + with conn.cursor() as cursor: + # Delete the PublishData table + try: + cursor.execute( + ("SELECT Strain.Id, PublishData.Id FROM " + "(PublishData, Strain, PublishXRef, PublishFreeze) " + "LEFT JOIN PublishSE ON " + "(PublishSE.DataId = PublishData.Id AND " + "PublishSE.StrainId = PublishData.StrainId) " + "LEFT JOIN NStrain ON " + "(NStrain.DataId = PublishData.Id AND " + "NStrain.StrainId = PublishData.StrainId) " + "WHERE PublishXRef.InbredSetId = " + "PublishFreeze.InbredSetId AND " + "PublishData.Id = PublishXRef.DataId AND " + "PublishXRef.Id = %s AND " + "PublishXRef.PhenotypeId = %s " + "AND PublishData.StrainId = Strain.Id " + "AND Strain.Name = \"%s\"") % (trait_name, + phenotype_id, + str(strain_name))) + strain_id, data_id = cursor.fetchone() + + cursor.execute(("DELETE FROM PublishData " + "WHERE StrainId = %s AND Id = %s") + % (strain_id, data_id)) + deleted_published_data = cursor.rowcount + + # Delete the PublishSE table + cursor.execute(("DELETE FROM PublishSE " + "WHERE StrainId = %s AND DataId = %s") % + (strain_id, data_id)) + deleted_se_data = cursor.rowcount + + # Delete the NStrain table + cursor.execute(("DELETE FROM NStrain " + "WHERE StrainId = %s AND DataId = %s" % + (strain_id, data_id))) + deleted_n_strains = cursor.rowcount + except Exception as e: #pylint: disable=[C0103, W0612] + conn.rollback() + raise MySQLdb.Error + conn.commit() + cursor.close() + cursor.close() + + return (deleted_published_data, + deleted_se_data, deleted_n_strains) + + +def insert_sample_data(conn: Any, #pylint: disable=[R0913] + trait_name: str, + strain_name: str, + phenotype_id: int, + value: Union[int, float, str], + error: Union[int, float, str], + count: Union[int, str]): + """Given the right parameters, insert sample-data to the relevant table. + + """ + + inserted_published_data, inserted_se_data, inserted_n_strains = 0, 0, 0 + with conn.cursor() as cursor: + try: + cursor.execute("SELECT DataId FROM PublishXRef WHERE Id = %s AND " + "PhenotypeId = %s", (trait_name, phenotype_id)) + data_id = cursor.fetchone() + + cursor.execute("SELECT Id FROM Strain WHERE Name = %s", + (strain_name,)) + strain_id = cursor.fetchone() + + # Insert the PublishData table + cursor.execute(("INSERT INTO PublishData (Id, StrainId, value)" + "VALUES (%s, %s, %s)"), + (data_id, strain_id, value)) + inserted_published_data = cursor.rowcount + + # Insert into the PublishSE table if error is specified + if error and error != "x": + cursor.execute(("INSERT INTO PublishSE (StrainId, DataId, " + " error) VALUES (%s, %s, %s)") % + (strain_id, data_id, error)) + inserted_se_data = cursor.rowcount + + # Insert into the NStrain table + if count and count != "x": + cursor.execute(("INSERT INTO NStrain " + "(StrainId, DataId, error) " + "VALUES (%s, %s, %s)") % + (strain_id, data_id, count)) + inserted_n_strains = cursor.rowcount + except Exception as e: #pylint: disable=[C0103, W0612] + conn.rollback() + raise MySQLdb.Error + return (inserted_published_data, + inserted_se_data, inserted_n_strains) + + def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any): """Retrieve trait information for type `Publish` traits. @@ -121,11 +321,12 @@ def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any): cursor.execute( query, { - k:v for k, v in trait_data_source.items() + k: v for k, v in trait_data_source.items() if k in ["trait_name", "trait_dataset_id"] }) return dict(zip([k.lower() for k in keys], cursor.fetchone())) + def set_confidential_field(trait_type, trait_info): """Post processing function for 'Publish' trait types. @@ -138,6 +339,7 @@ def set_confidential_field(trait_type, trait_info): and not trait_info.get("pubmed_id", None)) else 0} return trait_info + def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any): """Retrieve trait information for type `ProbeSet` traits. @@ -165,11 +367,12 @@ def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any): cursor.execute( query, { - k:v for k, v in trait_data_source.items() + k: v for k, v in trait_data_source.items() if k in ["trait_name", "trait_dataset_name"] }) return dict(zip(keys, cursor.fetchone())) + def retrieve_geno_trait_info(trait_data_source: Dict[str, Any], conn: Any): """Retrieve trait information for type `Geno` traits. @@ -189,11 +392,12 @@ def retrieve_geno_trait_info(trait_data_source: Dict[str, Any], conn: Any): cursor.execute( query, { - k:v for k, v in trait_data_source.items() + k: v for k, v in trait_data_source.items() if k in ["trait_name", "trait_dataset_name"] }) return dict(zip(keys, cursor.fetchone())) + def retrieve_temp_trait_info(trait_data_source: Dict[str, Any], conn: Any): """Retrieve trait information for type `Temp` traits. @@ -206,11 +410,12 @@ def retrieve_temp_trait_info(trait_data_source: Dict[str, Any], conn: Any): cursor.execute( query, { - k:v for k, v in trait_data_source.items() + k: v for k, v in trait_data_source.items() if k in ["trait_name"] }) return dict(zip(keys, cursor.fetchone())) + def set_haveinfo_field(trait_info): """ Common postprocessing function for all trait types. @@ -218,6 +423,7 @@ def set_haveinfo_field(trait_info): Sets the value for the 'haveinfo' field.""" return {**trait_info, "haveinfo": 1 if trait_info else 0} + def set_homologene_id_field_probeset(trait_info, conn): """ Postprocessing function for 'ProbeSet' traits. @@ -233,7 +439,7 @@ def set_homologene_id_field_probeset(trait_info, conn): cursor.execute( query, { - k:v for k, v in trait_info.items() + k: v for k, v in trait_info.items() if k in ["geneid", "group"] }) res = cursor.fetchone() @@ -241,12 +447,13 @@ def set_homologene_id_field_probeset(trait_info, conn): return {**trait_info, "homologeneid": res[0]} return {**trait_info, "homologeneid": None} + def set_homologene_id_field(trait_type, trait_info, conn): """ Common postprocessing function for all trait types. Sets the value for the 'homologene' key.""" - set_to_null = lambda ti: {**ti, "homologeneid": None} + def set_to_null(ti): return {**ti, "homologeneid": None} # pylint: disable=[C0103, C0321] functions_table = { "Temp": set_to_null, "Geno": set_to_null, @@ -255,6 +462,7 @@ def set_homologene_id_field(trait_type, trait_info, conn): } return functions_table[trait_type](trait_info) + def load_publish_qtl_info(trait_info, conn): """ Load extra QTL information for `Publish` traits @@ -275,6 +483,7 @@ def load_publish_qtl_info(trait_info, conn): return dict(zip(["locus", "lrs", "additive"], cursor.fetchone())) return {"locus": "", "lrs": "", "additive": ""} + def load_probeset_qtl_info(trait_info, conn): """ Load extra QTL information for `ProbeSet` traits @@ -297,6 +506,7 @@ def load_probeset_qtl_info(trait_info, conn): ["locus", "lrs", "pvalue", "mean", "additive"], cursor.fetchone())) return {"locus": "", "lrs": "", "pvalue": "", "mean": "", "additive": ""} + def load_qtl_info(qtl, trait_type, trait_info, conn): """ Load extra QTL information for traits @@ -325,6 +535,7 @@ def load_qtl_info(qtl, trait_type, trait_info, conn): return qtl_info_functions[trait_type](trait_info, conn) + def build_trait_name(trait_fullname): """ Initialises the trait's name, and other values from the search data provided @@ -351,6 +562,7 @@ def build_trait_name(trait_fullname): "cellid": name_parts[2] if len(name_parts) == 3 else "" } + def retrieve_probeset_sequence(trait, conn): """ Retrieve a 'ProbeSet' trait's sequence information @@ -372,6 +584,7 @@ def retrieve_probeset_sequence(trait, conn): seq = cursor.fetchone() return {**trait, "sequence": seq[0] if seq else ""} + def retrieve_trait_info( threshold: int, trait_full_name: str, conn: Any, qtl=None): @@ -427,6 +640,7 @@ def retrieve_trait_info( } return trait_info + def retrieve_temp_trait_data(trait_info: dict, conn: Any): """ Retrieve trait data for `Temp` traits. @@ -445,10 +659,12 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any): query, {"trait_name": trait_info["trait_name"]}) return [dict(zip( - ["sample_name", "value", "se_error", "nstrain", "id"], row)) + ["sample_name", "value", "se_error", "nstrain", "id"], + row)) for row in cursor.fetchall()] return [] + def retrieve_species_id(group, conn: Any): """ Retrieve a species id given the Group value @@ -460,6 +676,7 @@ def retrieve_species_id(group, conn: Any): return cursor.fetchone()[0] return None + def retrieve_geno_trait_data(trait_info: Dict, conn: Any): """ Retrieve trait data for `Geno` traits. @@ -483,11 +700,14 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any): "dataset_name": trait_info["db"]["dataset_name"], "species_id": retrieve_species_id( trait_info["db"]["group"], conn)}) - return [dict(zip( - ["sample_name", "value", "se_error", "id"], row)) - for row in cursor.fetchall()] + return [ + dict(zip( + ["sample_name", "value", "se_error", "id"], + row)) + for row in cursor.fetchall()] return [] + def retrieve_publish_trait_data(trait_info: Dict, conn: Any): """ Retrieve trait data for `Publish` traits. @@ -514,11 +734,13 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any): query, {"trait_name": trait_info["trait_name"], "dataset_id": trait_info["db"]["dataset_id"]}) - return [dict(zip( - ["sample_name", "value", "se_error", "nstrain", "id"], row)) - for row in cursor.fetchall()] + return [ + dict(zip( + ["sample_name", "value", "se_error", "nstrain", "id"], row)) + for row in cursor.fetchall()] return [] + def retrieve_cellid_trait_data(trait_info: Dict, conn: Any): """ Retrieve trait data for `Probe Data` types. @@ -547,11 +769,13 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any): {"cellid": trait_info["cellid"], "trait_name": trait_info["trait_name"], "dataset_id": trait_info["db"]["dataset_id"]}) - return [dict(zip( - ["sample_name", "value", "se_error", "id"], row)) - for row in cursor.fetchall()] + return [ + dict(zip( + ["sample_name", "value", "se_error", "id"], row)) + for row in cursor.fetchall()] return [] + def retrieve_probeset_trait_data(trait_info: Dict, conn: Any): """ Retrieve trait data for `ProbeSet` traits. @@ -576,11 +800,13 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any): query, {"trait_name": trait_info["trait_name"], "dataset_name": trait_info["db"]["dataset_name"]}) - return [dict(zip( - ["sample_name", "value", "se_error", "id"], row)) - for row in cursor.fetchall()] + return [ + dict(zip( + ["sample_name", "value", "se_error", "id"], row)) + for row in cursor.fetchall()] return [] + def with_samplelist_data_setup(samplelist: Sequence[str]): """ Build function that computes the trait data from provided list of samples. @@ -607,6 +833,7 @@ def with_samplelist_data_setup(samplelist: Sequence[str]): return None return setup_fn + def without_samplelist_data_setup(): """ Build function that computes the trait data. @@ -627,6 +854,7 @@ def without_samplelist_data_setup(): return None return setup_fn + def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tuple()): """ Retrieve trait data @@ -666,11 +894,38 @@ def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tupl "data": dict(map( lambda x: ( x["sample_name"], - {k:v for k, v in x.items() if x != "sample_name"}), + {k: v for k, v in x.items() if x != "sample_name"}), data))} return {} + def generate_traits_filename(base_path: str = TMPDIR): """Generate a unique filename for use with generated traits files.""" return "{}/traits_test_file_{}.txt".format( os.path.abspath(base_path), random_string(10)) + + +def export_informative(trait_data: dict, inc_var: bool = False) -> tuple: + """ + Export informative strain + + This is a migration of the `exportInformative` function in + web/webqtl/base/webqtlTrait.py module in GeneNetwork1. + + There is a chance that the original implementation has a bug, especially + dealing with the `inc_var` value. It the `inc_var` value is meant to control + the inclusion of the `variance` value, then the current implementation, and + that one in GN1 have a bug. + """ + def __exporter__(acc, data_item): + if not inc_var or data_item["variance"] is not None: + return ( + acc[0] + (data_item["sample_name"],), + acc[1] + (data_item["value"],), + acc[2] + (data_item["variance"],)) + return acc + return reduce( + __exporter__, + filter(lambda td: td["value"] is not None, + trait_data["data"].values()), + (tuple(), tuple(), tuple())) |