diff options
-rw-r--r-- | gn3/computations/partial_correlations.py (renamed from gn3/partial_correlations.py) | 70 | ||||
-rw-r--r-- | gn3/data_helpers.py | 14 | ||||
-rw-r--r-- | gn3/db/correlations.py | 109 | ||||
-rw-r--r-- | gn3/db/species.py | 20 | ||||
-rw-r--r-- | tests/unit/computations/test_partial_correlations.py (renamed from tests/unit/test_partial_correlations.py) | 66 | ||||
-rw-r--r-- | tests/unit/test_data_helpers.py | 26 |
6 files changed, 266 insertions, 39 deletions
diff --git a/gn3/partial_correlations.py b/gn3/computations/partial_correlations.py index 1fb0ccc..ba4de9e 100644 --- a/gn3/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -7,6 +7,7 @@ GeneNetwork1. from functools import reduce from typing import Any, Tuple, Sequence +from scipy.stats import pearsonr, spearmanr def control_samples(controls: Sequence[dict], sampleslist: Sequence[str]): """ @@ -122,3 +123,72 @@ def find_identical_traits( (primary_name,) + control_names), {}).items() if len(item[1]) > 1), tuple())) + +def tissue_correlation( + primary_trait_values: Tuple[float, ...], + target_trait_values: Tuple[float, ...], + method: str) -> Tuple[float, float]: + """ + Compute the correlation between the primary trait values, and the values of + a single target value. + + This migrates the `cal_tissue_corr` function embedded in the larger + `web.webqtl.correlation.correlationFunction.batchCalTissueCorr` function in + GeneNetwork1. + """ + def spearman_corr(*args): + result = spearmanr(*args) + return (result.correlation, result.pvalue) + + method_fns = {"pearson": pearsonr, "spearman": spearman_corr} + + assert len(primary_trait_values) == len(target_trait_values), ( + "The lengths of the `primary_trait_values` and `target_trait_values` " + "must be equal") + assert method in method_fns.keys(), ( + "Method must be one of: {}".format(",".join(method_fns.keys()))) + + corr, pvalue = method_fns[method](primary_trait_values, target_trait_values) + return (round(corr, 10), round(pvalue, 10)) + +def batch_computed_tissue_correlation( + primary_trait_values: Tuple[float, ...], target_traits_dict: dict, + method: str) -> Tuple[dict, dict]: + """ + This is a migration of the + `web.webqtl.correlation.correlationFunction.batchCalTissueCorr` function in + GeneNetwork1 + """ + def __corr__(acc, target): + corr = tissue_correlation(primary_trait_values, target[1], method) + return ({**acc[0], target[0]: corr[0]}, {**acc[0], target[1]: corr[1]}) + return reduce(__corr__, target_traits_dict.items(), ({}, {})) + +def correlations_of_all_tissue_traits( + primary_trait_symbol_value_dict: dict, symbol_value_dict: dict, + method: str) -> Tuple[dict, dict]: + """ + Computes and returns the correlation of all tissue traits. + + This is a migration of the + `web.webqtl.correlation.correlationFunction.calculateCorrOfAllTissueTrait` + function in GeneNetwork1. + """ + primary_trait_values = tuple(primary_trait_symbol_value_dict.values())[0] + return batch_computed_tissue_correlation( + primary_trait_values, symbol_value_dict, method) + +def good_dataset_samples_indexes( + samples: Tuple[str, ...], + samples_from_file: Tuple[str, ...]) -> Tuple[int, ...]: + """ + Return the indexes of the items in `samples_from_files` that are also found + in `samples`. + + This is a partial migration of the + `web.webqtl.correlation.PartialCorrDBPage.getPartialCorrelationsFast` + function in GeneNetwork1. + """ + return tuple(sorted( + samples_from_file.index(good) for good in + set(samples).intersection(set(samples_from_file)))) diff --git a/gn3/data_helpers.py b/gn3/data_helpers.py index f0d971e..d3f942b 100644 --- a/gn3/data_helpers.py +++ b/gn3/data_helpers.py @@ -5,7 +5,7 @@ data structures. from math import ceil from functools import reduce -from typing import Any, Tuple, Sequence +from typing import Any, Tuple, Sequence, Optional def partition_all(num: int, items: Sequence[Any]) -> Tuple[Tuple[Any, ...], ...]: """ @@ -23,3 +23,15 @@ def partition_all(num: int, items: Sequence[Any]) -> Tuple[Tuple[Any, ...], ...] tuple(items[start:stop]) for start, stop # type: ignore[has-type] in reduce( __compute_start_stop__, iterations, tuple())]) + +def parse_csv_line( + line: str, delimiter: str = ",", + quoting: Optional[str] = '"') -> Tuple[str, ...]: + """ + Parses a line from a CSV file into a tuple of strings. + + This is a migration of the `web.webqtl.utility.webqtlUtil.readLineCSV` + function in GeneNetwork1. + """ + return tuple( + col.strip("{} \t\n".format(quoting)) for col in line.split(delimiter)) diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py index 87ab082..06b3310 100644 --- a/gn3/db/correlations.py +++ b/gn3/db/correlations.py @@ -10,6 +10,8 @@ from gn3.random import random_string from gn3.data_helpers import partition_all from gn3.db.species import translate_to_mouse_gene_id +from gn3.computations.partial_correlations import correlations_of_all_tissue_traits + def get_filename(target_db_name: str, conn: Any) -> str: """ Retrieve the name of the reference database file with which correlations are @@ -140,22 +142,6 @@ def fetch_literature_correlations( cursor.execute("DROP TEMPORARY TABLE %s", temp_table) return dict(results) -def compare_tissue_correlation_absolute_values(val1, val2): - """ - Comparison function for use when sorting tissue correlation values. - - This is a partial migration of the - `web.webqtl.correlation.CorrelationPage.getTempTissueCorrTable` function in - GeneNetwork1.""" - try: - if abs(val1) < abs(val2): - return 1 - if abs(val1) == abs(val2): - return 0 - return -1 - except TypeError: - return 0 - def fetch_symbol_value_pair_dict( symbol_list: Tuple[str, ...], data_id_dict: dict, conn: Any) -> Dict[str, Tuple[float, ...]]: @@ -265,14 +251,21 @@ def fetch_tissue_probeset_xref_info( results or tuple(), (tuple(), {}, {}, {}, {}, {}, {})) -def correlations_of_all_tissue_traits() -> Tuple[dict, dict]: +def fetch_gene_symbol_tissue_value_dict_for_trait( + gene_name_list: Tuple[str, ...], probeset_freeze_id: int, + conn: Any) -> dict: """ + Fetches a map of the gene symbols to the tissue values. + This is a migration of the - `web.webqtl.correlation.CorrelationPage.calculateCorrOfAllTissueTrait` + `web.webqtl.correlation.correlationFunction.getGeneSymbolTissueValueDictForTrait` function in GeneNetwork1. """ - raise Exception("Unimplemented!!!") - return ({}, {}) + xref_info = fetch_tissue_probeset_xref_info( + gene_name_list, probeset_freeze_id, conn) + if xref_info[0]: + return fetch_gene_symbol_tissue_value_dict(xref_info[0], xref_info[2], conn) + return {} def build_temporary_tissue_correlations_table( trait_symbol: str, probeset_freeze_id: int, method: str, @@ -283,10 +276,40 @@ def build_temporary_tissue_correlations_table( This is a migration of the `web.webqtl.correlation.CorrelationPage.getTempTissueCorrTable` function in GeneNetwork1.""" - raise Exception("Unimplemented!!!") - return "" + # We should probably pass the `correlations_of_all_tissue_traits` function + # as an argument to this function and get rid of the one call immediately + # following this comment. + symbol_corr_dict, symbol_p_value_dict = correlations_of_all_tissue_traits( + fetch_gene_symbol_tissue_value_dict_for_trait( + (trait_symbol,), probeset_freeze_id, conn), + fetch_gene_symbol_tissue_value_dict_for_trait( + tuple(), probeset_freeze_id, conn), + method) + + symbol_corr_list = sorted( + symbol_corr_dict.items(), key=lambda key_val: key_val[1]) + + temp_table_name = f"TOPTISSUE{random_string(8)}" + create_query = ( + "CREATE TEMPORARY TABLE {temp_table_name}" + "(Symbol varchar(100) PRIMARY KEY, Correlation float, PValue float)") + insert_query = ( + f"INSERT INTO {temp_table_name}(Symbol, Correlation, PValue) " + " VALUES (%(symbol)s, %(correlation)s, %(pvalue)s)") -def fetch_tissue_correlations( + with conn.cursor() as cursor: + cursor.execute(create_query) + cursor.execute( + insert_query, + tuple({ + "symbol": symbol, + "correlation": corr, + "pvalue": symbol_p_value_dict[symbol] + } for symbol, corr in symbol_corr_list[0: 2 * return_number])) + + return temp_table_name + +def fetch_tissue_correlations(# pylint: disable=R0913 dataset: dict, trait_symbol: str, probeset_freeze_id: int, method: str, return_number: int, conn: Any) -> dict: """ @@ -316,3 +339,43 @@ def fetch_tissue_correlations( return { trait_name: (tiss_corr, tiss_p_val) for trait_name, tiss_corr, tiss_p_val in results} + +def check_for_literature_info(conn: Any, geneid: int) -> bool: + """ + Checks the database to find out whether the trait with `geneid` has any + associated literature. + + This is a migration of the + `web.webqtl.correlation.CorrelationPage.checkForLitInfo` function in + GeneNetwork1. + """ + query = "SELECT 1 FROM LCorrRamin3 WHERE GeneId1=%s LIMIT 1" + with conn.cursor() as cursor: + cursor.execute(query, geneid) + result = cursor.fetchone() + if result: + return True + + return False + +def check_symbol_for_tissue_correlation( + conn: Any, tissue_probeset_freeze_id: int, symbol: str = "") -> bool: + """ + Checks whether a symbol has any associated tissue correlations. + + This is a migration of the + `web.webqtl.correlation.CorrelationPage.checkSymbolForTissueCorr` function + in GeneNetwork1. + """ + query = ( + "SELECT 1 FROM TissueProbeSetXRef " + "WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s " + "AND Symbol=%(symbol)s LIMIT 1") + with conn.cursor() as cursor: + cursor.execute( + query, probeset_freeze_id=tissue_probeset_freeze_id, symbol=symbol) + result = cursor.fetchone() + if result: + return True + + return False diff --git a/gn3/db/species.py b/gn3/db/species.py index 1e5015f..702a9a8 100644 --- a/gn3/db/species.py +++ b/gn3/db/species.py @@ -47,17 +47,13 @@ def translate_to_mouse_gene_id(species: str, geneid: int, conn: Any) -> int: return geneid with conn.cursor as cursor: - if species == "rat": - cursor.execute( - "SELECT mouse FROM GeneIDXRef WHERE rat = %s", geneid) - rat_geneid = cursor.fetchone() - if rat_geneid: - return rat_geneid[0] - - cursor.execute( - "SELECT mouse FROM GeneIDXRef WHERE human = %s", geneid) - human_geneid = cursor.fetchone() - if human_geneid: - return human_geneid[0] + query = { + "rat": "SELECT mouse FROM GeneIDXRef WHERE rat = %s", + "human": "SELECT mouse FROM GeneIDXRef WHERE human = %s" + } + cursor.execute(query[species], geneid) + translated_gene_id = cursor.fetchone() + if translated_gene_id: + return translated_gene_id[0] return 0 # default if all else fails diff --git a/tests/unit/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py index 60e54c1..f7217a9 100644 --- a/tests/unit/test_partial_correlations.py +++ b/tests/unit/computations/test_partial_correlations.py @@ -1,11 +1,13 @@ """Module contains tests for gn3.partial_correlations""" from unittest import TestCase -from gn3.partial_correlations import ( +from gn3.computations.partial_correlations import ( fix_samples, control_samples, dictify_by_samples, - find_identical_traits) + tissue_correlation, + find_identical_traits, + good_dataset_samples_indexes) sampleslist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"] control_traits = ( @@ -209,3 +211,63 @@ class TestPartialCorrelations(TestCase): control_names=contn, control_values=contv): self.assertEqual( find_identical_traits(primn, primv, contn, contv), expected) + + def test_tissue_correlation_error(self): + """ + Test that `tissue_correlation` raises specific exceptions for particular + error conditions. + """ + for primary, target, method, error, error_msg in ( + ((1, 2, 3), (4, 5, 6, 7), "pearson", + AssertionError, + ( + "The lengths of the `primary_trait_values` and " + "`target_trait_values` must be equal")), + ((1, 2, 3), (4, 5, 6, 7), "spearman", + AssertionError, + ( + "The lengths of the `primary_trait_values` and " + "`target_trait_values` must be equal")), + ((1, 2, 3, 4), (5, 6, 7), "pearson", + AssertionError, + ( + "The lengths of the `primary_trait_values` and " + "`target_trait_values` must be equal")), + ((1, 2, 3, 4), (5, 6, 7), "spearman", + AssertionError, + ( + "The lengths of the `primary_trait_values` and " + "`target_trait_values` must be equal")), + ((1, 2, 3), (4, 5, 6), "nonexistentmethod", + AssertionError, + ( + "Method must be one of: pearson, spearman"))): + with self.subTest(primary=primary, target=target, method=method): + with self.assertRaises(error, msg=error_msg): + tissue_correlation(primary, target, method) + + def test_tissue_correlation(self): + """ + Test that the correct correlation values are computed for the given: + - primary trait + - target trait + - method + """ + for primary, target, method, expected in ( + ((12.34, 18.36, 42.51), (37.25, 46.25, 46.56), "pearson", + (0.6761779253, 0.5272701134)), + ((1, 2, 3, 4, 5), (5, 6, 7, 8, 7), "spearman", + (0.8207826817, 0.0885870053))): + with self.subTest(primary=primary, target=target, method=method): + self.assertEqual( + tissue_correlation(primary, target, method), expected) + + def test_good_dataset_samples_indexes(self): + """ + Test that `good_dataset_samples_indexes` returns correct indices. + """ + self.assertEqual( + good_dataset_samples_indexes( + ("a", "e", "i", "k"), + ("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l")), + (0, 4, 8, 10)) diff --git a/tests/unit/test_data_helpers.py b/tests/unit/test_data_helpers.py index 1eec3cc..39aea45 100644 --- a/tests/unit/test_data_helpers.py +++ b/tests/unit/test_data_helpers.py @@ -4,7 +4,7 @@ Test functions in gn3.data_helpers from unittest import TestCase -from gn3.data_helpers import partition_all +from gn3.data_helpers import partition_all, parse_csv_line class TestDataHelpers(TestCase): """ @@ -35,3 +35,27 @@ class TestDataHelpers(TestCase): ((0, 1, 2, 3, 4, 5, 6, 7, 8, 9), ))): with self.subTest(n=count, items=items): self.assertEqual(partition_all(count, items), expected) + + def test_parse_csv_line(self): + """ + Test parsing a single line from a CSV file + + Given: + - `line`: a line read from a csv file + - `delimiter`: the expected delimiter in the csv file + - `quoting`: the quoting enclosing each column in the csv file + When: + - `line` is parsed with the `parse_csv_file` with the given + parameters + Then: + - return a tuple of the columns in the CSV file, without the + delimiter and quoting + """ + for line, delimiter, quoting, expected in ( + ('"this","is","a","test"', ",", '"', ("this", "is", "a", "test")), + ('"this","is","a","test"', ",", None, ('"this"', '"is"', '"a"', '"test"'))): + with self.subTest(line=line, delimiter=delimiter, quoting=quoting): + self.assertEqual( + parse_csv_line( + line=line, delimiter=delimiter, quoting=quoting), + expected) |