diff options
author | BonfaceKilz | 2021-10-26 08:48:30 +0300 |
---|---|---|
committer | GitHub | 2021-10-26 08:48:30 +0300 |
commit | 0cfff99e22155b6b15e23cbeff596f5f8f08709c (patch) | |
tree | c831f28534ebf6432972e107dbd6da5daef81088 | |
parent | 5440bfcd6940db08c4479a39ba66dbc802b2c426 (diff) | |
parent | c13afb3af166d2b01e4f9fd9b09bb231f0a63cb1 (diff) | |
download | genenetwork3-0cfff99e22155b6b15e23cbeff596f5f8f08709c.tar.gz |
Merge pull request #46 from genenetwork/partial-correlations
Partial correlations
-rw-r--r-- | gn3/data_helpers.py | 25 | ||||
-rw-r--r-- | gn3/db/correlations.py | 318 | ||||
-rw-r--r-- | gn3/db/species.py | 31 | ||||
-rw-r--r-- | gn3/partial_correlations.py | 38 | ||||
-rw-r--r-- | tests/unit/test_data_helpers.py | 37 | ||||
-rw-r--r-- | tests/unit/test_partial_correlations.py | 64 |
6 files changed, 510 insertions, 3 deletions
diff --git a/gn3/data_helpers.py b/gn3/data_helpers.py new file mode 100644 index 0000000..f0d971e --- /dev/null +++ b/gn3/data_helpers.py @@ -0,0 +1,25 @@ +""" +This module will hold generic functions that can operate on a wide-array of +data structures. +""" + +from math import ceil +from functools import reduce +from typing import Any, Tuple, Sequence + +def partition_all(num: int, items: Sequence[Any]) -> Tuple[Tuple[Any, ...], ...]: + """ + Given a sequence `items`, return a new sequence of the same type as `items` + with the data partitioned into sections of `n` items per partition. + + This is an approximation of clojure's `partition-all` function. + """ + def __compute_start_stop__(acc, iteration): + start = iteration * num + return acc + ((start, start + num),) + + iterations = range(ceil(len(items) / num)) + return tuple([# type: ignore[misc] + tuple(items[start:stop]) for start, stop # type: ignore[has-type] + in reduce( + __compute_start_stop__, iterations, tuple())]) diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py new file mode 100644 index 0000000..87ab082 --- /dev/null +++ b/gn3/db/correlations.py @@ -0,0 +1,318 @@ +""" +This module will hold functions that are used in the (partial) correlations +feature to access the database to retrieve data needed for computations. +""" + +from functools import reduce +from typing import Any, Dict, Tuple + +from gn3.random import random_string +from gn3.data_helpers import partition_all +from gn3.db.species import translate_to_mouse_gene_id + +def get_filename(target_db_name: str, conn: Any) -> str: + """ + Retrieve the name of the reference database file with which correlations are + computed. + + This is a migration of the + `web.webqtl.correlation.CorrelationPage.getFileName` function in + GeneNetwork1. + """ + with conn.cursor() as cursor: + cursor.execute( + "SELECT Id, FullName from ProbeSetFreeze WHERE Name-%s", + target_db_name) + result = cursor.fetchone() + if result: + return "ProbeSetFreezeId_{tid}_FullName_{fname}.txt".format( + tid=result[0], + fname=result[1].replace(' ', '_').replace('/', '_')) + + return "" + +def build_temporary_literature_table( + species: str, gene_id: int, return_number: int, conn: Any) -> str: + """ + Build and populate a temporary table to hold the literature correlation data + to be used in computations. + + "This is a migration of the + `web.webqtl.correlation.CorrelationPage.getTempLiteratureTable` function in + GeneNetwork1. + """ + def __translated_species_id(row, cursor): + if species == "mouse": + return row[1] + query = { + "rat": "SELECT rat FROM GeneIDXRef WHERE mouse=%s", + "human": "SELECT human FROM GeneIDXRef WHERE mouse=%d"} + if species in query.keys(): + cursor.execute(query[species], row[1]) + record = cursor.fetchone() + if record: + return record[0] + return None + return None + + temp_table_name = f"TOPLITERATURE{random_string(8)}" + with conn.cursor as cursor: + mouse_geneid = translate_to_mouse_gene_id(species, gene_id, conn) + data_query = ( + "SELECT GeneId1, GeneId2, value FROM LCorrRamin3 " + "WHERE GeneId1 = %(mouse_gene_id)s " + "UNION ALL " + "SELECT GeneId2, GeneId1, value FROM LCorrRamin3 " + "WHERE GeneId2 = %(mouse_gene_id)s " + "AND GeneId1 != %(mouse_gene_id)s") + cursor.execute( + (f"CREATE TEMPORARY TABLE {temp_table_name} (" + "GeneId1 int(12) unsigned, " + "GeneId2 int(12) unsigned PRIMARY KEY, " + "value double)")) + cursor.execute(data_query, mouse_gene_id=mouse_geneid) + literature_data = [ + {"GeneId1": row[0], "GeneId2": row[1], "value": row[2]} + for row in cursor.fetchall() + if __translated_species_id(row, cursor)] + + cursor.execute( + (f"INSERT INTO {temp_table_name} " + "VALUES (%(GeneId1)s, %(GeneId2)s, %(value)s)"), + literature_data[0:(2 * return_number)]) + + return temp_table_name + +def fetch_geno_literature_correlations(temp_table: str) -> str: + """ + Helper function for `fetch_literature_correlations` below, to build query + for `Geno*` tables. + """ + return ( + f"SELECT Geno.Name, {temp_table}.value " + "FROM Geno, GenoXRef, GenoFreeze " + f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " + "WHERE ProbeSet.GeneId IS NOT NULL " + f"AND {temp_table}.value IS NOT NULL " + "AND GenoXRef.GenoFreezeId = GenoFreeze.Id " + "AND GenoFreeze.Name = %(db_name)s " + "AND Geno.Id=GenoXRef.GenoId " + "ORDER BY Geno.Id") + +def fetch_probeset_literature_correlations(temp_table: str) -> str: + """ + Helper function for `fetch_literature_correlations` below, to build query + for `ProbeSet*` tables. + """ + return ( + f"SELECT ProbeSet.Name, {temp_table}.value " + "FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze " + "LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " + "WHERE ProbeSet.GeneId IS NOT NULL " + "AND {temp_table}.value IS NOT NULL " + "AND ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id " + "AND ProbeSetFreeze.Name = %(db_name)s " + "AND ProbeSet.Id=ProbeSetXRef.ProbeSetId " + "ORDER BY ProbeSet.Id") + +def fetch_literature_correlations( + species: str, gene_id: int, dataset: dict, return_number: int, + conn: Any) -> dict: + """ + Gather the literature correlation data and pair it with trait id string(s). + + This is a migration of the + `web.webqtl.correlation.CorrelationPage.fetchLitCorrelations` function in + GeneNetwork1. + """ + temp_table = build_temporary_literature_table( + species, gene_id, return_number, conn) + query_fns = { + "Geno": fetch_geno_literature_correlations, + # "Temp": fetch_temp_literature_correlations, + # "Publish": fetch_publish_literature_correlations, + "ProbeSet": fetch_probeset_literature_correlations} + with conn.cursor as cursor: + cursor.execute( + query_fns[dataset["dataset_type"]](temp_table), + db_name=dataset["dataset_name"]) + results = cursor.fetchall() + cursor.execute("DROP TEMPORARY TABLE %s", temp_table) + return dict(results) + +def compare_tissue_correlation_absolute_values(val1, val2): + """ + Comparison function for use when sorting tissue correlation values. + + This is a partial migration of the + `web.webqtl.correlation.CorrelationPage.getTempTissueCorrTable` function in + GeneNetwork1.""" + try: + if abs(val1) < abs(val2): + return 1 + if abs(val1) == abs(val2): + return 0 + return -1 + except TypeError: + return 0 + +def fetch_symbol_value_pair_dict( + symbol_list: Tuple[str, ...], data_id_dict: dict, + conn: Any) -> Dict[str, Tuple[float, ...]]: + """ + Map each gene symbols to the corresponding tissue expression data. + + This is a migration of the + `web.webqtl.correlation.correlationFunction.getSymbolValuePairDict` function + in GeneNetwork1. + """ + data_ids = { + symbol: data_id_dict.get(symbol) for symbol in symbol_list + if data_id_dict.get(symbol) is not None + } + query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN %(data_ids)s" + with conn.cursor() as cursor: + cursor.execute( + query, + data_ids=tuple(data_ids.values())) + value_results = cursor.fetchall() + return { + key: tuple(row[1] for row in value_results if row[0] == key) + for key in data_ids.keys() + } + + return {} + +def fetch_gene_symbol_tissue_value_dict( + symbol_list: Tuple[str, ...], data_id_dict: dict, conn: Any, + limit_num: int = 1000) -> dict:#getGeneSymbolTissueValueDict + """ + Wrapper function for `gn3.db.correlations.fetch_symbol_value_pair_dict`. + + This is a migrations of the + `web.webqtl.correlation.correlationFunction.getGeneSymbolTissueValueDict` in + GeneNetwork1. + """ + count = len(symbol_list) + if count != 0 and count <= limit_num: + return fetch_symbol_value_pair_dict(symbol_list, data_id_dict, conn) + + if count > limit_num: + return { + key: value for dct in [ + fetch_symbol_value_pair_dict(sl, data_id_dict, conn) + for sl in partition_all(limit_num, symbol_list)] + for key, value in dct.items() + } + + return {} + +def fetch_tissue_probeset_xref_info( + gene_name_list: Tuple[str, ...], probeset_freeze_id: int, + conn: Any) -> Tuple[tuple, dict, dict, dict, dict, dict, dict]: + """ + Retrieve the ProbeSet XRef information for tissues. + + This is a migration of the + `web.webqtl.correlation.correlationFunction.getTissueProbeSetXRefInfo` + function in GeneNetwork1.""" + with conn.cursor() as cursor: + if len(gene_name_list) == 0: + query = ( + "SELECT t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, " + "t.description, t.Probe_Target_Description " + "FROM " + "(" + " SELECT Symbol, max(Mean) AS maxmean " + " FROM TissueProbeSetXRef " + " WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s " + " AND Symbol != '' " + " AND Symbol IS NOT NULL " + " GROUP BY Symbol" + ") AS x " + "INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol " + "AND t.Mean = x.maxmean") + cursor.execute(query, probeset_freeze_id=probeset_freeze_id) + else: + query = ( + "SELECT t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, " + "t.description, t.Probe_Target_Description " + "FROM " + "(" + " SELECT Symbol, max(Mean) AS maxmean " + " FROM TissueProbeSetXRef " + " WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s " + " AND Symbol in %(symbols)s " + " GROUP BY Symbol" + ") AS x " + "INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol " + "AND t.Mean = x.maxmean") + cursor.execute( + query, probeset_freeze_id=probeset_freeze_id, + symbols=tuple(gene_name_list)) + + results = cursor.fetchall() + + return reduce( + lambda acc, item: ( + acc[0] + (item[0],), + {**acc[1], item[0].lower(): item[1]}, + {**acc[1], item[0].lower(): item[2]}, + {**acc[1], item[0].lower(): item[3]}, + {**acc[1], item[0].lower(): item[4]}, + {**acc[1], item[0].lower(): item[5]}, + {**acc[1], item[0].lower(): item[6]}), + results or tuple(), + (tuple(), {}, {}, {}, {}, {}, {})) + +def correlations_of_all_tissue_traits() -> Tuple[dict, dict]: + """ + This is a migration of the + `web.webqtl.correlation.CorrelationPage.calculateCorrOfAllTissueTrait` + function in GeneNetwork1. + """ + raise Exception("Unimplemented!!!") + return ({}, {}) + +def build_temporary_tissue_correlations_table( + trait_symbol: str, probeset_freeze_id: int, method: str, + return_number: int, conn: Any) -> str: + """ + Build a temporary table to hold the tissue correlations data. + + This is a migration of the + `web.webqtl.correlation.CorrelationPage.getTempTissueCorrTable` function in + GeneNetwork1.""" + raise Exception("Unimplemented!!!") + return "" + +def fetch_tissue_correlations( + dataset: dict, trait_symbol: str, probeset_freeze_id: int, method: str, + return_number: int, conn: Any) -> dict: + """ + Pair tissue correlations data with a trait id string. + + This is a migration of the + `web.webqtl.correlation.CorrelationPage.fetchTissueCorrelations` function in + GeneNetwork1. + """ + temp_table = build_temporary_tissue_correlations_table( + trait_symbol, probeset_freeze_id, method, return_number, conn) + with conn.cursor() as cursor: + cursor.execute( + ( + f"SELECT ProbeSet.Name, {temp_table}.Correlation, " + f"{temp_table}.PValue " + "FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) " + "LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol " + "WHERE ProbeSetFreeze.Name = %(db_name) " + "AND ProbeSetFreeze.Id=ProbeSetXRef.ProbeSetFreezeId " + "AND ProbeSet.Id = ProbeSetXRef.ProbeSetId " + "AND ProbeSet.Symbol IS NOT NULL " + "AND %s.Correlation IS NOT NULL"), + db_name=dataset["dataset_name"]) + results = cursor.fetchall() + cursor.execute("DROP TEMPORARY TABLE %s", temp_table) + return { + trait_name: (tiss_corr, tiss_p_val) + for trait_name, tiss_corr, tiss_p_val in results} diff --git a/gn3/db/species.py b/gn3/db/species.py index 0deae4e..1e5015f 100644 --- a/gn3/db/species.py +++ b/gn3/db/species.py @@ -30,3 +30,34 @@ def get_chromosome(name: str, is_species: bool, conn: Any) -> Optional[Tuple]: with conn.cursor() as cursor: cursor.execute(_sql) return cursor.fetchall() + +def translate_to_mouse_gene_id(species: str, geneid: int, conn: Any) -> int: + """ + Translate rat or human geneid to mouse geneid + + This is a migration of the + `web.webqtl.correlation/CorrelationPage.translateToMouseGeneID` function in + GN1 + """ + assert species in ("rat", "mouse", "human"), "Invalid species" + if geneid is None: + return 0 + + if species == "mouse": + return geneid + + with conn.cursor as cursor: + if species == "rat": + cursor.execute( + "SELECT mouse FROM GeneIDXRef WHERE rat = %s", geneid) + rat_geneid = cursor.fetchone() + if rat_geneid: + return rat_geneid[0] + + cursor.execute( + "SELECT mouse FROM GeneIDXRef WHERE human = %s", geneid) + human_geneid = cursor.fetchone() + if human_geneid: + return human_geneid[0] + + return 0 # default if all else fails diff --git a/gn3/partial_correlations.py b/gn3/partial_correlations.py index c556d10..1fb0ccc 100644 --- a/gn3/partial_correlations.py +++ b/gn3/partial_correlations.py @@ -6,7 +6,7 @@ GeneNetwork1. """ from functools import reduce -from typing import Any, Sequence +from typing import Any, Tuple, Sequence def control_samples(controls: Sequence[dict], sampleslist: Sequence[str]): """ @@ -86,3 +86,39 @@ def fix_samples(primary_trait: dict, control_traits: Sequence[dict]) -> Sequence control_vals_vars[0], tuple(primary_trait[sample]["variance"] for sample in primary_samples), control_vals_vars[1]) + +def find_identical_traits( + primary_name: str, primary_value: float, control_names: Tuple[str, ...], + control_values: Tuple[float, ...]) -> Tuple[str, ...]: + """ + Find traits that have the same value when the values are considered to + 3 decimal places. + + This is a migration of the + `web.webqtl.correlation.correlationFunction.findIdenticalTraits` function in + GN1. + """ + def __merge_identicals__( + acc: Tuple[str, ...], + ident: Tuple[str, Tuple[str, ...]]) -> Tuple[str, ...]: + return acc + ident[1] + + def __dictify_controls__(acc, control_item): + ckey = "{:.3f}".format(control_item[0]) + return {**acc, ckey: acc.get(ckey, tuple()) + (control_item[1],)} + + return (reduce(## for identical control traits + __merge_identicals__, + (item for item in reduce(# type: ignore[var-annotated] + __dictify_controls__, zip(control_values, control_names), + {}).items() if len(item[1]) > 1), + tuple()) + or + reduce(## If no identical control traits, try primary and controls + __merge_identicals__, + (item for item in reduce(# type: ignore[var-annotated] + __dictify_controls__, + zip((primary_value,) + control_values, + (primary_name,) + control_names), {}).items() + if len(item[1]) > 1), + tuple())) diff --git a/tests/unit/test_data_helpers.py b/tests/unit/test_data_helpers.py new file mode 100644 index 0000000..1eec3cc --- /dev/null +++ b/tests/unit/test_data_helpers.py @@ -0,0 +1,37 @@ +""" +Test functions in gn3.data_helpers +""" + +from unittest import TestCase + +from gn3.data_helpers import partition_all + +class TestDataHelpers(TestCase): + """ + Test functions in gn3.data_helpers + """ + + def test_partition_all(self): + """ + Test that `gn3.data_helpers.partition_all` partitions sequences as expected. + + Given: + - `num`: The number of items per partition + - `items`: A sequence of items + When: + - The arguments above are passed to the `gn3.data_helpers.partition_all` + Then: + - Return a new sequence with partitions, each of which has `num` + items in the same order as those in `items`, save for the last + partition which might have fewer items than `num`. + """ + for count, items, expected in ( + (1, [0, 1, 2, 3], ((0,), (1,), (2,), (3,))), + (3, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + ((0, 1, 2), (3, 4, 5), (6, 7, 8), (9, ))), + (4, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + ((0, 1, 2, 3), (4, 5, 6, 7), (8, 9))), + (13, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + ((0, 1, 2, 3, 4, 5, 6, 7, 8, 9), ))): + with self.subTest(n=count, items=items): + self.assertEqual(partition_all(count, items), expected) diff --git a/tests/unit/test_partial_correlations.py b/tests/unit/test_partial_correlations.py index 7631a71..60e54c1 100644 --- a/tests/unit/test_partial_correlations.py +++ b/tests/unit/test_partial_correlations.py @@ -4,7 +4,8 @@ from unittest import TestCase from gn3.partial_correlations import ( fix_samples, control_samples, - dictify_by_samples) + dictify_by_samples, + find_identical_traits) sampleslist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"] control_traits = ( @@ -106,6 +107,8 @@ class TestPartialCorrelations(TestCase): def test_dictify_by_samples(self): """ + Test that `dictify_by_samples` generates the appropriate dict + Given: a sequence of sequences with sample names, values and variances, as in the output of `gn3.partial_correlations.control_samples` or @@ -133,7 +136,34 @@ class TestPartialCorrelations(TestCase): dictified_control_samples) def test_fix_samples(self): - """Test that fix_samples fixes the values""" + """ + Test that `fix_samples` returns only the common samples + + Given: + - A primary trait + - A sequence of control samples + When: + - The two arguments are passed to `fix_samples` + Then: + - Only the names of the samples present in the primary trait that + are also present in ALL the control traits are present in the + return value + - Only the values of the samples present in the primary trait that + are also present in ALL the control traits are present in the + return value + - ALL the values for ALL the control traits are present in the + return value + - Only the variances of the samples present in the primary trait + that are also present in ALL the control traits are present in the + return value + - ALL the variances for ALL the control traits are present in the + return value + - The return value is a tuple of the above items, in the following + order: + ((sample_names, ...), (primary_trait_values, ...), + (control_traits_values, ...), (primary_trait_variances, ...) + (control_traits_variances, ...)) + """ self.assertEqual( fix_samples( {"B6cC3-1": {"sample_name": "B6cC3-1", "value": 7.51879, @@ -149,3 +179,33 @@ class TestPartialCorrelations(TestCase): (None,), (None, None, None, None, None, None, None, None, None, None, None, None, None))) + + def test_find_identical_traits(self): + """ + Test `gn3.partial_correlations.find_identical_traits`. + + Given: + - the name of a primary trait + - the value of a primary trait + - a sequence of names of control traits + - a sequence of values of control traits + When: + - the arguments above are passed to the `find_identical_traits` + function + Then: + - Return ALL trait names that have the same value when up to three + decimal places are considered + """ + for primn, primv, contn, contv, expected in ( + ("pt", 12.98395, ("ct0", "ct1", "ct2"), + (0.1234, 2.3456, 3.4567), tuple()), + ("pt", 12.98395, ("ct0", "ct1", "ct2"), + (12.98354, 2.3456, 3.4567), ("pt", "ct0")), + ("pt", 12.98395, ("ct0", "ct1", "ct2", "ct3"), + (0.1234, 2.3456, 0.1233, 4.5678), ("ct0", "ct2")) + ): + with self.subTest( + primary_name=primn, primary_value=primv, + control_names=contn, control_values=contv): + self.assertEqual( + find_identical_traits(primn, primv, contn, contv), expected) |