about summary refs log tree commit diff
path: root/gn3/db/traits.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/db/traits.py')
-rw-r--r--gn3/db/traits.py361
1 files changed, 308 insertions, 53 deletions
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index f2673c8..4098b08 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -1,17 +1,93 @@
 """This class contains functions relating to trait data manipulation"""
 import os
+import MySQLdb
+from functools import reduce
 from typing import Any, Dict, Union, Sequence
+
+import MySQLdb
+
 from gn3.settings import TMPDIR
 from gn3.random import random_string
 from gn3.function_helpers import compose
 from gn3.db.datasets import retrieve_trait_dataset
 
 
+def export_trait_data(
+        trait_data: dict, samplelist: Sequence[str], dtype: str = "val",
+        var_exists: bool = False, n_exists: bool = False):
+    """
+    Export data according to `samplelist`. Mostly used in calculating
+    correlations.
+
+    DESCRIPTION:
+    Migrated from
+    https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211
+
+    PARAMETERS
+    trait: (dict)
+      The dictionary of key-value pairs representing a trait
+    samplelist: (list)
+      A list of sample names
+    dtype: (str)
+      ... verify what this is ...
+    var_exists: (bool)
+      A flag indicating existence of variance
+    n_exists: (bool)
+      A flag indicating existence of ndata
+    """
+    def __export_all_types(tdata, sample):
+        sample_data = []
+        if tdata[sample]["value"]:
+            sample_data.append(tdata[sample]["value"])
+            if var_exists:
+                if tdata[sample]["variance"]:
+                    sample_data.append(tdata[sample]["variance"])
+                else:
+                    sample_data.append(None)
+            if n_exists:
+                if tdata[sample]["ndata"]:
+                    sample_data.append(tdata[sample]["ndata"])
+                else:
+                    sample_data.append(None)
+        else:
+            if var_exists and n_exists:
+                sample_data += [None, None, None]
+            elif var_exists or n_exists:
+                sample_data += [None, None]
+            else:
+                sample_data.append(None)
+
+        return tuple(sample_data)
+
+    def __exporter(accumulator, sample):
+        # pylint: disable=[R0911]
+        if sample in trait_data["data"]:
+            if dtype == "val":
+                return accumulator + (trait_data["data"][sample]["value"], )
+            if dtype == "var":
+                return accumulator + (trait_data["data"][sample]["variance"], )
+            if dtype == "N":
+                return accumulator + (trait_data["data"][sample]["ndata"], )
+            if dtype == "all":
+                return accumulator + __export_all_types(trait_data["data"], sample)
+            raise KeyError("Type `%s` is incorrect" % dtype)
+        if var_exists and n_exists:
+            return accumulator + (None, None, None)
+        if var_exists or n_exists:
+            return accumulator + (None, None)
+        return accumulator + (None,)
+
+    return reduce(__exporter, samplelist, tuple())
+
+
 def get_trait_csv_sample_data(conn: Any,
                               trait_name: int, phenotype_id: int):
     """Fetch a trait and return it as a csv string"""
