about summary refs log tree commit diff
path: root/gn3/db/sample_data.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/db/sample_data.py')
-rw-r--r--gn3/db/sample_data.py228
1 files changed, 196 insertions, 32 deletions
diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py
index 8db40e3..4e01a3a 100644
--- a/gn3/db/sample_data.py
+++ b/gn3/db/sample_data.py
@@ -59,20 +59,32 @@ def __extract_actions(
     return result
 
 def get_mrna_sample_data(
-    conn: Any, probeset_id: str, dataset_name: str
+    conn: Any, probeset_id: int, dataset_name: str, probeset_name: str = None  # type: ignore
 ) -> Dict:
     """Fetch a mRNA Assay (ProbeSet in the DB) trait's sample data and return it as a dict"""
     with conn.cursor() as cursor:
-        cursor.execute("""
-SELECT st.Name, ifnull(psd.value, 'x'), ifnull(psse.error, 'x'), ifnull(ns.count, 'x')
-FROM ProbeSetFreeze psf
-    JOIN ProbeSetXRef psx ON psx.ProbeSetFreezeId = psf.Id
-    JOIN ProbeSet ps ON ps.Id = psx.ProbeSetId
-    JOIN ProbeSetData psd ON psd.Id = psx.DataId
-    JOIN Strain st ON psd.StrainId = st.Id
-    LEFT JOIN ProbeSetSE psse ON psse.DataId = psd.Id AND psse.StrainId = psd.StrainId
-    LEFT JOIN NStrain ns ON ns.DataId = psd.Id AND ns.StrainId = psd.StrainId
-WHERE ps.Id = %s AND psf.Name= %s""", (probeset_id, dataset_name))
+        if probeset_name:
+            cursor.execute("""
+    SELECT st.Name, ifnull(psd.value, 'x'), ifnull(psse.error, 'x'), ifnull(ns.count, 'x')
+    FROM ProbeSetFreeze psf
+        JOIN ProbeSetXRef psx ON psx.ProbeSetFreezeId = psf.Id
+        JOIN ProbeSet ps ON ps.Id = psx.ProbeSetId
+        JOIN ProbeSetData psd ON psd.Id = psx.DataId
+        JOIN Strain st ON psd.StrainId = st.Id
+        LEFT JOIN ProbeSetSE psse ON psse.DataId = psd.Id AND psse.StrainId = psd.StrainId
+        LEFT JOIN NStrain ns ON ns.DataId = psd.Id AND ns.StrainId = psd.StrainId
+    WHERE ps.Name = %s AND psf.Name= %s""", (probeset_name, dataset_name))
+        else:
+            cursor.execute("""
+    SELECT st.Name, ifnull(psd.value, 'x'), ifnull(psse.error, 'x'), ifnull(ns.count, 'x')
+    FROM ProbeSetFreeze psf
+        JOIN ProbeSetXRef psx ON psx.ProbeSetFreezeId = psf.Id
+        JOIN ProbeSet ps ON ps.Id = psx.ProbeSetId
+        JOIN ProbeSetData psd ON psd.Id = psx.DataId
+        JOIN Strain st ON psd.StrainId = st.Id
+        LEFT JOIN ProbeSetSE psse ON psse.DataId = psd.Id AND psse.StrainId = psd.StrainId
+        LEFT JOIN NStrain ns ON ns.DataId = psd.Id AND ns.StrainId = psd.StrainId
+    WHERE ps.Id = %s AND psf.Name= %s""", (probeset_id, dataset_name))
 
         sample_data = {}
         for data in cursor.fetchall():
@@ -118,18 +130,28 @@ WHERE ps.Id = %s AND psf.Name= %s""", (probeset_id, dataset_name))
         return "\n".join(trait_csv)
 
 def get_pheno_sample_data(
-    conn: Any, trait_name: int, phenotype_id: int
+    conn: Any, trait_name: int, phenotype_id: int, group_id: int = None  # type: ignore
 ) -> Dict:
     """Fetch a phenotype (Publish in the DB) trait's sample data and return it as a dict"""
     with conn.cursor() as cursor:
-        cursor.execute("""
-SELECT st.Name, ifnull(pd.value, 'x'), ifnull(ps.error, 'x'), ifnull(ns.count, 'x')
-FROM PublishFreeze pf JOIN PublishXRef px ON px.InbredSetId = pf.InbredSetId
-     JOIN PublishData pd ON pd.Id = px.DataId JOIN Strain st ON pd.StrainId = st.Id
-     LEFT JOIN PublishSE ps ON ps.DataId = pd.Id AND ps.StrainId = pd.StrainId
-     LEFT JOIN NStrain ns ON ns.DataId = pd.Id AND ns.StrainId = pd.StrainId
-WHERE px.Id = %s AND px.PhenotypeId = %s
-ORDER BY st.Name""", (trait_name, phenotype_id))
+        if group_id:
+            cursor.execute("""
+    SELECT st.Name, ifnull(ROUND(pd.value, 2), 'x'), ifnull(ROUND(ps.error, 3), 'x'), ifnull(ns.count, 'x')
+    FROM PublishFreeze pf JOIN PublishXRef px ON px.InbredSetId = pf.InbredSetId
+        JOIN PublishData pd ON pd.Id = px.DataId JOIN Strain st ON pd.StrainId = st.Id
+        LEFT JOIN PublishSE ps ON ps.DataId = pd.Id AND ps.StrainId = pd.StrainId
+        LEFT JOIN NStrain ns ON ns.DataId = pd.Id AND ns.StrainId = pd.StrainId
+    WHERE px.Id = %s AND px.InbredSetId = %s
+    ORDER BY st.Name""", (trait_name, group_id))
+        else:
+            cursor.execute("""
+    SELECT st.Name, ifnull(pd.value, 'x'), ifnull(ps.error, 'x'), ifnull(ns.count, 'x')
+    FROM PublishFreeze pf JOIN PublishXRef px ON px.InbredSetId = pf.InbredSetId
+        JOIN PublishData pd ON pd.Id = px.DataId JOIN Strain st ON pd.StrainId = st.Id
+        LEFT JOIN PublishSE ps ON ps.DataId = pd.Id AND ps.StrainId = pd.StrainId
+        LEFT JOIN NStrain ns ON ns.DataId = pd.Id AND ns.StrainId = pd.StrainId
+    WHERE px.Id = %s AND px.PhenotypeId = %s
+    ORDER BY st.Name""", (trait_name, phenotype_id))
 
         sample_data = {}
         for data in cursor.fetchall():
@@ -302,8 +324,8 @@ def update_sample_data(
     if data_type == "mrna":
         strain_id, data_id, inbredset_id = get_mrna_sample_data_ids(
             conn=conn,
-            probeset_id=int(probeset_id),
-            dataset_name=dataset_name,
+            probeset_id=int(probeset_id),# pylint: disable=[possibly-used-before-assignment]
+            dataset_name=dataset_name,# pylint: disable=[possibly-used-before-assignment]
             strain_name=extract_strain_name(csv_header, original_data),
         )
         none_case_attrs = {
@@ -315,8 +337,8 @@ def update_sample_data(
     else:
         strain_id, data_id, inbredset_id = get_pheno_sample_data_ids(
             conn=conn,
-            publishxref_id=int(trait_name),
-            phenotype_id=phenotype_id,
+            publishxref_id=int(trait_name),# pylint: disable=[possibly-used-before-assignment]
+            phenotype_id=phenotype_id,# pylint: disable=[possibly-used-before-assignment]
             strain_name=extract_strain_name(csv_header, original_data),
         )
         none_case_attrs = {
@@ -422,8 +444,8 @@ def delete_sample_data(
     if data_type == "mrna":
         strain_id, data_id, inbredset_id = get_mrna_sample_data_ids(
             conn=conn,
-            probeset_id=int(probeset_id),
-            dataset_name=dataset_name,
+            probeset_id=int(probeset_id),# pylint: disable=[possibly-used-before-assignment]
+            dataset_name=dataset_name,# pylint: disable=[possibly-used-before-assignment]
             strain_name=extract_strain_name(csv_header, data),
         )
         none_case_attrs: Dict[str, Any] = {
@@ -435,8 +457,8 @@ def delete_sample_data(
     else:
         strain_id, data_id, inbredset_id = get_pheno_sample_data_ids(
             conn=conn,
-            publishxref_id=int(trait_name),
-            phenotype_id=phenotype_id,
+            publishxref_id=int(trait_name),# pylint: disable=[possibly-used-before-assignment]
+            phenotype_id=phenotype_id,# pylint: disable=[possibly-used-before-assignment]
             strain_name=extract_strain_name(csv_header, data),
         )
         none_case_attrs = {
@@ -528,8 +550,8 @@ def insert_sample_data(
     if data_type == "mrna":
         strain_id, data_id, inbredset_id = get_mrna_sample_data_ids(
             conn=conn,
-            probeset_id=int(probeset_id),
-            dataset_name=dataset_name,
+            probeset_id=int(probeset_id),# pylint: disable=[possibly-used-before-assignment]
+            dataset_name=dataset_name,# pylint: disable=[possibly-used-before-assignment]
             strain_name=extract_strain_name(csv_header, data),
         )
         none_case_attrs = {
@@ -541,8 +563,8 @@ def insert_sample_data(
     else:
         strain_id, data_id, inbredset_id = get_pheno_sample_data_ids(
             conn=conn,
-            publishxref_id=int(trait_name),
-            phenotype_id=phenotype_id,
+            publishxref_id=int(trait_name),# pylint: disable=[possibly-used-before-assignment]
+            phenotype_id=phenotype_id,# pylint: disable=[possibly-used-before-assignment]
             strain_name=extract_strain_name(csv_header, data),
         )
         none_case_attrs = {
@@ -584,3 +606,145 @@ def insert_sample_data(
         return count
     except Exception as _e:
         raise MySQLdb.Error(_e) from _e
+
+def batch_update_sample_data(
+    conn: Any, diff_data: Dict
+):
+    """Given sample data diffs, execute all relevant update/insert/delete queries"""
+    def __fetch_data_id(conn, db_type, trait_id, dataset_name):
+        with conn.cursor() as cursor:
+            if db_type == "Publish":
+                cursor.execute(
+                    (
+                        f"SELECT {db_type}XRef.DataId "
+                        f"FROM {db_type}XRef, {db_type}Freeze "
+                        f"WHERE {db_type}XRef.InbredSetId = {db_type}Freeze.InbredSetId AND "
+                        f"{db_type}XRef.Id = %s AND "
+                        f"{db_type}Freeze.Name = %s"
+                    ), (trait_id, dataset_name)
+                )
+            elif db_type == "ProbeSet":
+                cursor.execute(
+                    (
+                        f"SELECT {db_type}XRef.DataId "
+                        f"FROM {db_type}XRef, {db_type}, {db_type}Freeze "
+                        f"WHERE {db_type}XRef.InbredSetId = {db_type}Freeze.InbredSetId AND "
+                        f"{db_type}XRef.ProbeSetId = {db_type}.Id AND "
+                        f"{db_type}.Name = %s AND "
+                        f"{db_type}Freeze.Name = %s"
+                    ), (trait_id, dataset_name)
+                )
+            return cursor.fetchone()[0]
+
+    def __fetch_strain_id(conn, strain_name):
+        with conn.cursor() as cursor:
+            cursor.execute(
+                "SELECT Id FROM Strain WHERE Name = %s", (strain_name,)
+            )
+            return cursor.fetchone()[0]
+
+    def __update_query(conn, db_type, data_id, strain_id, diffs):
+        with conn.cursor() as cursor:
+            if 'value' in diffs:
+                cursor.execute(
+                    (
+                        f"UPDATE {db_type}Data "
+                        "SET value = %s "
+                        "WHERE Id = %s AND StrainId = %s"
+                    ), (diffs['value']['Current'], data_id, strain_id)
+                )
+            if 'error' in diffs:
+                cursor.execute(
+                    (
+                        f"UPDATE {db_type}SE "
+                        "SET error = %s "
+                        "WHERE DataId = %s AND StrainId = %s"
+                    ), (diffs['error']['Current'], data_id, strain_id)
+                )
+            if 'n_cases' in diffs:
+                cursor.execute(
+                    (
+                        "UPDATE NStrain "
+                        "SET count = %s "
+                        "WHERE DataId = %s AND StrainId = %s"
+                    ), (diffs['n_cases']['Current'], data_id, strain_id)
+                )
+
+        conn.commit()
+
+    def __insert_query(conn, db_type, data_id, strain_id, diffs):
+        with conn.cursor() as cursor:
+            if 'value' in diffs:
+                cursor.execute(
+                    (
+                        f"INSERT INTO {db_type}Data (Id, StrainId, value)"
+                        "VALUES (%s, %s, %s)"
+                    ), (data_id, strain_id, diffs['value'])
+                )
+            if 'error' in diffs:
+                cursor.execute(
+                    (
+                        f"INSERT INTO {db_type}SE (DataId, StrainId, error)"
+                        "VALUES (%s, %s, %s)"
+                    ), (data_id, strain_id, diffs['error'])
+                )
+            if 'n_cases' in diffs:
+                cursor.execute(
+                    (
+                        "INSERT INTO NStrain (DataId, StrainId, count)"
+                        "VALUES (%s, %s, %s)"
+                    ), (data_id, strain_id, diffs['n_cases'])
+                )
+
+        conn.commit()
+
+    def __delete_query(conn, db_type, data_id, strain_id, diffs):
+        with conn.cursor() as cursor:
+            if 'value' in diffs:
+                cursor.execute(
+                    (
+                        f"DELETE FROM {db_type}Data "
+                        "WHERE Id = %s AND StrainId = %s"
+                    ), (data_id, strain_id)
+                )
+            if 'error' in diffs:
+                cursor.execute(
+                    (
+                        f"DELETE FROM {db_type}SE "
+                        "WHERE DataId = %s AND StrainId = %s"
+                    ), (data_id, strain_id)
+                )
+            if 'n_cases' in diffs:
+                cursor.execute(
+                    (
+                        "DELETE FROM NStrain "
+                        "WHERE DataId = %s AND StrainId = %s"
+                    ), (data_id, strain_id)
+                )
+
+        conn.commit()
+
+    def __update_data(conn, db_type, data_id, diffs, update_type):
+        for strain in diffs:
+            strain_id = __fetch_strain_id(conn, strain)
+            if update_type == "update":
+                __update_query(conn, db_type, data_id, strain_id, diffs[strain])
+            elif update_type == "insert":
+                __insert_query(conn, db_type, data_id, strain_id, diffs[strain])
+            elif update_type == "delete":
+                __delete_query(conn, db_type, data_id, strain_id, diffs[strain])
+
+    for key in diff_data:
+        dataset, trait = key.split(":")
+        if "Publish" in dataset:
+            db_type = "Publish"
+        else:
+            db_type = "ProbeSet"
+
+        data_id = __fetch_data_id(conn, db_type, trait, dataset)
+
+        __update_data(conn, db_type, data_id, diff_data[key]['Modifications'], 'update')
+        __update_data(conn, db_type, data_id, diff_data[key]['Additions'], 'insert')
+        __update_data(conn, db_type, data_id, diff_data[key]['Deletions'], 'delete')
+
+    return diff_data