about summary refs log tree commit diff
diff options
context:
space:
mode:
authorBonfaceKilz2021-10-26 08:48:30 +0300
committerGitHub2021-10-26 08:48:30 +0300
commit0cfff99e22155b6b15e23cbeff596f5f8f08709c (patch)
treec831f28534ebf6432972e107dbd6da5daef81088
parent5440bfcd6940db08c4479a39ba66dbc802b2c426 (diff)
parentc13afb3af166d2b01e4f9fd9b09bb231f0a63cb1 (diff)
downloadgenenetwork3-0cfff99e22155b6b15e23cbeff596f5f8f08709c.tar.gz
Merge pull request #46 from genenetwork/partial-correlations
Partial correlations
-rw-r--r--gn3/data_helpers.py25
-rw-r--r--gn3/db/correlations.py318
-rw-r--r--gn3/db/species.py31
-rw-r--r--gn3/partial_correlations.py38
-rw-r--r--tests/unit/test_data_helpers.py37
-rw-r--r--tests/unit/test_partial_correlations.py64
6 files changed, 510 insertions, 3 deletions
diff --git a/gn3/data_helpers.py b/gn3/data_helpers.py
new file mode 100644
index 0000000..f0d971e
--- /dev/null
+++ b/gn3/data_helpers.py
@@ -0,0 +1,25 @@
+"""
+This module will hold generic functions that can operate on a wide-array of
+data structures.
+"""
+
+from math import ceil
+from functools import reduce
+from typing import Any, Tuple, Sequence
+
+def partition_all(num: int, items: Sequence[Any]) -> Tuple[Tuple[Any, ...], ...]:
+    """
+    Given a sequence `items`, return a new sequence of the same type as `items`
+    with the data partitioned into sections of `n` items per partition.
+
+    This is an approximation of clojure's `partition-all` function.
+    """
+    def __compute_start_stop__(acc, iteration):
+        start = iteration * num
+        return acc + ((start, start + num),)
+
+    iterations = range(ceil(len(items) / num))
+    return tuple([# type: ignore[misc]
+        tuple(items[start:stop]) for start, stop # type: ignore[has-type]
+        in reduce(
+            __compute_start_stop__, iterations, tuple())])
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py
new file mode 100644
index 0000000..87ab082
--- /dev/null
+++ b/gn3/db/correlations.py
@@ -0,0 +1,318 @@
+"""
+This module will hold functions that are used in the (partial) correlations
+feature to access the database to retrieve data needed for computations.
+"""
+
+from functools import reduce
+from typing import Any, Dict, Tuple
+
+from gn3.random import random_string
+from gn3.data_helpers import partition_all
+from gn3.db.species import translate_to_mouse_gene_id
+
+def get_filename(target_db_name: str, conn: Any) -> str:
+    """
+    Retrieve the name of the reference database file with which correlations are
+    computed.
+
+    This is a migration of the
+    `web.webqtl.correlation.CorrelationPage.getFileName` function in
+    GeneNetwork1.
+    """
+    with conn.cursor() as cursor:
+        cursor.execute(
+            "SELECT Id, FullName from ProbeSetFreeze WHERE Name-%s",
+            target_db_name)
+        result = cursor.fetchone()
+        if result:
+            return "ProbeSetFreezeId_{tid}_FullName_{fname}.txt".format(
+                tid=result[0],
+                fname=result[1].replace(' ', '_').replace('/', '_'))
+
+    return ""
+
+def build_temporary_literature_table(
+        species: str, gene_id: int, return_number: int, conn: Any) -> str:
+    """
+    Build and populate a temporary table to hold the literature correlation data
+    to be used in computations.
+
+    "This is a migration of the
+    `web.webqtl.correlation.CorrelationPage.getTempLiteratureTable` function in
+    GeneNetwork1.
+    """
+    def __translated_species_id(row, cursor):
+        if species == "mouse":
+            return row[1]
+        query = {
+            "rat": "SELECT rat FROM GeneIDXRef WHERE mouse=%s",
+            "human": "SELECT human FROM GeneIDXRef WHERE mouse=%d"}
+        if species in query.keys():
+            cursor.execute(query[species], row[1])
+            record = cursor.fetchone()
+            if record:
+                return record[0]
+            return None
+        return None
+
+    temp_table_name = f"TOPLITERATURE{random_string(8)}"
+    with conn.cursor as cursor:
+        mouse_geneid = translate_to_mouse_gene_id(species, gene_id, conn)
+        data_query = (
+            "SELECT GeneId1, GeneId2, value FROM LCorrRamin3 "
+            "WHERE GeneId1 = %(mouse_gene_id)s "
+            "UNION ALL "
+            "SELECT GeneId2, GeneId1, value FROM LCorrRamin3 "
+            "WHERE GeneId2 = %(mouse_gene_id)s "
+            "AND GeneId1 != %(mouse_gene_id)s")
+        cursor.execute(
+            (f"CREATE TEMPORARY TABLE {temp_table_name} ("
+             "GeneId1 int(12) unsigned, "
+             "GeneId2 int(12) unsigned PRIMARY KEY, "
+             "value double)"))
+        cursor.execute(data_query, mouse_gene_id=mouse_geneid)
+        literature_data = [
+            {"GeneId1": row[0], "GeneId2": row[1], "value": row[2]}
+            for row in cursor.fetchall()
+            if __translated_species_id(row, cursor)]
+
+        cursor.execute(
+            (f"INSERT INTO {temp_table_name} "
+             "VALUES (%(GeneId1)s, %(GeneId2)s, %(value)s)"),
+            literature_data[0:(2 * return_number)])
+
+    return temp_table_name
+
+def fetch_geno_literature_correlations(temp_table: str) -> str:
+    """
+    Helper function for `fetch_literature_correlations` below, to build query
+    for `Geno*` tables.
+    """
+    return (
+        f"SELECT Geno.Name, {temp_table}.value "
+        "FROM Geno, GenoXRef, GenoFreeze "
+        f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId "
+        "WHERE ProbeSet.GeneId IS NOT NULL "
+        f"AND {temp_table}.value IS NOT NULL "
+        "AND GenoXRef.GenoFreezeId = GenoFreeze.Id "
+        "AND GenoFreeze.Name = %(db_name)s "
+        "AND Geno.Id=GenoXRef.GenoId "
+        "ORDER BY Geno.Id")
+
+def fetch_probeset_literature_correlations(temp_table: str) -> str:
+    """
+    Helper function for `fetch_literature_correlations` below, to build query
+    for `ProbeSet*` tables.
+    """
+    return (
+        f"SELECT ProbeSet.Name, {temp_table}.value "
+        "FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze "
+        "LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId "
+        "WHERE ProbeSet.GeneId IS NOT NULL "
+        "AND {temp_table}.value IS NOT NULL "
+        "AND ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id "
+        "AND ProbeSetFreeze.Name = %(db_name)s "
+        "AND ProbeSet.Id=ProbeSetXRef.ProbeSetId "
+        "ORDER BY ProbeSet.Id")
+
+def fetch_literature_correlations(
+        species: str, gene_id: int, dataset: dict, return_number: int,
+        conn: Any) -> dict:
+    """
+    Gather the literature correlation data and pair it with trait id string(s).
+
+    This is a migration of the
+    `web.webqtl.correlation.CorrelationPage.fetchLitCorrelations` function in
+    GeneNetwork1.
+    """
+    temp_table = build_temporary_literature_table(
+        species, gene_id, return_number, conn)
+    query_fns = {
+        "Geno": fetch_geno_literature_correlations,
+        # "Temp": fetch_temp_literature_correlations,
+        # "Publish": fetch_publish_literature_correlations,
+        "ProbeSet": fetch_probeset_literature_correlations}
+    with conn.cursor as cursor:
+        cursor.execute(
+            query_fns[dataset["dataset_type"]](temp_table),
+            db_name=dataset["dataset_name"])
+        results = cursor.fetchall()
+        cursor.execute("DROP TEMPORARY TABLE %s", temp_table)
+        return dict(results)
+
+def compare_tissue_correlation_absolute_values(val1, val2):
+    """
+    Comparison function for use when sorting tissue correlation values.
+
+    This is a partial migration of the
+    `web.webqtl.correlation.CorrelationPage.getTempTissueCorrTable` function in
+    GeneNetwork1."""
+    try:
+        if abs(val1) < abs(val2):
+            return 1
+        if abs(val1) == abs(val2):
+            return 0
+        return -1
+    except TypeError:
+        return 0
+
+def fetch_symbol_value_pair_dict(
+        symbol_list: Tuple[str, ...], data_id_dict: dict,
+        conn: Any) -> Dict[str, Tuple[float, ...]]:
+    """
+    Map each gene symbols to the corresponding tissue expression data.
+
+    This is a migration of the
+    `web.webqtl.correlation.correlationFunction.getSymbolValuePairDict` function
+    in GeneNetwork1.
+    """
+    data_ids = {
+        symbol: data_id_dict.get(symbol) for symbol in symbol_list
+        if data_id_dict.get(symbol) is not None
+    }
+    query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN %(data_ids)s"
+    with conn.cursor() as cursor:
+        cursor.execute(
+            query,
+            data_ids=tuple(data_ids.values()))
+        value_results = cursor.fetchall()
+        return {
+            key: tuple(row[1] for row in value_results if row[0] == key)
+            for key in data_ids.keys()
+        }
+
+    return {}
+
+def fetch_gene_symbol_tissue_value_dict(
+        symbol_list: Tuple[str, ...], data_id_dict: dict, conn: Any,
+        limit_num: int = 1000) -> dict:#getGeneSymbolTissueValueDict
+    """
+    Wrapper function for `gn3.db.correlations.fetch_symbol_value_pair_dict`.
+
+    This is a migrations of the
+    `web.webqtl.correlation.correlationFunction.getGeneSymbolTissueValueDict` in
+    GeneNetwork1.
+    """
+    count = len(symbol_list)
+    if count != 0 and count <= limit_num:
+        return fetch_symbol_value_pair_dict(symbol_list, data_id_dict, conn)
+
+    if count > limit_num:
+        return {
+            key: value for dct in [
+                fetch_symbol_value_pair_dict(sl, data_id_dict, conn)
+                for sl in partition_all(limit_num, symbol_list)]
+            for key, value in dct.items()
+        }
+
+    return {}
+
+def fetch_tissue_probeset_xref_info(
+        gene_name_list: Tuple[str, ...], probeset_freeze_id: int,
+        conn: Any) -> Tuple[tuple, dict, dict, dict, dict, dict, dict]:
+    """
+    Retrieve the ProbeSet XRef information for tissues.
+
+    This is a migration of the
+    `web.webqtl.correlation.correlationFunction.getTissueProbeSetXRefInfo`
+    function in GeneNetwork1."""
+    with conn.cursor() as cursor:
+        if len(gene_name_list) == 0:
+            query = (
+                "SELECT t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, "
+                "t.description, t.Probe_Target_Description "
+                "FROM "
+                "("
+                "  SELECT Symbol, max(Mean) AS maxmean "
+                "  FROM TissueProbeSetXRef "
+                "  WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s "
+                "  AND Symbol != '' "
+                "  AND Symbol IS NOT NULL "
+                "  GROUP BY Symbol"
+                ") AS x "
+                "INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol "
+                "AND t.Mean = x.maxmean")
+            cursor.execute(query, probeset_freeze_id=probeset_freeze_id)
+        else:
+            query = (
+                "SELECT t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, "
+                "t.description, t.Probe_Target_Description "
+                "FROM "
+                "("
+                "  SELECT Symbol, max(Mean) AS maxmean "
+                "  FROM TissueProbeSetXRef "
+                "  WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s "
+                "  AND Symbol in %(symbols)s "
+                "  GROUP BY Symbol"
+                ") AS x "
+                "INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol "
+                "AND t.Mean = x.maxmean")
+            cursor.execute(
+                query, probeset_freeze_id=probeset_freeze_id,
+                symbols=tuple(gene_name_list))
+
+        results = cursor.fetchall()
+
+    return reduce(
+        lambda acc, item: (
+            acc[0] + (item[0],),
+            {**acc[1], item[0].lower(): item[1]},
+            {**acc[1], item[0].lower(): item[2]},
+            {**acc[1], item[0].lower(): item[3]},
+            {**acc[1], item[0].lower(): item[4]},
+            {**acc[1], item[0].lower(): item[5]},
+            {**acc[1], item[0].lower(): item[6]}),
+        results or tuple(),
+        (tuple(), {}, {}, {}, {}, {}, {}))
+
+def correlations_of_all_tissue_traits() -> Tuple[dict, dict]:
+    """
+    This is a migration of the
+    `web.webqtl.correlation.CorrelationPage.calculateCorrOfAllTissueTrait`
+    function in GeneNetwork1.
+    """
+    raise Exception("Unimplemented!!!")
+    return ({}, {})
+
+def build_temporary_tissue_correlations_table(
+        trait_symbol: str, probeset_freeze_id: int, method: str,
+        return_number: int, conn: Any) -> str:
+    """
+    Build a temporary table to hold the tissue correlations data.
+
+    This is a migration of the
+    `web.webqtl.correlation.CorrelationPage.getTempTissueCorrTable` function in
+    GeneNetwork1."""
+    raise Exception("Unimplemented!!!")
+    return ""
+
+def fetch_tissue_correlations(
+        dataset: dict, trait_symbol: str, probeset_freeze_id: int, method: str,
+        return_number: int, conn: Any) -> dict:
+    """
+    Pair tissue correlations data with a trait id string.
+
+    This is a migration of the
+    `web.webqtl.correlation.CorrelationPage.fetchTissueCorrelations` function in
+    GeneNetwork1.
+    """
+    temp_table = build_temporary_tissue_correlations_table(
+        trait_symbol, probeset_freeze_id, method, return_number, conn)
+    with conn.cursor() as cursor:
+        cursor.execute(
+            (
+                f"SELECT ProbeSet.Name, {temp_table}.Correlation, "
+                f"{temp_table}.PValue "
+                "FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) "
+                "LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol "
+                "WHERE ProbeSetFreeze.Name = %(db_name) "
+                "AND ProbeSetFreeze.Id=ProbeSetXRef.ProbeSetFreezeId "
+                "AND ProbeSet.Id = ProbeSetXRef.ProbeSetId "
+                "AND ProbeSet.Symbol IS NOT NULL "
+                "AND %s.Correlation IS NOT NULL"),
+            db_name=dataset["dataset_name"])
+        results = cursor.fetchall()
+        cursor.execute("DROP TEMPORARY TABLE %s", temp_table)
+        return {
+            trait_name: (tiss_corr, tiss_p_val)
+            for trait_name, tiss_corr, tiss_p_val in results}
diff --git a/gn3/db/species.py b/gn3/db/species.py
index 0deae4e..1e5015f 100644
--- a/gn3/db/species.py
+++ b/gn3/db/species.py
@@ -30,3 +30,34 @@ def get_chromosome(name: str, is_species: bool, conn: Any) -> Optional[Tuple]:
     with conn.cursor() as cursor:
         cursor.execute(_sql)
         return cursor.fetchall()