-    sql = ("SELECT DISTINCT Strain.Id, PublishData.Id, Strain.Name, "
-           "PublishData.value, "
+    def __float_strip(num_str):
+        if str(num_str)[-2:] == ".0":
+            return str(int(num_str))
+        return str(num_str)
+    sql = ("SELECT DISTINCT Strain.Name, PublishData.value, "
            "PublishSE.error, NStrain.count FROM "
            "(PublishData, Strain, PublishXRef, PublishFreeze) "
            "LEFT JOIN PublishSE ON "
@@ -23,65 +99,189 @@ def get_trait_csv_sample_data(conn: Any,
            "PublishData.Id = PublishXRef.DataId AND "
            "PublishXRef.Id = %s AND PublishXRef.PhenotypeId = %s "
            "AND PublishData.StrainId = Strain.Id Order BY Strain.Name")
-    csv_data = ["Strain Id,Strain Name,Value,SE,Count"]
-    publishdata_id = ""
+    csv_data = ["Strain Name,Value,SE,Count"]
     with conn.cursor() as cursor:
         cursor.execute(sql, (trait_name, phenotype_id,))
         for record in cursor.fetchall():
-            (strain_id, publishdata_id,
-             strain_name, value, error, count) = record
+            (strain_name, value, error, count) = record
             csv_data.append(
-                ",".join([str(val) if val else "x"
-                          for val in (strain_id, strain_name,
-                                      value, error, count)]))
-    return f"# Publish Data Id: {publishdata_id}\n\n" + "\n".join(csv_data)
+                ",".join([__float_strip(val) if val else "x"
+                          for val in (strain_name, value, error, count)]))
+    return "\n".join(csv_data)
+
 
+def update_sample_data(conn: Any, #pylint: disable=[R0913]
 
-def update_sample_data(conn: Any,
+                       trait_name: str,
                        strain_name: str,
-                       strain_id: int,
-                       publish_data_id: int,
+                       phenotype_id: int,
                        value: Union[int, float, str],
                        error: Union[int, float, str],
                        count: Union[int, str]):
     """Given the right parameters, update sample-data from the relevant
     table."""
-    # pylint: disable=[R0913, R0914, C0103]
-    STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
-    PUBLISH_DATA_SQL: str = ("UPDATE PublishData SET value = %s "
-                             "WHERE StrainId = %s AND Id = %s")
-    PUBLISH_SE_SQL: str = ("UPDATE PublishSE SET error = %s "
-                           "WHERE StrainId = %s AND DataId = %s")
-    N_STRAIN_SQL: str = ("UPDATE NStrain SET count = %s "
-                         "WHERE StrainId = %s AND DataId = %s")
-
-    updated_strains: int = 0
+    strain_id, data_id = "", ""
+
+    with conn.cursor() as cursor:
+        cursor.execute(
+            ("SELECT Strain.Id, PublishData.Id FROM "
+             "(PublishData, Strain, PublishXRef, PublishFreeze) "
+             "LEFT JOIN PublishSE ON "
+             "(PublishSE.DataId = PublishData.Id AND "
+             "PublishSE.StrainId = PublishData.StrainId) "
+             "LEFT JOIN NStrain ON "
+             "(NStrain.DataId = PublishData.Id AND "
+             "NStrain.StrainId = PublishData.StrainId) "
+             "WHERE PublishXRef.InbredSetId = "
+             "PublishFreeze.InbredSetId AND "
+             "PublishData.Id = PublishXRef.DataId AND "
+             "PublishXRef.Id = %s AND "
+             "PublishXRef.PhenotypeId = %s "
+             "AND PublishData.StrainId = Strain.Id "
+             "AND Strain.Name = \"%s\"") % (trait_name,
+                                            phenotype_id,
+                                            str(strain_name)))
+        strain_id, data_id = cursor.fetchone()
     updated_published_data: int = 0
     updated_se_data: int = 0
     updated_n_strains: int = 0
 
     with conn.cursor() as cursor:
-        # Update the Strains table
-        cursor.execute(STRAIN_ID_SQL, (strain_name, strain_id))
-        updated_strains = cursor.rowcount
         # Update the PublishData table
-        cursor.execute(PUBLISH_DATA_SQL,
+        cursor.execute(("UPDATE PublishData SET value = %s "
+                        "WHERE StrainId = %s AND Id = %s"),
                        (None if value == "x" else value,
-                        strain_id, publish_data_id))
+                        strain_id, data_id))
         updated_published_data = cursor.rowcount
+
         # Update the PublishSE table
-        cursor.execute(PUBLISH_SE_SQL,
+        cursor.execute(("UPDATE PublishSE SET error = %s "
+                        "WHERE StrainId = %s AND DataId = %s"),
                        (None if error == "x" else error,
-                        strain_id, publish_data_id))
+                        strain_id, data_id))
         updated_se_data = cursor.rowcount
