about summary refs log tree commit diff
path: root/gn3/db
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/db')
-rw-r--r--gn3/db/traits.py238
1 files changed, 193 insertions, 45 deletions
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index ebb7e3c..75de4f4 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -1,5 +1,6 @@
 """This class contains functions relating to trait data manipulation"""
 import os
+import MySQLdb
 from functools import reduce
 from typing import Any, Dict, Union, Sequence
 
@@ -76,16 +77,15 @@ def export_trait_data(
 
     return reduce(__exporter, samplelist, tuple())
 
+
 def get_trait_csv_sample_data(conn: Any,
                               trait_name: int, phenotype_id: int):
     """Fetch a trait and return it as a csv string"""
-
     def __float_strip(n):
         if str(n)[-2:] == ".0":
             return str(int(n))
         return str(n)
-    sql = ("SELECT DISTINCT Strain.Id, PublishData.Id, Strain.Name, "
-           "PublishData.value, "
+    sql = ("SELECT DISTINCT Strain.Name, PublishData.value, "
            "PublishSE.error, NStrain.count FROM "
            "(PublishData, Strain, PublishXRef, PublishFreeze) "
            "LEFT JOIN PublishSE ON "
@@ -97,65 +97,188 @@ def get_trait_csv_sample_data(conn: Any,
            "PublishData.Id = PublishXRef.DataId AND "
            "PublishXRef.Id = %s AND PublishXRef.PhenotypeId = %s "
            "AND PublishData.StrainId = Strain.Id Order BY Strain.Name")
-    csv_data = ["Strain Id,Strain Name,Value,SE,Count"]
-    publishdata_id = ""
+    csv_data = ["Strain Name,Value,SE,Count"]
     with conn.cursor() as cursor:
         cursor.execute(sql, (trait_name, phenotype_id,))
         for record in cursor.fetchall():
-            (strain_id, publishdata_id,
-             strain_name, value, error, count) = record
+            (strain_name, value, error, count) = record
             csv_data.append(
                 ",".join([__float_strip(val) if val else "x"
-                          for val in (strain_id, strain_name,
-                                      value, error, count)]))
-    return f"# Publish Data Id: {publishdata_id}\n" + "\n".join(csv_data)
+                          for val in (strain_name, value, error, count)]))
+    return "\n".join(csv_data)
 
 
 def update_sample_data(conn: Any,
+                       trait_name: str,
                        strain_name: str,
-                       strain_id: int,
-                       publish_data_id: int,
+                       phenotype_id: int,
                        value: Union[int, float, str],
                        error: Union[int, float, str],
                        count: Union[int, str]):
     """Given the right parameters, update sample-data from the relevant
     table."""
-    # pylint: disable=[R0913, R0914, C0103]
-    STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
-    PUBLISH_DATA_SQL: str = ("UPDATE PublishData SET value = %s "
-                             "WHERE StrainId = %s AND Id = %s")
-    PUBLISH_SE_SQL: str = ("UPDATE PublishSE SET error = %s "
-                           "WHERE StrainId = %s AND DataId = %s")
-    N_STRAIN_SQL: str = ("UPDATE NStrain SET count = %s "
-                         "WHERE StrainId = %s AND DataId = %s")
-
-    updated_strains: int = 0
+    strain_id, data_id = "", ""
+
+    with conn.cursor() as cursor:
+        cursor.execute(
+            ("SELECT Strain.Id, PublishData.Id FROM "
+             "(PublishData, Strain, PublishXRef, PublishFreeze) "
+             "LEFT JOIN PublishSE ON "
+             "(PublishSE.DataId = PublishData.Id AND "
+             "PublishSE.StrainId = PublishData.StrainId) "
+             "LEFT JOIN NStrain ON "
+             "(NStrain.DataId = PublishData.Id AND "
+             "NStrain.StrainId = PublishData.StrainId) "
+             "WHERE PublishXRef.InbredSetId = "
+             "PublishFreeze.InbredSetId AND "
+             "PublishData.Id = PublishXRef.DataId AND "
+             "PublishXRef.Id = %s AND "
+             "PublishXRef.PhenotypeId = %s "
+             "AND PublishData.StrainId = Strain.Id "
+             "AND Strain.Name = \"%s\"") % (trait_name,
+                                            phenotype_id,
+                                            str(strain_name)))
+        strain_id, data_id = cursor.fetchone()
     updated_published_data: int = 0
     updated_se_data: int = 0
     updated_n_strains: int = 0
 
     with conn.cursor() as cursor:
-        # Update the Strains table
-        cursor.execute(STRAIN_ID_SQL, (strain_name, strain_id))
-        updated_strains = cursor.rowcount
         # Update the PublishData table
-        cursor.execute(PUBLISH_DATA_SQL,
+        cursor.execute(("UPDATE PublishData SET value = %s "
+                        "WHERE StrainId = %s AND Id = %s"),
                        (None if value == "x" else value,
-                        strain_id, publish_data_id))
+                        strain_id, data_id))
         updated_published_data = cursor.rowcount
+
         # Update the PublishSE table
-        cursor.execute(PUBLISH_SE_SQL,
+        cursor.execute(("UPDATE PublishSE SET error = %s "
+                        "WHERE StrainId = %s AND DataId = %s"),
                        (None if error == "x" else error,
-                        strain_id, publish_data_id))
+                        strain_id, data_id))
         updated_se_data = cursor.rowcount
+
         # Update the NStrain table
-        cursor.execute(N_STRAIN_SQL,
+        cursor.execute(("UPDATE NStrain SET count = %s "
+                        "WHERE StrainId = %s AND DataId = %s"),
                        (None if count == "x" else count,
-                        strain_id, publish_data_id))
+                        strain_id, data_id))
         updated_n_strains = cursor.rowcount
-    return (updated_strains, updated_published_data,
+    return (updated_published_data,
             updated_se_data, updated_n_strains)
 
+
+def delete_sample_data(conn: Any,
+                       trait_name: str,
+                       strain_name: str,
+                       phenotype_id: int):
+    """Given the right parameters, delete sample-data from the relevant
+    table."""
+    strain_id, data_id = "", ""
+
+    deleted_published_data: int = 0
+    deleted_se_data: int = 0
+    deleted_n_strains: int = 0
+
+    with conn.cursor() as cursor:
+        # Delete the PublishData table
+        try:
+            cursor.execute(
+                ("SELECT Strain.Id, PublishData.Id FROM "
+                 "(PublishData, Strain, PublishXRef, PublishFreeze) "
+                 "LEFT JOIN PublishSE ON "
+                 "(PublishSE.DataId = PublishData.Id AND "
+                 "PublishSE.StrainId = PublishData.StrainId) "
+                 "LEFT JOIN NStrain ON "
+                 "(NStrain.DataId = PublishData.Id AND "
+                 "NStrain.StrainId = PublishData.StrainId) "
+                 "WHERE PublishXRef.InbredSetId = "
+                 "PublishFreeze.InbredSetId AND "
+                 "PublishData.Id = PublishXRef.DataId AND "
+                 "PublishXRef.Id = %s AND "
+                 "PublishXRef.PhenotypeId = %s "
+                 "AND PublishData.StrainId = Strain.Id "
+                 "AND Strain.Name = \"%s\"") % (trait_name,
+                                                phenotype_id,
+                                                str(strain_name)))
+            strain_id, data_id = cursor.fetchone()
+
+            cursor.execute(("DELETE FROM PublishData "
+                            "WHERE StrainId = %s AND Id = %s")
+                           % (strain_id, data_id))
+            deleted_published_data = cursor.rowcount
+
+            # Delete the PublishSE table
+            cursor.execute(("DELETE FROM PublishSE "
+                            "WHERE StrainId = %s AND DataId = %s") %
+                           (strain_id, data_id))
+            deleted_se_data = cursor.rowcount
+
+            # Delete the NStrain table
+            cursor.execute(("DELETE FROM NStrain "
+                            "WHERE StrainId = %s AND DataId = %s" %
+                            (strain_id, data_id)))
+            deleted_n_strains = cursor.rowcount
+        except Exception as e:
+            conn.rollback()
+            raise MySQLdb.Error
+        conn.commit()
+        cursor.close()
+        cursor.close()
+
+    return (deleted_published_data,
+            deleted_se_data, deleted_n_strains)
+
+
+def insert_sample_data(conn: Any,
+                       trait_name: str,
+                       strain_name: str,
+                       phenotype_id: int,
+                       value: Union[int, float, str],
+                       error: Union[int, float, str],
+                       count: Union[int, str]):
+    """Given the right parameters, insert sample-data to the relevant table.
+
+    """
+
+    inserted_published_data, inserted_se_data, inserted_n_strains = 0, 0, 0
+    with conn.cursor() as cursor:
+        try:
+            cursor.execute("SELECT DataId FROM PublishXRef WHERE Id = %s AND "
+                           "PhenotypeId = %s", (trait_name, phenotype_id))
+            data_id = cursor.fetchone()
+
+            cursor.execute("SELECT Id FROM Strain WHERE Name = %s",
+                           (strain_name,))
+            strain_id = cursor.fetchone()
+
+            # Insert the PublishData table
+            cursor.execute(("INSERT INTO PublishData (Id, StrainId, value)"
+                            "VALUES (%s, %s, %s)"),
+                           (data_id, strain_id, value))
+            inserted_published_data = cursor.rowcount
+
+            # Insert into the PublishSE table if error is specified
+            if error and error != "x":
+                cursor.execute(("INSERT INTO PublishSE (StrainId, DataId, "
+                                " error) VALUES (%s, %s, %s)") %
+                               (strain_id, data_id, error))
+            inserted_se_data = cursor.rowcount
+
+            # Insert into the NStrain table
+            if count and count != "x":
+                cursor.execute(("INSERT INTO NStrain "
+                                "(StrainId, DataId, error) "
+                                "VALUES (%s, %s, %s)") %
+                               (strain_id, data_id, count))
+            inserted_n_strains = cursor.rowcount
+        except Exception as e:
+            conn.rollback()
+            raise MySQLdb.Error
+    return (inserted_published_data,
+            inserted_se_data, inserted_n_strains)
+
+
 def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any):
     """Retrieve trait information for type `Publish` traits.
 
@@ -195,11 +318,12 @@ def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any):
         cursor.execute(
             query,
             {
-                k:v for k, v in trait_data_source.items()
+                k: v for k, v in trait_data_source.items()
                 if k in ["trait_name", "trait_dataset_id"]
             })
         return dict(zip([k.lower() for k in keys], cursor.fetchone()))
 
+
 def set_confidential_field(trait_type, trait_info):
     """Post processing function for 'Publish' trait types.
 