+
+def translate_to_mouse_gene_id(species: str, geneid: int, conn: Any) -> int:
+    """
+    Translate rat or human geneid to mouse geneid
+
+    This is a migration of the
+    `web.webqtl.correlation/CorrelationPage.translateToMouseGeneID` function in
+    GN1
+    """
+    assert species in ("rat", "mouse", "human"), "Invalid species"
+    if geneid is None:
+        return 0
+
+    if species == "mouse":
+        return geneid
+
+    with conn.cursor as cursor:
+        if species == "rat":
+            cursor.execute(
+                "SELECT mouse FROM GeneIDXRef WHERE rat = %s", geneid)
+            rat_geneid = cursor.fetchone()
+            if rat_geneid:
+                return rat_geneid[0]
+
+        cursor.execute(
+            "SELECT mouse FROM GeneIDXRef WHERE human = %s", geneid)
+        human_geneid = cursor.fetchone()
+        if human_geneid:
+            return human_geneid[0]
+
+    return 0 # default if all else fails
diff --git a/gn3/partial_correlations.py b/gn3/partial_correlations.py
index c556d10..1fb0ccc 100644
--- a/gn3/partial_correlations.py
+++ b/gn3/partial_correlations.py
@@ -6,7 +6,7 @@ GeneNetwork1.
 """
 
 from functools import reduce
-from typing import Any, Sequence
+from typing import Any, Tuple, Sequence
 
 def control_samples(controls: Sequence[dict], sampleslist: Sequence[str]):
     """