+
         # Update the NStrain table
-        cursor.execute(N_STRAIN_SQL,
+        cursor.execute(("UPDATE NStrain SET count = %s "
+                        "WHERE StrainId = %s AND DataId = %s"),
                        (None if count == "x" else count,
-                        strain_id, publish_data_id))
+                        strain_id, data_id))
         updated_n_strains = cursor.rowcount
-    return (updated_strains, updated_published_data,
+    return (updated_published_data,
             updated_se_data, updated_n_strains)
 
+
+def delete_sample_data(conn: Any,
+                       trait_name: str,
+                       strain_name: str,
+                       phenotype_id: int):
+    """Given the right parameters, delete sample-data from the relevant
+    table."""
+    strain_id, data_id = "", ""
+
+    deleted_published_data: int = 0
+    deleted_se_data: int = 0
+    deleted_n_strains: int = 0
+
+    with conn.cursor() as cursor:
+        # Delete the PublishData table
+        try:
+            cursor.execute(
+                ("SELECT Strain.Id, PublishData.Id FROM "
+                 "(PublishData, Strain, PublishXRef, PublishFreeze) "
+                 "LEFT JOIN PublishSE ON "
+                 "(PublishSE.DataId = PublishData.Id AND "
+                 "PublishSE.StrainId = PublishData.StrainId) "
+                 "LEFT JOIN NStrain ON "
+                 "(NStrain.DataId = PublishData.Id AND "
+                 "NStrain.StrainId = PublishData.StrainId) "
+                 "WHERE PublishXRef.InbredSetId = "
+                 "PublishFreeze.InbredSetId AND "
+                 "PublishData.Id = PublishXRef.DataId AND "
+                 "PublishXRef.Id = %s AND "
+                 "PublishXRef.PhenotypeId = %s "
+                 "AND PublishData.StrainId = Strain.Id "
+                 "AND Strain.Name = \"%s\"") % (trait_name,
+                                                phenotype_id,
+                                                str(strain_name)))
+            strain_id, data_id = cursor.fetchone()
+
+            cursor.execute(("DELETE FROM PublishData "
+                            "WHERE StrainId = %s AND Id = %s")
+                           % (strain_id, data_id))
+            deleted_published_data = cursor.rowcount
+
+            # Delete the PublishSE table
+            cursor.execute(("DELETE FROM PublishSE "
+                            "WHERE StrainId = %s AND DataId = %s") %
+                           (strain_id, data_id))
+            deleted_se_data = cursor.rowcount
+
+            # Delete the NStrain table
+            cursor.execute(("DELETE FROM NStrain "
+                            "WHERE StrainId = %s AND DataId = %s" %
+                            (strain_id, data_id)))
+            deleted_n_strains = cursor.rowcount
+        except Exception as e: #pylint: disable=[C0103, W0612]
+            conn.rollback()
+            raise MySQLdb.Error
+        conn.commit()
+        cursor.close()
+        cursor.close()
+
+    return (deleted_published_data,
+            deleted_se_data, deleted_n_strains)
+
+
+def insert_sample_data(conn: Any, #pylint: disable=[R0913]
+                       trait_name: str,
+                       strain_name: str,
+                       phenotype_id: int,
+                       value: Union[int, float, str],
+                       error: Union[int, float, str],
+                       count: Union[int, str]):
+    """Given the right parameters, insert sample-data to the relevant table.
+
+    """
+
+    inserted_published_data, inserted_se_data, inserted_n_strains = 0, 0, 0
+    with conn.cursor() as cursor:
+        try:
+            cursor.execute("SELECT DataId FROM PublishXRef WHERE Id = %s AND "
+                           "PhenotypeId = %s", (trait_name, phenotype_id))
+            data_id = cursor.fetchone()
+
+            cursor.execute("SELECT Id FROM Strain WHERE Name = %s",
+                           (strain_name,))
+            strain_id = cursor.fetchone()
+
+            # Insert the PublishData table
+            cursor.execute(("INSERT INTO PublishData (Id, StrainId, value)"
+                            "VALUES (%s, %s, %s)"),
+                           (data_id, strain_id, value))
+            inserted_published_data = cursor.rowcount
+
+            # Insert into the PublishSE table if error is specified
+            if error and error != "x":
+                cursor.execute(("INSERT INTO PublishSE (StrainId, DataId, "
+                                " error) VALUES (%s, %s, %s)") %
+                               (strain_id, data_id, error))
+            inserted_se_data = cursor.rowcount
+
+            # Insert into the NStrain table
+            if count and count != "x":
+                cursor.execute(("INSERT INTO NStrain "
+                                "(StrainId, DataId, error) "
+                                "VALUES (%s, %s, %s)") %
+                               (strain_id, data_id, count))
+            inserted_n_strains = cursor.rowcount
+        except Exception as e: #pylint: disable=[C0103, W0612]
+            conn.rollback()
+            raise MySQLdb.Error
+    return (inserted_published_data,
+            inserted_se_data, inserted_n_strains)
+
+
 def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any):
     """Retrieve trait information for type `Publish` traits.
 
@@ -121,11 +321,12 @@ def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any):
         cursor.execute(
             query,
             {
-                k:v for k, v in trait_data_source.items()
+                k: v for k, v in trait_data_source.items()
                 if k in ["trait_name", "trait_dataset_id"]
             })
         return dict(zip([k.lower() for k in keys], cursor.fetchone()))
 
+
 def set_confidential_field(trait_type, trait_info):
     """Post processing function for 'Publish' trait types.
 