@@ -212,6 +336,7 @@ def set_confidential_field(trait_type, trait_info):
                 and not trait_info.get("pubmed_id", None)) else 0}
     return trait_info
 
+
 def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any):
     """Retrieve trait information for type `ProbeSet` traits.
 
@@ -239,11 +364,12 @@ def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any):
         cursor.execute(
             query,
             {
-                k:v for k, v in trait_data_source.items()
+                k: v for k, v in trait_data_source.items()
                 if k in ["trait_name", "trait_dataset_name"]
             })
         return dict(zip(keys, cursor.fetchone()))
 
+
 def retrieve_geno_trait_info(trait_data_source: Dict[str, Any], conn: Any):
     """Retrieve trait information for type `Geno` traits.
 
@@ -263,11 +389,12 @@ def retrieve_geno_trait_info(trait_data_source: Dict[str, Any], conn: Any):
         cursor.execute(
             query,
             {
-                k:v for k, v in trait_data_source.items()
+                k: v for k, v in trait_data_source.items()
                 if k in ["trait_name", "trait_dataset_name"]
             })
         return dict(zip(keys, cursor.fetchone()))
 
+
 def retrieve_temp_trait_info(trait_data_source: Dict[str, Any], conn: Any):
     """Retrieve trait information for type `Temp` traits.
 
@@ -280,11 +407,12 @@ def retrieve_temp_trait_info(trait_data_source: Dict[str, Any], conn: Any):
         cursor.execute(
             query,
             {
-                k:v for k, v in trait_data_source.items()
+                k: v for k, v in trait_data_source.items()
                 if k in ["trait_name"]
             })
         return dict(zip(keys, cursor.fetchone()))
 
+
 def set_haveinfo_field(trait_info):
     """
     Common postprocessing function for all trait types.