@@ -86,3 +86,39 @@ def fix_samples(primary_trait: dict, control_traits: Sequence[dict]) -> Sequence
         control_vals_vars[0],
         tuple(primary_trait[sample]["variance"] for sample in primary_samples),
         control_vals_vars[1])
+
+def find_identical_traits(
+        primary_name: str, primary_value: float, control_names: Tuple[str, ...],
+        control_values: Tuple[float, ...]) -> Tuple[str, ...]:
+    """
+    Find traits that have the same value when the values are considered to
+    3 decimal places.
+
+    This is a migration of the
+    `web.webqtl.correlation.correlationFunction.findIdenticalTraits` function in
+    GN1.
+    """
+    def __merge_identicals__(
+            acc: Tuple[str, ...],
+            ident: Tuple[str, Tuple[str, ...]]) -> Tuple[str, ...]:
+        return acc + ident[1]
+
+    def __dictify_controls__(acc, control_item):
+        ckey = "{:.3f}".format(control_item[0])
+        return {**acc, ckey: acc.get(ckey, tuple()) + (control_item[1],)}
+
+    return (reduce(## for identical control traits
+        __merge_identicals__,
+        (item for item in reduce(# type: ignore[var-annotated]
+            __dictify_controls__, zip(control_values, control_names),
+            {}).items() if len(item[1]) > 1),
+        tuple())
+            or
+            reduce(## If no identical control traits, try primary and controls
+                __merge_identicals__,
+                (item for item in reduce(# type: ignore[var-annotated]
+                    __dictify_controls__,
+                    zip((primary_value,) + control_values,
+                        (primary_name,) + control_names), {}).items()
+                 if len(item[1]) > 1),
+                tuple()))
diff --git a/tests/unit/test_data_helpers.py b/tests/unit/test_data_helpers.py
new file mode 100644
index 0000000..1eec3cc
--- /dev/null
+++ b/tests/unit/test_data_helpers.py
@@ -0,0 +1,37 @@
+"""
+Test functions in gn3.data_helpers
+"""
+
+from unittest import TestCase
+
+from gn3.data_helpers import partition_all
+
+class TestDataHelpers(TestCase):
+    """
+    Test functions in gn3.data_helpers
+    """
+
+    def test_partition_all(self):
+        """
+        Test that `gn3.data_helpers.partition_all` partitions sequences as expected.
+
+        Given:
+            - `num`: The number of items per partition
+            - `items`: A sequence of items
+        When:
+            - The arguments above are passed to the `gn3.data_helpers.partition_all`
+        Then:
+            - Return a new sequence with partitions, each of which has `num`
+              items in the same order as those in `items`, save for the last
+              partition which might have fewer items than `num`.
+        """
+        for count, items, expected in (
+                (1, [0, 1, 2, 3], ((0,), (1,), (2,), (3,))),
+                (3, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9),
+                 ((0, 1, 2), (3, 4, 5), (6, 7, 8), (9, ))),
+                (4, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+                 ((0, 1, 2, 3), (4, 5, 6, 7), (8, 9))),
+                (13, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+                 ((0, 1, 2, 3, 4, 5, 6, 7, 8, 9), ))):
+            with self.subTest(n=count, items=items):
+                self.assertEqual(partition_all(count, items), expected)
diff --git a/tests/unit/test_partial_correlations.py b/tests/unit/test_partial_correlations.py
index 7631a71..60e54c1 100644
--- a/tests/unit/test_partial_correlations.py
+++ b/tests/unit/test_partial_correlations.py
@@ -4,7 +4,8 @@ from unittest import TestCase
 from gn3.partial_correlations import (
     fix_samples,
     control_samples,
-    dictify_by_samples)
+    dictify_by_samples,
+    find_identical_traits)
 
 sampleslist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"]
 control_traits = (
@@ -106,6 +107,8 @@ class TestPartialCorrelations(TestCase):
 
     def test_dictify_by_samples(self):
         """
+        Test that `dictify_by_samples` generates the appropriate dict
+
         Given:
             a sequence of sequences with sample names, values and variances, as
             in the output of `gn3.partial_correlations.control_samples` or
@@ -133,7 +136,34 @@ class TestPartialCorrelations(TestCase):
             dictified_control_samples)
 
     def test_fix_samples(self):
-        """Test that fix_samples fixes the values"""
+        """
+        Test that `fix_samples` returns only the common samples
+
+        Given:
+            - A primary trait
+            - A sequence of control samples
+        When:
+            - The two arguments are passed to `fix_samples`
+        Then:
+            - Only the names of the samples present in the primary trait that
+              are also present in ALL the control traits are present in the
+              return value
+            - Only the values of the samples present in the primary trait that
+              are also present in ALL the control traits are present in the
+              return value
+            - ALL the values for ALL the control traits are present in the
+              return value
+            - Only the variances of the samples present in the primary trait
+              that are also present in ALL the control traits are present in the
+              return value
+            - ALL the variances for ALL the control traits are present in the
+              return value
+            - The return value is a tuple of the above items, in the following
+              order:
+                ((sample_names, ...), (primary_trait_values, ...),
+                 (control_traits_values, ...), (primary_trait_variances, ...)
+                 (control_traits_variances, ...))
+        """
         self.assertEqual(
             fix_samples(
                 {"B6cC3-1": {"sample_name": "B6cC3-1", "value": 7.51879,
@@ -149,3 +179,33 @@ class TestPartialCorrelations(TestCase):
              (None,),
              (None, None, None, None, None, None, None, None, None, None, None,
               None, None)))
+
+    def test_find_identical_traits(self):
+        """
+        Test `gn3.partial_correlations.find_identical_traits`.
+
+        Given:
+            - the name of a primary trait
+            - the value of a primary trait
+            - a sequence of names of control traits
+            - a sequence of values of control traits
+        When:
+            - the arguments above are passed to the `find_identical_traits`
+              function
+        Then:
+            - Return ALL trait names that have the same value when up to three
+              decimal places are considered
+        """
+        for primn, primv, contn, contv, expected in (
+                ("pt", 12.98395, ("ct0", "ct1", "ct2"),
+                 (0.1234, 2.3456, 3.4567), tuple()),
+                ("pt", 12.98395, ("ct0", "ct1", "ct2"),
+                 (12.98354, 2.3456, 3.4567), ("pt", "ct0")),
+                ("pt", 12.98395, ("ct0", "ct1", "ct2", "ct3"),
+                 (0.1234, 2.3456, 0.1233, 4.5678), ("ct0", "ct2"))
+        ):
+            with self.subTest(
+                    primary_name=primn, primary_value=primv,
+                    control_names=contn, control_values=contv):
+                self.assertEqual(
+                    find_identical_traits(primn, primv, contn, contv), expected)