about summary refs log tree commit diff
path: root/gn3/db
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/db')
-rw-r--r--gn3/db/sample_data.py93
1 files changed, 49 insertions, 44 deletions
diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py
index b2d6aed..e3daa21 100644
--- a/gn3/db/sample_data.py
+++ b/gn3/db/sample_data.py
@@ -1,7 +1,8 @@
-from gn3.csvcmp import extract_strain_name
-from typing import Any, Tuple
+"""Module containing functions that work with sample data"""
+from typing import Any, Tuple, Dict, Callable
 
 import MySQLdb
+from gn3.csvcmp import extract_strain_name
 
 
 _MAP = {
@@ -14,6 +15,10 @@ _MAP = {
 def __extract_actions(original_data: str,
                       updated_data: str,
                       csv_header: str) -> dict:
+    """Return a dictionary containing elements that need to be deleted, inserted,
+or updated.
+
+    """
     original_data = original_data.strip().split(",")
     updated_data = updated_data.strip().split(",")
     csv_header = csv_header.strip().split(",")
@@ -28,7 +33,7 @@ def __extract_actions(original_data: str,
             strain_name = _o
         if _o == _u:  # No change
             continue
-        elif _o and _u == "x":  # Deletion
+        if _o and _u == "x":  # Deletion
             result["delete"]["data"].append(_o)
             result["delete"]["csv_header"].append(_h)
         elif _o == "x" and _u:  # Insert
@@ -84,23 +89,23 @@ def get_trait_csv_sample_data(conn: Any,
         if not case_attr_columns:
             return ("Strain Name,Value,SE,Count\n" +
                     "\n".join(csv_data.keys()))
-        else:
-            columns = sorted(case_attr_columns)
-            csv = ("Strain Name,Value,SE,Count," +
-                   ",".join(columns) + "\n")
-            for key, value in csv_data.items():
-                if not value:
-                    csv += (key + (len(case_attr_columns) * ",x") + "\n")
-                else:
-                    vals = [str(value.get(column, "x")) for column in columns]
-                    csv += (key + "," + ",".join(vals) + "\n")
-            return csv
+        columns = sorted(case_attr_columns)
+        csv = ("Strain Name,Value,SE,Count," +
+               ",".join(columns) + "\n")
+        for key, value in csv_data.items():
+            if not value:
+                csv += (key + (len(case_attr_columns) * ",x") + "\n")
+            else:
+                vals = [str(value.get(column, "x")) for column in columns]
+                csv += (key + "," + ",".join(vals) + "\n")
+        return csv
     return "No Sample Data Found"
 
 
 def get_sample_data_ids(conn: Any, publishxref_id: int,
                         phenotype_id: int,
                         strain_name: str) -> Tuple:
+    """Get the strain_id, publishdata_id and inbredset_id for a given strain"""
     strain_id, publishdata_id, inbredset_id = None, None, None
     with conn.cursor() as cursor:
         cursor.execute("SELECT st.id, pd.Id, pf.InbredSetId "
@@ -125,7 +130,8 @@ def get_sample_data_ids(conn: Any, publishxref_id: int,
     return (strain_id, publishdata_id, inbredset_id)
 
 
-def update_sample_data(conn: Any,  # pylint: disable=[R0913]
+# pylint: disable=[R0913, R0914]
+def update_sample_data(conn: Any,
                        trait_name: str,
                        original_data: str,
                        updated_data: str,
@@ -136,12 +142,14 @@ def update_sample_data(conn: Any,  # pylint: disable=[R0913]
     def __update_data(conn, table, value):
         if value and value != "x":
             with conn.cursor() as cursor:
-                sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
+                sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) +
+                             " = %s")
                 _val = _MAP.get(table)[-1]
                 cursor.execute((f"UPDATE {table} SET {_val} = %s "
                                 f"WHERE {sub_query}"),
                                (value, strain_id, data_id))
                 return cursor.rowcount
+        return 0
 
     def __update_case_attribute(conn, value, strain_id,
                                 case_attr, inbredset_id):
@@ -177,6 +185,7 @@ def update_sample_data(conn: Any,  # pylint: disable=[R0913]
         if __actions.get("update"):
             _csv_header = __actions["update"]["csv_header"]
             _data = __actions["update"]["data"]
+            # pylint: disable=[E1101]
             for header, value in zip(_csv_header.split(","),
                                      _data.split(",")):
                 header = header.strip()
@@ -208,9 +217,9 @@ def update_sample_data(conn: Any,  # pylint: disable=[R0913]
                 phenotype_id=phenotype_id)
             if _rowcount:
                 count += 1
-    except Exception as e:  # pylint: disable=[C0103, W0612]
+    except Exception as _e:
         conn.rollback()
-        raise MySQLdb.Error
+        raise MySQLdb.Error(_e) from _e
     conn.commit()
     return count
 
@@ -223,28 +232,24 @@ def delete_sample_data(conn: Any,
     """Given the right parameters, delete sample-data from the relevant
     tables."""
     def __delete_data(conn, table):
-        if value and value != "x":
-            sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
-            with conn.cursor() as cursor:
-                cursor.execute((f"DELETE FROM {table} "
-                                f"WHERE {sub_query}"),
-                               (strain_id, data_id))
-                return cursor.rowcount
-        return 0
+        sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
+        with conn.cursor() as cursor:
+            cursor.execute((f"DELETE FROM {table} "
+                            f"WHERE {sub_query}"),
+                           (strain_id, data_id))
+            return cursor.rowcount
 
     def __delete_case_attribute(conn, strain_id,
                                 case_attr, inbredset_id):
-        if value != "x":
-            with conn.cursor() as cursor:
-                cursor.execute(
-                    "DELETE FROM CaseAttributeXRefNew "
-                    "WHERE StrainId = %s AND CaseAttributeId = "
-                    "(SELECT CaseAttributeId FROM "
-                    "CaseAttribute WHERE Name = %s) "
-                    "AND InbredSetId = %s",
-                    (strain_id, case_attr, inbredset_id))
-                return cursor.rowcount
-        return 0
+        with conn.cursor() as cursor:
+            cursor.execute(
+                "DELETE FROM CaseAttributeXRefNew "
+                "WHERE StrainId = %s AND CaseAttributeId = "
+                "(SELECT CaseAttributeId FROM "
+                "CaseAttribute WHERE Name = %s) "
+                "AND InbredSetId = %s",
+                (strain_id, case_attr, inbredset_id))
+            return cursor.rowcount
 
     strain_id, data_id, inbredset_id = get_sample_data_ids(
         conn=conn, publishxref_id=trait_name,
@@ -260,9 +265,8 @@ def delete_sample_data(conn: Any,
     count = 0
 
     try:
-        for header, value in zip(csv_header.split(","), data.split(",")):
+        for header in csv_header.split(","):
             header = header.strip()
-            value = value.strip()
             if header in none_case_attrs:
                 count += none_case_attrs.get(header)()
             else:
@@ -271,14 +275,15 @@ def delete_sample_data(conn: Any,
                     strain_id=strain_id,
                     case_attr=header,
                     inbredset_id=inbredset_id)
-    except Exception as e:  # pylint: disable=[C0103, W0612]
+    except Exception as _e:
         conn.rollback()
-        raise MySQLdb.Error
+        raise MySQLdb.Error(_e) from _e
     conn.commit()
     return count
 
 
-def insert_sample_data(conn: Any,  # pylint: disable=[R0913]
+# pylint: disable=[R0913, R0914]
+def insert_sample_data(conn: Any,
                        trait_name: str,
                        data: str,
                        csv_header: str,
@@ -355,6 +360,6 @@ def insert_sample_data(conn: Any,  # pylint: disable=[R0913]
                     case_attr=header,
                     value=value)
         return count
-    except Exception as e:  # pylint: disable=[C0103, W0612]
+    except Exception as _e:
         conn.rollback()
-        raise MySQLdb.Error
+        raise MySQLdb.Error(_e) from _e