@@ -292,6 +420,7 @@ def set_haveinfo_field(trait_info):
     Sets the value for the 'haveinfo' field."""
     return {**trait_info, "haveinfo": 1 if trait_info else 0}
 
+
 def set_homologene_id_field_probeset(trait_info, conn):
     """
     Postprocessing function for 'ProbeSet' traits.
@@ -307,7 +436,7 @@ def set_homologene_id_field_probeset(trait_info, conn):
         cursor.execute(
             query,
             {
-                k:v for k, v in trait_info.items()
+                k: v for k, v in trait_info.items()
                 if k in ["geneid", "group"]
             })
         res = cursor.fetchone()
@@ -315,12 +444,13 @@ def set_homologene_id_field_probeset(trait_info, conn):
             return {**trait_info, "homologeneid": res[0]}
     return {**trait_info, "homologeneid": None}
 
+
 def set_homologene_id_field(trait_type, trait_info, conn):
     """
     Common postprocessing function for all trait types.
 
     Sets the value for the 'homologene' key."""
-    set_to_null = lambda ti: {**ti, "homologeneid": None}
+    def set_to_null(ti): return {**ti, "homologeneid": None}
     functions_table = {
         "Temp": set_to_null,
         "Geno": set_to_null,
@@ -329,6 +459,7 @@ def set_homologene_id_field(trait_type, trait_info, conn):
     }
     return functions_table[trait_type](trait_info)
 
+
 def load_publish_qtl_info(trait_info, conn):
     """
     Load extra QTL information for `Publish` traits
@@ -349,6 +480,7 @@ def load_publish_qtl_info(trait_info, conn):
         return dict(zip(["locus", "lrs", "additive"], cursor.fetchone()))
     return {"locus": "", "lrs": "", "additive": ""}
 
+
 def load_probeset_qtl_info(trait_info, conn):
     """
     Load extra QTL information for `ProbeSet` traits
@@ -371,6 +503,7 @@ def load_probeset_qtl_info(trait_info, conn):
             ["locus", "lrs", "pvalue", "mean", "additive"], cursor.fetchone()))
     return {"locus": "", "lrs": "", "pvalue": "", "mean": "", "additive": ""}
 
+
 def load_qtl_info(qtl, trait_type, trait_info, conn):
     """
     Load extra QTL information for traits
@@ -399,6 +532,7 @@ def load_qtl_info(qtl, trait_type, trait_info, conn):
 
     return qtl_info_functions[trait_type](trait_info, conn)
 
+
 def build_trait_name(trait_fullname):
     """
     Initialises the trait's name, and other values from the search data provided
@@ -425,6 +559,7 @@ def build_trait_name(trait_fullname):
         "cellid": name_parts[2] if len(name_parts) == 3 else ""
     }
 
