about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/db/correlations.py183
1 files changed, 181 insertions, 2 deletions
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py
index 67cfef9..87ab082 100644
--- a/gn3/db/correlations.py
+++ b/gn3/db/correlations.py
@@ -3,9 +3,11 @@ This module will hold functions that are used in the (partial) correlations
 feature to access the database to retrieve data needed for computations.
 """
 
-from typing import Any
+from functools import reduce
+from typing import Any, Dict, Tuple
 
 from gn3.random import random_string
+from gn3.data_helpers import partition_all
 from gn3.db.species import translate_to_mouse_gene_id
 
 def get_filename(target_db_name: str, conn: Any) -> str:
@@ -136,4 +138,181 @@ def fetch_literature_correlations(
             db_name=dataset["dataset_name"])
         results = cursor.fetchall()
         cursor.execute("DROP TEMPORARY TABLE %s", temp_table)
-        return dict(results) # {trait_name: lit_corr for trait_name, lit_corr in results}
+        return dict(results)
+
+def compare_tissue_correlation_absolute_values(val1, val2):
+    """
+    Comparison function for use when sorting tissue correlation values.
+
+    This is a partial migration of the
+    `web.webqtl.correlation.CorrelationPage.getTempTissueCorrTable` function in
+    GeneNetwork1."""
+    try:
+        if abs(val1) < abs(val2):
+            return 1
+        if abs(val1) == abs(val2):
+            return 0
+        return -1
+    except TypeError:
+        return 0
+
+def fetch_symbol_value_pair_dict(
+        symbol_list: Tuple[str, ...], data_id_dict: dict,
+        conn: Any) -> Dict[str, Tuple[float, ...]]:
+    """
+    Map each gene symbols to the corresponding tissue expression data.
+
+    This is a migration of the
+    `web.webqtl.correlation.correlationFunction.getSymbolValuePairDict` function
+    in GeneNetwork1.
+    """
+    data_ids = {
+        symbol: data_id_dict.get(symbol) for symbol in symbol_list
+        if data_id_dict.get(symbol) is not None
+    }
+    query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN %(data_ids)s"
+    with conn.cursor() as cursor:
+        cursor.execute(
+            query,
+            data_ids=tuple(data_ids.values()))
+        value_results = cursor.fetchall()
+        return {
+            key: tuple(row[1] for row in value_results if row[0] == key)
+            for key in data_ids.keys()
+        }
+
+    return {}
+
+def fetch_gene_symbol_tissue_value_dict(
+        symbol_list: Tuple[str, ...], data_id_dict: dict, conn: Any,
+        limit_num: int = 1000) -> dict:#getGeneSymbolTissueValueDict
+    """
+    Wrapper function for `gn3.db.correlations.fetch_symbol_value_pair_dict`.
+
+    This is a migrations of the
+    `web.webqtl.correlation.correlationFunction.getGeneSymbolTissueValueDict` in
+    GeneNetwork1.
+    """
+    count = len(symbol_list)
+    if count != 0 and count <= limit_num:
+        return fetch_symbol_value_pair_dict(symbol_list, data_id_dict, conn)
+
+    if count > limit_num:
+        return {
+            key: value for dct in [
+                fetch_symbol_value_pair_dict(sl, data_id_dict, conn)
+                for sl in partition_all(limit_num, symbol_list)]
+            for key, value in dct.items()
+        }
+
+    return {}
+
+def fetch_tissue_probeset_xref_info(
+        gene_name_list: Tuple[str, ...], probeset_freeze_id: int,
+        conn: Any) -> Tuple[tuple, dict, dict, dict, dict, dict, dict]:
+    """
+    Retrieve the ProbeSet XRef information for tissues.
+
+    This is a migration of the
+    `web.webqtl.correlation.correlationFunction.getTissueProbeSetXRefInfo`
+    function in GeneNetwork1."""
+    with conn.cursor() as cursor:
+        if len(gene_name_list) == 0:
+            query = (
+                "SELECT t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, "
+                "t.description, t.Probe_Target_Description "
+                "FROM "
+                "("
+                "  SELECT Symbol, max(Mean) AS maxmean "
+                "  FROM TissueProbeSetXRef "
+                "  WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s "
+                "  AND Symbol != '' "
+                "  AND Symbol IS NOT NULL "
+                "  GROUP BY Symbol"
+                ") AS x "
+                "INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol "
+                "AND t.Mean = x.maxmean")
+            cursor.execute(query, probeset_freeze_id=probeset_freeze_id)
+        else:
+            query = (
+                "SELECT t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, "
+                "t.description, t.Probe_Target_Description "
+                "FROM "
+                "("
+                "  SELECT Symbol, max(Mean) AS maxmean "
+                "  FROM TissueProbeSetXRef "
+                "  WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s "
+                "  AND Symbol in %(symbols)s "
+                "  GROUP BY Symbol"
+                ") AS x "
+                "INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol "
+                "AND t.Mean = x.maxmean")
+            cursor.execute(
+                query, probeset_freeze_id=probeset_freeze_id,
+                symbols=tuple(gene_name_list))
+
+        results = cursor.fetchall()
+
+    return reduce(
+        lambda acc, item: (
+            acc[0] + (item[0],),
+            {**acc[1], item[0].lower(): item[1]},
+            {**acc[1], item[0].lower(): item[2]},
+            {**acc[1], item[0].lower(): item[3]},
+            {**acc[1], item[0].lower(): item[4]},
+            {**acc[1], item[0].lower(): item[5]},
+            {**acc[1], item[0].lower(): item[6]}),
+        results or tuple(),
+        (tuple(), {}, {}, {}, {}, {}, {}))
+
+def correlations_of_all_tissue_traits() -> Tuple[dict, dict]:
+    """
+    This is a migration of the
+    `web.webqtl.correlation.CorrelationPage.calculateCorrOfAllTissueTrait`
+    function in GeneNetwork1.
+    """
+    raise Exception("Unimplemented!!!")
+    return ({}, {})
+
+def build_temporary_tissue_correlations_table(
+        trait_symbol: str, probeset_freeze_id: int, method: str,
+        return_number: int, conn: Any) -> str:
+    """
+    Build a temporary table to hold the tissue correlations data.
+
+    This is a migration of the
+    `web.webqtl.correlation.CorrelationPage.getTempTissueCorrTable` function in
+    GeneNetwork1."""
+    raise Exception("Unimplemented!!!")
+    return ""
+
+def fetch_tissue_correlations(
+        dataset: dict, trait_symbol: str, probeset_freeze_id: int, method: str,
+        return_number: int, conn: Any) -> dict:
+    """
+    Pair tissue correlations data with a trait id string.
+
+    This is a migration of the
+    `web.webqtl.correlation.CorrelationPage.fetchTissueCorrelations` function in
+    GeneNetwork1.
+    """
+    temp_table = build_temporary_tissue_correlations_table(
+        trait_symbol, probeset_freeze_id, method, return_number, conn)
+    with conn.cursor() as cursor:
+        cursor.execute(
+            (
+                f"SELECT ProbeSet.Name, {temp_table}.Correlation, "
+                f"{temp_table}.PValue "
+                "FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) "
+                "LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol "
+                "WHERE ProbeSetFreeze.Name = %(db_name) "
+                "AND ProbeSetFreeze.Id=ProbeSetXRef.ProbeSetFreezeId "
+                "AND ProbeSet.Id = ProbeSetXRef.ProbeSetId "
+                "AND ProbeSet.Symbol IS NOT NULL "
+                "AND %s.Correlation IS NOT NULL"),
+            db_name=dataset["dataset_name"])
+        results = cursor.fetchall()
+        cursor.execute("DROP TEMPORARY TABLE %s", temp_table)
+        return {
+            trait_name: (tiss_corr, tiss_p_val)
+            for trait_name, tiss_corr, tiss_p_val in results}