@@ -138,6 +339,7 @@ def set_confidential_field(trait_type, trait_info):
                 and not trait_info.get("pubmed_id", None)) else 0}
     return trait_info
 
+
 def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any):
     """Retrieve trait information for type `ProbeSet` traits.
 
@@ -165,11 +367,12 @@ def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any):
         cursor.execute(
             query,
             {
-                k:v for k, v in trait_data_source.items()
+                k: v for k, v in trait_data_source.items()
                 if k in ["trait_name", "trait_dataset_name"]
             })
         return dict(zip(keys, cursor.fetchone()))
 
+
 def retrieve_geno_trait_info(trait_data_source: Dict[str, Any], conn: Any):
     """Retrieve trait information for type `Geno` traits.
 
@@ -189,11 +392,12 @@ def retrieve_geno_trait_info(trait_data_source: Dict[str, Any], conn: Any):
         cursor.execute(
             query,
             {
-                k:v for k, v in trait_data_source.items()
+                k: v for k, v in trait_data_source.items()
                 if k in ["trait_name", "trait_dataset_name"]
             })
         return dict(zip(keys, cursor.fetchone()))
 
+
 def retrieve_temp_trait_info(trait_data_source: Dict[str, Any], conn: Any):
     """Retrieve trait information for type `Temp` traits.
 
@@ -206,11 +410,12 @@ def retrieve_temp_trait_info(trait_data_source: Dict[str, Any], conn: Any):
         cursor.execute(
             query,
             {
-                k:v for k, v in trait_data_source.items()
+                k: v for k, v in trait_data_source.items()
                 if k in ["trait_name"]
             })
         return dict(zip(keys, cursor.fetchone()))
 
+
 def set_haveinfo_field(trait_info):
     """
     Common postprocessing function for all trait types.
@@ -218,6 +423,7 @@ def set_haveinfo_field(trait_info):
     Sets the value for the 'haveinfo' field."""
     return {**trait_info, "haveinfo": 1 if trait_info else 0}
 
+
 def set_homologene_id_field_probeset(trait_info, conn):
     """
     Postprocessing function for 'ProbeSet' traits.
