about summary refs log tree commit diff
path: root/gn3/db
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/db')
-rw-r--r--gn3/db/correlations.py211
-rw-r--r--gn3/db/species.py17
-rw-r--r--gn3/db/traits.py47
3 files changed, 242 insertions, 33 deletions
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py
index 06b3310..3d12019 100644
--- a/gn3/db/correlations.py
+++ b/gn3/db/correlations.py
@@ -2,17 +2,16 @@
 This module will hold functions that are used in the (partial) correlations
 feature to access the database to retrieve data needed for computations.
 """
-
+import os
 from functools import reduce
-from typing import Any, Dict, Tuple
+from typing import Any, Dict, Tuple, Union
 
 from gn3.random import random_string
 from gn3.data_helpers import partition_all
 from gn3.db.species import translate_to_mouse_gene_id
 
-from gn3.computations.partial_correlations import correlations_of_all_tissue_traits
-
-def get_filename(target_db_name: str, conn: Any) -> str:
+def get_filename(conn: Any, target_db_name: str, text_files_dir: str) -> Union[
+        str, bool]:
     """
     Retrieve the name of the reference database file with which correlations are
     computed.
@@ -23,18 +22,20 @@ def get_filename(target_db_name: str, conn: Any) -> str:
     """
     with conn.cursor() as cursor:
         cursor.execute(
-            "SELECT Id, FullName from ProbeSetFreeze WHERE Name-%s",
-            target_db_name)
+            "SELECT Id, FullName from ProbeSetFreeze WHERE Name=%s",
+            (target_db_name,))
         result = cursor.fetchone()
         if result:
-            return "ProbeSetFreezeId_{tid}_FullName_{fname}.txt".format(
+            filename = "ProbeSetFreezeId_{tid}_FullName_{fname}.txt".format(
                 tid=result[0],
                 fname=result[1].replace(' ', '_').replace('/', '_'))
+            return ((filename in os.listdir(text_files_dir))
+                    and f"{text_files_dir}/{filename}")
 
-    return ""
+    return False
 
 def build_temporary_literature_table(
-        species: str, gene_id: int, return_number: int, conn: Any) -> str:
+        conn: Any, species: str, gene_id: int, return_number: int) -> str:
     """
     Build and populate a temporary table to hold the literature correlation data
     to be used in computations.
