From 59d2d0a8175c4a06c5363a8a3d3addd683ef9a8d Mon Sep 17 00:00:00 2001
From: BonfaceKilz
Date: Tue, 1 Mar 2022 14:14:43 +0300
Subject: Allow deleting case-attribute data during deletion

* gn3/db/sample_data.py (delete_sample_data): Modify this function to allow
deleting case-attribute values.
---
 gn3/db/sample_data.py | 93 +++++++++++++++++++++++++++++++--------------------
 1 file changed, 57 insertions(+), 36 deletions(-)

(limited to 'gn3/db')

diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py
index 708bfd5..06b5767 100644
--- a/gn3/db/sample_data.py
+++ b/gn3/db/sample_data.py
@@ -163,49 +163,70 @@ def update_sample_data(conn: Any,  # pylint: disable=[R0913]
 
 def delete_sample_data(conn: Any,
                        trait_name: str,
-                       strain_name: str,
+                       data: str,
+                       csv_header: str,
                        phenotype_id: int):
     """Given the right parameters, delete sample-data from the relevant
-    table."""
-    strain_id, data_id, _ = get_sample_data_ids(
+    tables."""
+    def __delete_data(conn, table):
+        if value and value != "x":
+            _map = {
+                "PublishData": "StrainId = %s AND Id = %s",
+                "PublishSE": "StrainId = %s AND DataId = %s",
+                "NStrain": "StrainId = %s AND DataId = %s",
+            }
+            with conn.cursor() as cursor:
+                cursor.execute((f"DELETE FROM {table} "
+                                f"WHERE {_map.get(table)}"),
+                               (strain_id, data_id))
+                return cursor.rowcount
+        return 0
+
+    def __delete_case_attribute(conn, strain_id,
+                                case_attr, inbredset_id):
+        if value != "x":
+            with conn.cursor() as cursor:
+                cursor.execute(
+                    ("DELETE FROM CaseAttributeXRefNew "
+                     "WHERE StrainId = "
+                     "(SELECT CaseAttributeId FROM "
+                     f"CaseAttribute WHERE NAME = %s) "
+                     "AND InbredSetId = %s"),
+                    (strain_id, case_attr, inbredset_id)
+                )
+                return cursor.rowcount
+        return 0
+
+    strain_id, data_id, inbredset_id = get_sample_data_ids(
         conn=conn, publishxref_id=trait_name,
         phenotype_id=phenotype_id,
         strain_name=strain_name)
 
-    deleted_published_data: int = 0
-    deleted_se_data: int = 0
-    deleted_n_strains: int = 0
-
-    with conn.cursor() as cursor:
-        # Delete the PublishData table
-        try:
-            # Only run if the strain_id and data_id exist
-            if strain_id and data_id:
-                cursor.execute(("DELETE FROM PublishData "
-                                "WHERE StrainId = %s AND Id = %s")
-                               % (strain_id, data_id))
-                deleted_published_data = cursor.rowcount
-
-                # Delete the PublishSE table
-                cursor.execute(("DELETE FROM PublishSE "
-                                "WHERE StrainId = %s AND DataId = %s") %
-                               (strain_id, data_id))
-                deleted_se_data = cursor.rowcount
-
-                # Delete the NStrain table
-                cursor.execute(("DELETE FROM NStrain "
-                                "WHERE StrainId = %s AND DataId = %s" %
-                                (strain_id, data_id)))
-                deleted_n_strains = cursor.rowcount
-        except Exception as e:  #pylint: disable=[C0103, W0612]
-            conn.rollback()
-            raise MySQLdb.Error
-        conn.commit()
-        cursor.close()
-        cursor.close()
+    none_case_attrs = {
+        "Strain Name": lambda: 0,
+        "Value": lambda: __delete_data(conn, "PublishData"),
+        "SE": lambda: __delete_data(conn, "PublishSE"),
+        "Count": lambda: __delete_data(conn, "NStrain"),
+    }
+    count = 0
 
-    return (deleted_published_data,
-            deleted_se_data, deleted_n_strains)
+    try:
+        for header, value in zip(csv_header.split(","), data.split(",")):
+            header = header.strip()
+            value = value.strip()
+            if header in none_case_attrs:
+                count += none_case_attrs.get(header)()
+            else:
+                count += __delete_case_attribute(
+                    conn=conn,
+                    strain_id=strain_id,
+                    case_attr=header,
+                    inbredset_id=inbredset_id)
+    except Exception as e:  # pylint: disable=[C0103, W0612]
+        conn.rollback()
+        raise MySQLdb.Error
+    conn.commit()
+    return count
 
 
 def insert_sample_data(conn: Any,  # pylint: disable=[R0913]
-- 
cgit v1.2.3