about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/db/sample_data.py316
1 files changed, 297 insertions, 19 deletions
diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py
index d620a1b..8fdc618 100644
--- a/gn3/db/sample_data.py
+++ b/gn3/db/sample_data.py
@@ -8,7 +8,9 @@ from gn3.csvcmp import parse_csv_column
 
 
 _MAP = {
+    "ProbeSetData": ("StrainId", "Id", "value"),
     "PublishData": ("StrainId", "Id", "value"),
+    "ProbeSetSE": ("StrainId", "DataId", "error"),
     "PublishSE": ("StrainId", "DataId", "error"),
     "NStrain": ("StrainId", "DataId", "count"),
 }
@@ -171,21 +173,22 @@ WHERE px.Id = %s AND px.PhenotypeId = %s ORDER BY st.Name""",
         return "\n".join(trait_csv)
 
 def get_mrna_sample_data_ids(
-    conn: Any, probsetxref_id: int, strain_name: str
+    conn: Any, probeset_id: int, dataset_name: str, strain_name: str
 ) -> Tuple:
     """Get the strain_id, probesetdata_id and inbredset_id for a given strain"""
     strain_id, probesetdata_id, inbredset_id = None, None, None
     with conn.cursor() as cursor:
-        cursor.execute(
-            "SELECT st.id, pd.Id, pf.InbredSetId "
-            "FROM ProbeSetData pd "
-            "JOIN Strain st ON pd.StrainId = st.Id "
-            "JOIN ProbeSetXRef px ON px.DataId = pd.Id "
-            "JOIN ProbeSetFreeze pf ON pf.InbredSetId "
-            "= px.InbredSetId WHERE px.Id = %s "
-            "AND px.PhenotypeId = %s AND st.Name = %s",
-            (probsetxref_id, strain_name),
-        )
+        cursor.execute("""
+SELECT st.id, psd.Id, pf.InbredSetId
+FROM ProbeFreeze pf
+    JOIN ProbeSetFreeze psf ON psf.ProbeFreezeId = pf.Id
+    JOIN ProbeSetXRef psx ON psx.ProbeSetFreezeId = psf.Id
+    JOIN ProbeSet ps ON ps.Id = psx.ProbeSetId
+    JOIN ProbeSetData psd ON psd.Id = psx.DataId
+    JOIN Strain st ON psd.StrainId = st.Id
+    LEFT JOIN ProbeSetSE psse ON psse.DataId = psd.Id AND psse.StrainId = psd.StrainId
+    LEFT JOIN NStrain ns ON ns.DataId = psd.Id AND ns.StrainId = psd.StrainId
+WHERE ps.Id = %s AND psf.Name= %s AND st.Name = %s""", (probeset_id, dataset_name, strain_name))
         if _result := cursor.fetchone():
             strain_id, probesetdata_id, inbredset_id = _result
         if not all([strain_id, probesetdata_id, inbredset_id]):
@@ -195,7 +198,7 @@ def get_mrna_sample_data_ids(
                 "WHERE Id = %s",
                 (probesetxref_id),
             )
-            probsetdata_id, inbredset_id = cursor.fetchone()
+            probesetdata_id, inbredset_id = cursor.fetchone()
             cursor.execute(
                 "SELECT Id FROM Strain WHERE Name = %s", (strain_name,)
             )
@@ -234,9 +237,123 @@ def get_pheno_sample_data_ids(
             strain_id = cursor.fetchone()[0]
     return (strain_id, publishdata_id, inbredset_id)
 
+# pylint: disable=[R0913, R0914]
+def update_mrna_sample_data(
+    conn: Any,
+    original_data: str,
+    updated_data: str,
+    csv_header: str,
+    probeset_id: int,
+    dataset_name: str
+) -> int:
+    """Given the right parameters, update sample-data from the relevant
+    table."""
+
+    def __update_data(conn, table, value):
+        if value and value != "x":
+            with conn.cursor() as cursor:
+                sub_query = " = %s AND ".join(_MAP.get(table)[:2]) + " = %s"
+                _val = _MAP.get(table)[-1]
+                cursor.execute(
+                    (f"UPDATE {table} SET {_val} = %s " f"WHERE {sub_query}"),
+                    (value, strain_id, data_id),
+                )
+                conn.commit()
+                return cursor.rowcount
+        return 0
+
+    def __update_case_attribute(
+        conn, value, strain_id, case_attr, inbredset_id
+    ):
+        if value != "x":
+            (id_, name) = parse_csv_column(case_attr)
+            with conn.cursor() as cursor:
+                if id_:
+                    cursor.execute(
+                        "UPDATE CaseAttributeXRefNew "
+                        "SET Value = %s "
+                        "WHERE StrainId = %s AND CaseAttributeId = %s "
+                        "AND InbredSetId = %s",
+                        (value, strain_id, id_, inbredset_id),
+                    )
+                else:
+                    cursor.execute(
+                        "UPDATE CaseAttributeXRefNew "
+                        "SET Value = %s "
+                        "WHERE StrainId = %s AND CaseAttributeId = "
+                        "(SELECT CaseAttributeId FROM "
+                        "CaseAttribute WHERE Name = %s) "
+                        "AND InbredSetId = %s",
+                        (value, strain_id, name, inbredset_id),
+                    )
+                conn.commit()
+                return cursor.rowcount
+        return 0
+
+    strain_id, data_id, inbredset_id = get_mrna_sample_data_ids(
+        conn=conn,
+        probeset_id=int(probeset_id),
+        dataset_name=dataset_name,
+        strain_name=extract_strain_name(csv_header, original_data),
+    )
+
+    none_case_attrs: Dict[str, Callable] = {
+        "Strain Name": lambda x: 0,
+        "Value": lambda x: __update_data(conn, "ProbeSetData", x),
+        "SE": lambda x: __update_data(conn, "ProbeSetSE", x),
+        "Count": lambda x: __update_data(conn, "NStrain", x),
+    }
+    count = 0
+    # try:
+    __actions = __extract_actions(
+        original_data=original_data,
+        updated_data=updated_data,
+        csv_header=csv_header,
+    )
+
+    if __actions.get("update"):
+        _csv_header = __actions["update"]["csv_header"]
+        _data = __actions["update"]["data"]
+        # pylint: disable=[E1101]
+        for header, value in zip(_csv_header.split(","), _data.split(",")):
+            header = header.strip()
+            value = value.strip()
+            if header in none_case_attrs:
+                count += none_case_attrs[header](value)
+            else:
+                count += __update_case_attribute(
+                    conn=conn,
+                    value=value,
+                    strain_id=strain_id,
+                    case_attr=header,
+                    inbredset_id=inbredset_id,
+                )
+    if __actions.get("delete"):
+        _rowcount = delete_mrna_sample_data(
+            conn=conn,
+            data=__actions["delete"]["data"],
+            csv_header=__actions["delete"]["csv_header"],
+            probeset_id=probeset_id,
+            dataset_name=dataset_name
+        )
+        if _rowcount:
+            count += 1
+    if __actions.get("insert"):
+        _rowcount = insert_mrna_sample_data(
+            conn=conn,
+            data=__actions["insert"]["data"],
+            csv_header=__actions["insert"]["csv_header"],
+            probeset_id=probeset_id,
+            dataset_name=dataset_name
+        )
+        if _rowcount:
+            count += 1
+    # except Exception as _e:
+    #     raise MySQLdb.Error(_e) from _e
+    return count
 
 # pylint: disable=[R0913, R0914]
-def update_sample_data(
+def update_pheno_sample_data(
     conn: Any,
     trait_name: str,
     original_data: str,
@@ -288,7 +405,7 @@ def update_sample_data(
                 return cursor.rowcount
         return 0
 
-    strain_id, data_id, inbredset_id = get_sample_data_ids(
+    strain_id, data_id, inbredset_id = get_pheno_sample_data_ids(
         conn=conn,
         publishxref_id=int(trait_name),
         phenotype_id=phenotype_id,
@@ -337,7 +454,7 @@ def update_sample_data(
             if _rowcount:
                 count += 1
         if __actions.get("insert"):
-            _rowcount = insert_sample_data(
+            _rowcount = insert_pheno_sample_data(
                 conn=conn,
                 trait_name=trait_name,
                 data=__actions["insert"]["data"],
@@ -350,8 +467,76 @@ def update_sample_data(
         raise MySQLdb.Error(_e) from _e
     return count
 
+def delete_mrna_sample_data(
+    conn: Any, data: str, csv_header: str, probeset_id: int, dataset_name: str
+) -> int:
+    """Given the right parameters, delete sample-data from the relevant
+    tables."""
+
+    def __delete_data(conn, table):
+        sub_query = " = %s AND ".join(_MAP.get(table)[:2]) + " = %s"
+        with conn.cursor() as cursor:
+            cursor.execute(
+                (f"DELETE FROM {table} " f"WHERE {sub_query}"),
+                (strain_id, data_id),
+            )
+            conn.commit()
+            return cursor.rowcount
+
+    def __delete_case_attribute(conn, strain_id, case_attr, inbredset_id):
+        with conn.cursor() as cursor:
+            (id_, name) = parse_csv_column(case_attr)
+            if id_:
+                cursor.execute(
+                    "DELETE FROM CaseAttributeXRefNew "
+                    "WHERE StrainId = %s AND CaseAttributeId = %s "
+                    "AND InbredSetId = %s",
+                    (strain_id, id_, inbredset_id),
+                )
+            else:
+                cursor.execute(
+                    "DELETE FROM CaseAttributeXRefNew "
+                    "WHERE StrainId = %s AND CaseAttributeId = "
+                    "(SELECT CaseAttributeId FROM "
+                    "CaseAttribute WHERE Name = %s) "
+                    "AND InbredSetId = %s",
+                    (strain_id, name, inbredset_id),
+                )
+            conn.commit()
+            return cursor.rowcount
+
+    strain_id, data_id, inbredset_id = get_mrna_sample_data_ids(
+        conn=conn,
+        probeset_id=int(probeset_id),
+        dataset_name=dataset_name,
+        strain_name=extract_strain_name(csv_header, original_data),
+    )
+
+    none_case_attrs: Dict[str, Any] = {
+        "Strain Name": lambda: 0,
+        "Value": lambda: __delete_data(conn, "PublishData"),
+        "SE": lambda: __delete_data(conn, "PublishSE"),
+        "Count": lambda: __delete_data(conn, "NStrain"),
+    }
+    count = 0
+
+    try:
+        for header in csv_header.split(","):
+            header = header.strip()
+            if header in none_case_attrs:
+                count += none_case_attrs[header]()
+            else:
+                count += __delete_case_attribute(
+                    conn=conn,
+                    strain_id=strain_id,
+                    case_attr=header,
+                    inbredset_id=inbredset_id,
+                )
+    except Exception as _e:
+        raise MySQLdb.Error(_e) from _e
+    return count
 
-def delete_sample_data(
+def delete_pheno_sample_data(
     conn: Any, trait_name: str, data: str, csv_header: str, phenotype_id: int
 ) -> int:
     """Given the right parameters, delete sample-data from the relevant