@@ -128,7 +129,7 @@ def fetch_literature_correlations(
     GeneNetwork1.
     """
     temp_table = build_temporary_literature_table(
-        species, gene_id, return_number, conn)
+        conn, species, gene_id, return_number)
     query_fns = {
         "Geno": fetch_geno_literature_correlations,
         # "Temp": fetch_temp_literature_correlations,
@@ -268,8 +269,8 @@ def fetch_gene_symbol_tissue_value_dict_for_trait(
     return {}
 
 def build_temporary_tissue_correlations_table(
-        trait_symbol: str, probeset_freeze_id: int, method: str,
-        return_number: int, conn: Any) -> str:
+        conn: Any, trait_symbol: str, probeset_freeze_id: int, method: str,
+        return_number: int) -> str:
     """
     Build a temporary table to hold the tissue correlations data.
 
@@ -279,6 +280,16 @@ def build_temporary_tissue_correlations_table(
     # We should probably pass the `correlations_of_all_tissue_traits` function
     # as an argument to this function and get rid of the one call immediately
     # following this comment.
+    from gn3.computations.partial_correlations import (#pylint: disable=[C0415, R0401]
+        correlations_of_all_tissue_traits)
+    # This import above is necessary within the function to avoid
+    # circular-imports.
+    #
+    #
+    # This import above is indicative of convoluted code, with the computation
+    # being interwoven with the data retrieval. This needs to be changed, such
+    # that the function being imported here is no longer necessary, or have the
+    # imported function passed to this function as an argument.
     symbol_corr_dict, symbol_p_value_dict = correlations_of_all_tissue_traits(
         fetch_gene_symbol_tissue_value_dict_for_trait(
             (trait_symbol,), probeset_freeze_id, conn),
@@ -320,7 +331,7 @@ def fetch_tissue_correlations(# pylint: disable=R0913
     GeneNetwork1.
     """
     temp_table = build_temporary_tissue_correlations_table(
-        trait_symbol, probeset_freeze_id, method, return_number, conn)
+        conn, trait_symbol, probeset_freeze_id, method, return_number)
     with conn.cursor() as cursor:
         cursor.execute(
             (
@@ -379,3 +390,175 @@ def check_symbol_for_tissue_correlation(
             return True
 
     return False
+
+def fetch_sample_ids(
+        conn: Any, sample_names: Tuple[str, ...], species_name: str) -> Tuple[
+            int, ...]:
+    """
+    Given a sequence of sample names, and a species name, return the sample ids
+    that correspond to both.
+
+    This is a partial migration of the
+    `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+    GeneNetwork1.
+    """
+    query = (
+        "SELECT Strain.Id FROM Strain, Species "
+        "WHERE Strain.Name IN %(samples_names)s "
+        "AND Strain.SpeciesId=Species.Id "
+        "AND Species.name=%(species_name)s")
+    with conn.cursor() as cursor:
+        cursor.execute(
+            query,
+            {
+                "samples_names": tuple(sample_names),
+                "species_name": species_name
+            })
+        return tuple(row[0] for row in cursor.fetchall())
+
+def build_query_sgo_lit_corr(
+        db_type: str, temp_table: str, sample_id_columns: str,
+        joins: Tuple[str, ...]) -> str:
+    """
+    Build query for `SGO Literature Correlation` data, when querying the given
+    `temp_table` temporary table.
+
+    This is a partial migration of the
+    `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+    GeneNetwork1.
+    """
+    return (
+        (f"SELECT {db_type}.Name, {temp_table}.value, " +
+         sample_id_columns +
+         f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+         f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " +
+         " ".join(joins) +
+         " WHERE ProbeSet.GeneId IS NOT NULL " +
+         f"AND {temp_table}.value IS NOT NULL " +
+         f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+         f"AND {db_type}Freeze.Name = %(db_name)s " +
+         f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " +
+         f"ORDER BY {db_type}.Id"),
+        2)
+
+def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins):
+    """
+    Build query for `Tissue Correlation` data, when querying the given
+    `temp_table` temporary table.
+
+    This is a partial migration of the
+    `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+    GeneNetwork1.
+    """
+    return (
+        (f"SELECT {db_type}.Name, {temp_table}.Correlation, " +
+         f"{temp_table}.PValue, " +
+         sample_id_columns +
+         f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+         f"LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol " +
+         " ".join(joins) +
+         " WHERE ProbeSet.Symbol IS NOT NULL " +
+         f"AND {temp_table}.Correlation IS NOT NULL " +
+         f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+         f"AND {db_type}Freeze.Name = %(db_name)s " +
+         f"AND {db_type}.Id = {db_type}XRef.{db_type}Id "
+         f"ORDER BY {db_type}.Id"),
+        3)
+
+def fetch_all_database_data(# pylint: disable=[R0913, R0914]
+        conn: Any, species: str, gene_id: int, trait_symbol: str,
+        samples: Tuple[str, ...], dataset: dict, method: str,
+        return_number: int, probeset_freeze_id: int) -> Tuple[
+            Tuple[float], int]:
+    """
+    This is a migration of the
+    `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+    GeneNetwork1.
+    """
+    db_type = dataset["dataset_type"]
+    db_name = dataset["dataset_name"]
+    def __build_query__(sample_ids, temp_table):
+        sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids)
+        if db_type == "Publish":
+            joins = tuple(
+                ("LEFT JOIN PublishData AS T{item} "
+                 "ON T{item}.Id = PublishXRef.DataId "
+                 "AND T{item}.StrainId = %(T{item}_sample_id)s")
+                for item in sample_ids)
+            return (
+                ("SELECT PublishXRef.Id, " +
+                 sample_id_columns +
+                 "FROM (PublishXRef, PublishFreeze) " +
+                 " ".join(joins) +
+                 " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+                 "AND PublishFreeze.Name = %(db_name)s"),
+                1)
+        if temp_table is not None:
+            joins = tuple(
+                (f"LEFT JOIN {db_type}Data AS T{item} "
+                 f"ON T{item}.Id = {db_type}XRef.DataId "
+                 f"AND T{item}.StrainId=%(T{item}_sample_id)s")
+                for item in sample_ids)
+            if method.lower() == "sgo literature correlation":
+                return build_query_sgo_lit_corr(
+                    sample_ids, temp_table, sample_id_columns, joins)
+            if method.lower() in (
+                    "tissue correlation, pearson's r",
+                    "tissue correlation, spearman's rho"):
+                return build_query_tissue_corr(
+                    sample_ids, temp_table, sample_id_columns, joins)
+        joins = tuple(
+            (f"LEFT JOIN {db_type}Data AS T{item} "
+             f"ON T{item}.Id = {db_type}XRef.DataId "
+             f"AND T{item}.StrainId = %(T{item}_sample_id)s")
+            for item in sample_ids)
+        return (
+            (
+                f"SELECT {db_type}.Name, " +
+                sample_id_columns +
+                f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+                " ".join(joins) +
+                f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+                f"AND {db_type}Freeze.Name = %(db_name)s " +
+                f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " +
+                f"ORDER BY {db_type}.Id"),
+            1)
+
+    def __fetch_data__(sample_ids, temp_table):
+        query, data_start_pos = __build_query__(sample_ids, temp_table)
+        with conn.cursor() as cursor:
+            cursor.execute(
+                query,
+                {"db_name": db_name,
+                 **{f"T{item}_sample_id": item for item in sample_ids}})
+            return (cursor.fetchall(), data_start_pos)
+
+    sample_ids = tuple(
+        # look into graduating this to an argument and removing the `samples`
+        # and `species` argument: function currying and compositions might help
+        # with this
+        f"{sample_id}" for sample_id in
+        fetch_sample_ids(conn, samples, species))
+
+    temp_table = None
+    if gene_id and db_type == "probeset":
+        if method.lower() == "sgo literature correlation":
+            temp_table = build_temporary_literature_table(
+                conn, species, gene_id, return_number)
+        if method.lower() in (
+                "tissue correlation, pearson's r",
+                "tissue correlation, spearman's rho"):
+            temp_table = build_temporary_tissue_correlations_table(
+                conn, trait_symbol, probeset_freeze_id, method, return_number)
+
+    trait_database = tuple(
+        item for sublist in
+        (__fetch_data__(ssample_ids, temp_table)
+         for ssample_ids in partition_all(25, sample_ids))
+        for item in sublist)
+
+    if temp_table:
+        with conn.cursor() as cursor:
+            cursor.execute(f"DROP TEMPORARY TABLE {temp_table}")
+
+    return (trait_database[0], trait_database[1])
diff --git a/gn3/db/species.py b/gn3/db/species.py
index 702a9a8..5b8e096 100644
--- a/gn3/db/species.py
+++ b/gn3/db/species.py
@@ -57,3 +57,20 @@ def translate_to_mouse_gene_id(species: str, geneid: int, conn: Any) -> int:
             return translated_gene_id[0]
 
     return 0 # default if all else fails
+
+def species_name(conn: Any, group: str) -> str:
+    """
+    Retrieve the name of the species, given the group (RISet).
+
+    This is a migration of the
+    `web.webqtl.dbFunction.webqtlDatabaseFunction.retrieveSpecies` function in
+    GeneNetwork1.
+    """
+    with conn.cursor() as cursor:
+        cursor.execute(
+            ("SELECT Species.Name FROM Species, InbredSet "
+             "WHERE InbredSet.Name = %(group_name)s "
+             "AND InbredSet.SpeciesId = Species.Id"),
+            {"group_name": group})
+        return cursor.fetchone()[0]
+    return None
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index 75de4f4..4098b08 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -4,6 +4,8 @@ import MySQLdb
 from functools import reduce
 from typing import Any, Dict, Union, Sequence
 
+import MySQLdb
+
 from gn3.settings import TMPDIR
 from gn3.random import random_string
 from gn3.function_helpers import compose
@@ -81,10 +83,10 @@ def export_trait_data(
 def get_trait_csv_sample_data(conn: Any,
                               trait_name: int, phenotype_id: int):
     """Fetch a trait and return it as a csv string"""
-    def __float_strip(n):
-        if str(n)[-2:] == ".0":
-            return str(int(n))
-        return str(n)
+    def __float_strip(num_str):
+        if str(num_str)[-2:] == ".0":
+            return str(int(num_str))
+        return str(num_str)
     sql = ("SELECT DISTINCT Strain.Name, PublishData.value, "
            "PublishSE.error, NStrain.count FROM "
            "(PublishData, Strain, PublishXRef, PublishFreeze) "
@@ -108,7 +110,8 @@ def get_trait_csv_sample_data(conn: Any,
     return "\n".join(csv_data)
 
 
-def update_sample_data(conn: Any,
+def update_sample_data(conn: Any, #pylint: disable=[R0913]
+
                        trait_name: str,
                        strain_name: str,
                        phenotype_id: int,
@@ -219,7 +222,7 @@ def delete_sample_data(conn: Any,
                             "WHERE StrainId = %s AND DataId = %s" %
                             (strain_id, data_id)))
             deleted_n_strains = cursor.rowcount
-        except Exception as e:
+        except Exception as e: #pylint: disable=[C0103, W0612]
             conn.rollback()
             raise MySQLdb.Error
         conn.commit()
@@ -230,7 +233,7 @@ def delete_sample_data(conn: Any,
             deleted_se_data, deleted_n_strains)
 
 
-def insert_sample_data(conn: Any,
+def insert_sample_data(conn: Any, #pylint: disable=[R0913]
                        trait_name: str,
                        strain_name: str,
                        phenotype_id: int,
@@ -272,7 +275,7 @@ def insert_sample_data(conn: Any,
                                 "VALUES (%s, %s, %s)") %
                                (strain_id, data_id, count))
             inserted_n_strains = cursor.rowcount
-        except Exception as e:
+        except Exception as e: #pylint: disable=[C0103, W0612]
             conn.rollback()
             raise MySQLdb.Error
     return (inserted_published_data,
@@ -450,7 +453,7 @@ def set_homologene_id_field(trait_type, trait_info, conn):
     Common postprocessing function for all trait types.
 
     Sets the value for the 'homologene' key."""
-    def set_to_null(ti): return {**ti, "homologeneid": None}
+    def set_to_null(ti): return {**ti, "homologeneid": None} # pylint: disable=[C0103, C0321]
     functions_table = {
         "Temp": set_to_null,
         "Geno": set_to_null,
@@ -656,8 +659,9 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any):
             query,
             {"trait_name": trait_info["trait_name"]})
         return [dict(zip(
-            ["sample_name", "value", "se_error", "nstrain", "id"], row))
-            for row in cursor.fetchall()]
+            ["sample_name", "value", "se_error", "nstrain", "id"],
+            row))
+                for row in cursor.fetchall()]
     return []
 
 