@@ -233,7 +439,7 @@ def set_homologene_id_field_probeset(trait_info, conn):
         cursor.execute(
             query,
             {
-                k:v for k, v in trait_info.items()
+                k: v for k, v in trait_info.items()
                 if k in ["geneid", "group"]
             })
         res = cursor.fetchone()
@@ -241,12 +447,13 @@ def set_homologene_id_field_probeset(trait_info, conn):
             return {**trait_info, "homologeneid": res[0]}
     return {**trait_info, "homologeneid": None}
 
+
 def set_homologene_id_field(trait_type, trait_info, conn):
     """
     Common postprocessing function for all trait types.
 
     Sets the value for the 'homologene' key."""
-    set_to_null = lambda ti: {**ti, "homologeneid": None}
+    def set_to_null(ti): return {**ti, "homologeneid": None} # pylint: disable=[C0103, C0321]
     functions_table = {
         "Temp": set_to_null,
         "Geno": set_to_null,
@@ -255,6 +462,7 @@ def set_homologene_id_field(trait_type, trait_info, conn):
     }
     return functions_table[trait_type](trait_info)
 
+
 def load_publish_qtl_info(trait_info, conn):
     """
     Load extra QTL information for `Publish` traits
@@ -275,6 +483,7 @@ def load_publish_qtl_info(trait_info, conn):
         return dict(zip(["locus", "lrs", "additive"], cursor.fetchone()))
     return {"locus": "", "lrs": "", "additive": ""}
 
+
 def load_probeset_qtl_info(trait_info, conn):
     """
     Load extra QTL information for `ProbeSet` traits
@@ -297,6 +506,7 @@ def load_probeset_qtl_info(trait_info, conn):
             ["locus", "lrs", "pvalue", "mean", "additive"], cursor.fetchone()))
     return {"locus": "", "lrs": "", "pvalue": "", "mean": "", "additive": ""}
 
+
 def load_qtl_info(qtl, trait_type, trait_info, conn):
     """
     Load extra QTL information for traits
@@ -325,6 +535,7 @@ def load_qtl_info(qtl, trait_type, trait_info, conn):
 
     return qtl_info_functions[trait_type](trait_info, conn)
 
+
 def build_trait_name(trait_fullname):
     """
     Initialises the trait's name, and other values from the search data provided
@@ -351,6 +562,7 @@ def build_trait_name(trait_fullname):
         "cellid": name_parts[2] if len(name_parts) == 3 else ""
     }
 
+
 def retrieve_probeset_sequence(trait, conn):
     """
     Retrieve a 'ProbeSet' trait's sequence information
@@ -372,6 +584,7 @@ def retrieve_probeset_sequence(trait, conn):
         seq = cursor.fetchone()
         return {**trait, "sequence": seq[0] if seq else ""}
 
+
 def retrieve_trait_info(
         threshold: int, trait_full_name: str, conn: Any,
         qtl=None):
@@ -427,6 +640,7 @@ def retrieve_trait_info(
         }
     return trait_info
 
+
 def retrieve_temp_trait_data(trait_info: dict, conn: Any):
     """
     Retrieve trait data for `Temp` traits.
@@ -445,10 +659,12 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any):
             query,
             {"trait_name": trait_info["trait_name"]})
         return [dict(zip(
-            ["sample_name", "value", "se_error", "nstrain", "id"], row))
+            ["sample_name", "value", "se_error", "nstrain", "id"],
+            row))
                 for row in cursor.fetchall()]
     return []
 
+
 def retrieve_species_id(group, conn: Any):
     """
     Retrieve a species id given the Group value
@@ -460,6 +676,7 @@ def retrieve_species_id(group, conn: Any):
         return cursor.fetchone()[0]
     return None
 
+
 def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
     """
     Retrieve trait data for `Geno` traits.
@@ -483,11 +700,14 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
              "dataset_name": trait_info["db"]["dataset_name"],
              "species_id": retrieve_species_id(
                  trait_info["db"]["group"], conn)})