@@ -389,7 +574,7 @@ def delete_sample_data(
             conn.commit()
             return cursor.rowcount
 
-    strain_id, data_id, inbredset_id = get_sample_data_ids(
+    strain_id, data_id, inbredset_id = get_pheno_sample_data_ids(
         conn=conn,
         publishxref_id=int(trait_name),
         phenotype_id=phenotype_id,
@@ -420,9 +605,102 @@ def delete_sample_data(
         raise MySQLdb.Error(_e) from _e
     return count
 
+# pylint: disable=[R0913, R0914]
+def insert_mrna_sample_data(
+    conn: Any, data: str, csv_header: str, probeset_id: int, dataset_name: str
+) -> int:
+    """Given the right parameters, insert sample-data to the relevant table."""
+
+    def __insert_data(conn, table, value):
+        if value and value != "x":
+            with conn.cursor() as cursor:
+                columns = ", ".join(_MAP.get(table))
+                cursor.execute(
+                    (
+                        f"INSERT INTO {table} "
+                        f"({columns}) "
+                        f"VALUES (%s, %s, %s)"
+                    ),
+                    (strain_id, data_id, value),
+                )
+                conn.commit()
+                return cursor.rowcount
+        return 0
+
+    def __insert_case_attribute(conn, case_attr, value):
+        if value != "x":
+            with conn.cursor() as cursor:
+                (id_, name) = parse_csv_column(case_attr)
+                if not id_:
+                    cursor.execute(
+                        "SELECT Id FROM CaseAttribute WHERE Name = %s",
+                        (name,),
+                    )
+                    if case_attr_id := cursor.fetchone():
+                        id_ = case_attr_id[0]
+
+                cursor.execute(
+                    "SELECT StrainId FROM "
+                    "CaseAttributeXRefNew WHERE StrainId = %s "
+                    "AND CaseAttributeId = %s "
+                    "AND InbredSetId = %s",
+                    (strain_id, id_, inbredset_id),
+                )
+                if (not cursor.fetchone()) and id_:
+                    cursor.execute(
+                        "INSERT INTO CaseAttributeXRefNew "
+                        "(StrainId, CaseAttributeId, Value, InbredSetId) "
+                        "VALUES (%s, %s, %s, %s)",
+                        (strain_id, id_, value, inbredset_id),
+                    )
+                    row_count = cursor.rowcount
+                    conn.commit()
+                    return row_count
+                conn.commit()
+        return 0
+
+    strain_id, data_id, inbredset_id = get_mrna_sample_data_ids(
+        conn=conn,
+        probeset_id=int(probeset_id),
+        dataset_name=dataset_name,
+        strain_name=extract_strain_name(csv_header, data),
+    )
+
+    none_case_attrs: Dict[str, Any] = {
+        "Strain Name": lambda _: 0,
+        "Value": lambda x: __insert_data(conn, "PublishData", x),
+        "SE": lambda x: __insert_data(conn, "PublishSE", x),
+        "Count": lambda x: __insert_data(conn, "NStrain", x),
+    }
+
+    try:
+        count = 0
+
+        # Check if the data already exists:
+        with conn.cursor() as cursor:
+            cursor.execute(
+                "SELECT Id FROM PublishData where Id = %s "
+                "AND StrainId = %s",
+                (data_id, strain_id))
+            data_exists = cursor.fetchone()
+        if data_exists:  # Data already exists
+            return count
+
+        for header, value in zip(csv_header.split(","), data.split(",")):
+            header = header.strip()
+            value = value.strip()
+            if header in none_case_attrs:
+                count += none_case_attrs[header](value)
+            else:
+                count += __insert_case_attribute(
+                    conn=conn, case_attr=header, value=value
+                )
+        return count
+    except Exception as _e:
+        raise MySQLdb.Error(_e) from _e
 
 # pylint: disable=[R0913, R0914]
-def insert_sample_data(
+def insert_pheno_sample_data(
     conn: Any, trait_name: str, data: str, csv_header: str, phenotype_id: int
 ) -> int:
     """Given the right parameters, insert sample-data to the relevant table."""
@@ -475,7 +753,7 @@ def insert_sample_data(
                 conn.commit()
         return 0
 
-    strain_id, data_id, inbredset_id = get_sample_data_ids(
+    strain_id, data_id, inbredset_id = get_pheno_sample_data_ids(
         conn=conn,
         publishxref_id=int(trait_name),
         phenotype_id=phenotype_id,