@@ -696,8 +700,10 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
              "dataset_name": trait_info["db"]["dataset_name"],
              "species_id": retrieve_species_id(
                  trait_info["db"]["group"], conn)})
-        return [dict(zip(
-            ["sample_name", "value", "se_error", "id"], row))
+        return [
+            dict(zip(
+                ["sample_name", "value", "se_error", "id"],
+                row))
             for row in cursor.fetchall()]
     return []
 
@@ -728,8 +734,9 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
             query,
             {"trait_name": trait_info["trait_name"],
              "dataset_id": trait_info["db"]["dataset_id"]})
-        return [dict(zip(
-            ["sample_name", "value", "se_error", "nstrain", "id"], row))
+        return [
+            dict(zip(
+                ["sample_name", "value", "se_error", "nstrain", "id"], row))
             for row in cursor.fetchall()]
     return []
 
@@ -762,8 +769,9 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
             {"cellid": trait_info["cellid"],
              "trait_name": trait_info["trait_name"],
              "dataset_id": trait_info["db"]["dataset_id"]})
-        return [dict(zip(
-            ["sample_name", "value", "se_error", "id"], row))
+        return [
+            dict(zip(
+                ["sample_name", "value", "se_error", "id"], row))
             for row in cursor.fetchall()]
     return []
 
@@ -792,8 +800,9 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
             query,
             {"trait_name": trait_info["trait_name"],
              "dataset_name": trait_info["db"]["dataset_name"]})
-        return [dict(zip(
-            ["sample_name", "value", "se_error", "id"], row))
+        return [
+            dict(zip(
+                ["sample_name", "value", "se_error", "id"], row))
             for row in cursor.fetchall()]
     return []