+
 def retrieve_probeset_sequence(trait, conn):
     """
     Retrieve a 'ProbeSet' trait's sequence information
@@ -446,6 +581,7 @@ def retrieve_probeset_sequence(trait, conn):
         seq = cursor.fetchone()
         return {**trait, "sequence": seq[0] if seq else ""}
 
+
 def retrieve_trait_info(
         threshold: int, trait_full_name: str, conn: Any,
         qtl=None):
@@ -501,6 +637,7 @@ def retrieve_trait_info(
         }
     return trait_info
 
+
 def retrieve_temp_trait_data(trait_info: dict, conn: Any):
     """
     Retrieve trait data for `Temp` traits.
@@ -520,9 +657,10 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any):
             {"trait_name": trait_info["trait_name"]})
         return [dict(zip(
             ["sample_name", "value", "se_error", "nstrain", "id"], row))
-                for row in cursor.fetchall()]
+            for row in cursor.fetchall()]
     return []
 
+
 def retrieve_species_id(group, conn: Any):
     """
     Retrieve a species id given the Group value
@@ -534,6 +672,7 @@ def retrieve_species_id(group, conn: Any):
         return cursor.fetchone()[0]
     return None
 
+
 def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
     """
     Retrieve trait data for `Geno` traits.
@@ -559,9 +698,10 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
                  trait_info["db"]["group"], conn)})
         return [dict(zip(
             ["sample_name", "value", "se_error", "id"], row))
-                for row in cursor.fetchall()]
+            for row in cursor.fetchall()]
     return []
 
+
 def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
     """
     Retrieve trait data for `Publish` traits.
@@ -590,9 +730,10 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
              "dataset_id": trait_info["db"]["dataset_id"]})
         return [dict(zip(
             ["sample_name", "value", "se_error", "nstrain", "id"], row))
-                for row in cursor.fetchall()]
+            for row in cursor.fetchall()]
     return []
 
+
 def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
     """
     Retrieve trait data for `Probe Data` types.
@@ -623,9 +764,10 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
              "dataset_id": trait_info["db"]["dataset_id"]})
         return [dict(zip(
             ["sample_name", "value", "se_error", "id"], row))
-                for row in cursor.fetchall()]
+            for row in cursor.fetchall()]
     return []
 
+
 def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
     """
     Retrieve trait data for `ProbeSet` traits.
@@ -652,9 +794,10 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
              "dataset_name": trait_info["db"]["dataset_name"]})
         return [dict(zip(
             ["sample_name", "value", "se_error", "id"], row))
-                for row in cursor.fetchall()]
+            for row in cursor.fetchall()]
     return []
 
+
 def with_samplelist_data_setup(samplelist: Sequence[str]):
     """
     Build function that computes the trait data from provided list of samples.
@@ -681,6 +824,7 @@ def with_samplelist_data_setup(samplelist: Sequence[str]):
         return None
     return setup_fn
 
+
 def without_samplelist_data_setup():
     """
     Build function that computes the trait data.
@@ -701,6 +845,7 @@ def without_samplelist_data_setup():
         return None
     return setup_fn
 
+
 def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tuple()):
     """
     Retrieve trait data
@@ -740,15 +885,17 @@ def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tupl
             "data": dict(map(
                 lambda x: (
                     x["sample_name"],
-                    {k:v for k, v in x.items() if x != "sample_name"}),
+                    {k: v for k, v in x.items() if x != "sample_name"}),
                 data))}
     return {}
 
+
 def generate_traits_filename(base_path: str = TMPDIR):
     """Generate a unique filename for use with generated traits files."""
     return "{}/traits_test_file_{}.txt".format(
         os.path.abspath(base_path), random_string(10))
 
+
 def export_informative(trait_data: dict, inc_var: bool = False) -> tuple:
     """
     Export informative strain
@@ -770,5 +917,6 @@ def export_informative(trait_data: dict, inc_var: bool = False) -> tuple:
         return acc
     return reduce(
         __exporter__,
-        filter(lambda td: td["value"] is not None, trait_data["data"].values()),
+        filter(lambda td: td["value"] is not None,
+               trait_data["data"].values()),
         (tuple(), tuple(), tuple()))