-        return [dict(zip(
-            ["sample_name", "value", "se_error", "id"], row))
-                for row in cursor.fetchall()]
+        return [
+            dict(zip(
+                ["sample_name", "value", "se_error", "id"],
+                row))
+            for row in cursor.fetchall()]
     return []
 
+
 def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
     """
     Retrieve trait data for `Publish` traits.
@@ -514,11 +734,13 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
             query,
             {"trait_name": trait_info["trait_name"],
              "dataset_id": trait_info["db"]["dataset_id"]})
-        return [dict(zip(
-            ["sample_name", "value", "se_error", "nstrain", "id"], row))
-                for row in cursor.fetchall()]
+        return [
+            dict(zip(
+                ["sample_name", "value", "se_error", "nstrain", "id"], row))
+            for row in cursor.fetchall()]
     return []
 
+
 def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
     """
     Retrieve trait data for `Probe Data` types.
@@ -547,11 +769,13 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
             {"cellid": trait_info["cellid"],
              "trait_name": trait_info["trait_name"],
              "dataset_id": trait_info["db"]["dataset_id"]})
-        return [dict(zip(
-            ["sample_name", "value", "se_error", "id"], row))
-                for row in cursor.fetchall()]
+        return [
+            dict(zip(
+                ["sample_name", "value", "se_error", "id"], row))
+            for row in cursor.fetchall()]
     return []
 
+
 def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
     """
     Retrieve trait data for `ProbeSet` traits.
@@ -576,11 +800,13 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
             query,
             {"trait_name": trait_info["trait_name"],
              "dataset_name": trait_info["db"]["dataset_name"]})
-        return [dict(zip(
-            ["sample_name", "value", "se_error", "id"], row))
-                for row in cursor.fetchall()]
+        return [
+            dict(zip(
+                ["sample_name", "value", "se_error", "id"], row))
+            for row in cursor.fetchall()]
     return []
 
+
 def with_samplelist_data_setup(samplelist: Sequence[str]):
     """
     Build function that computes the trait data from provided list of samples.
@@ -607,6 +833,7 @@ def with_samplelist_data_setup(samplelist: Sequence[str]):
         return None
     return setup_fn
 
+
 def without_samplelist_data_setup():
     """
     Build function that computes the trait data.
@@ -627,6 +854,7 @@ def without_samplelist_data_setup():
         return None
     return setup_fn
 
+
 def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tuple()):
     """
     Retrieve trait data
@@ -666,11 +894,38 @@ def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tupl
             "data": dict(map(
                 lambda x: (
                     x["sample_name"],
-                    {k:v for k, v in x.items() if x != "sample_name"}),
+                    {k: v for k, v in x.items() if x != "sample_name"}),
                 data))}
     return {}
 
+
 def generate_traits_filename(base_path: str = TMPDIR):
     """Generate a unique filename for use with generated traits files."""
     return "{}/traits_test_file_{}.txt".format(
         os.path.abspath(base_path), random_string(10))
+
+
+def export_informative(trait_data: dict, inc_var: bool = False) -> tuple:
+    """
+    Export informative strain
+
+    This is a migration of the `exportInformative` function in
+    web/webqtl/base/webqtlTrait.py module in GeneNetwork1.
+
+    There is a chance that the original implementation has a bug, especially
+    dealing with the `inc_var` value. It the `inc_var` value is meant to control
+    the inclusion of the `variance` value, then the current implementation, and
+    that one in GN1 have a bug.
+    """
+    def __exporter__(acc, data_item):
+        if not inc_var or data_item["variance"] is not None:
+            return (
+                acc[0] + (data_item["sample_name"],),
+                acc[1] + (data_item["value"],),
+                acc[2] + (data_item["variance"],))
+        return acc
+    return reduce(
+        __exporter__,
+        filter(lambda td: td["value"] is not None,
+               trait_data["data"].values()),
+        (tuple(), tuple(), tuple()))