about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/db/sample_data.py265
1 files changed, 148 insertions, 117 deletions
diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py
index f73954f..73fbd95 100644
--- a/gn3/db/sample_data.py
+++ b/gn3/db/sample_data.py
@@ -13,11 +13,11 @@ _MAP = {
 }
 
 
-def __extract_actions(original_data: str,
-                      updated_data: str,
-                      csv_header: str) -> Dict:
+def __extract_actions(
+    original_data: str, updated_data: str, csv_header: str
+) -> Dict:
     """Return a dictionary containing elements that need to be deleted, inserted,
-or updated.
+    or updated.
 
     """
     result: Dict[str, Any] = {
@@ -26,9 +26,11 @@ or updated.
         "update": {"data": [], "csv_header": []},
     }
     strain_name = ""
-    for _o, _u, _h in zip(original_data.strip().split(","),
-                          updated_data.strip().split(","),
-                          csv_header.strip().split(",")):
+    for _o, _u, _h in zip(
+        original_data.strip().split(","),
+        updated_data.strip().split(","),
+        csv_header.strip().split(","),
+    ):
         if _h == "Strain Name":
             strain_name = _o
         if _o == _u:  # No change
@@ -46,33 +48,38 @@ or updated.
         if not val["data"]:
             result[key] = None
         else:
-            result[key]["data"] = (f"{strain_name}," +
-                                   ",".join(result[key]["data"]))
-            result[key]["csv_header"] = ("Strain Name," +
-                                         ",".join(result[key]["csv_header"]))
+            result[key]["data"] = f"{strain_name}," + ",".join(
+                result[key]["data"]
+            )
+            result[key]["csv_header"] = "Strain Name," + ",".join(
+                result[key]["csv_header"]
+            )
     return result
 
 
-def get_trait_csv_sample_data(conn: Any,
-                              trait_name: int, phenotype_id: int) -> str:
+def get_trait_csv_sample_data(
+    conn: Any, trait_name: int, phenotype_id: int
+) -> str:
     """Fetch a trait and return it as a csv string"""
-    __query = ("SELECT concat(st.Name, ',', ifnull(pd.value, 'x'), ',', "
-               "ifnull(ps.error, 'x'), ',', ifnull(ns.count, 'x')) as 'Data' "
-               ",ifnull(ca.Name, 'x') as 'CaseAttr', "
-               "ifnull(cxref.value, 'x') as 'Value' "
-               "FROM PublishFreeze pf "
-               "JOIN PublishXRef px ON px.InbredSetId = pf.InbredSetId "
-               "JOIN PublishData pd ON pd.Id = px.DataId "
-               "JOIN Strain st ON pd.StrainId = st.Id "
-               "LEFT JOIN PublishSE ps ON ps.DataId = pd.Id "
-               "AND ps.StrainId = pd.StrainId "
-               "LEFT JOIN NStrain ns ON ns.DataId = pd.Id "
-               "AND ns.StrainId = pd.StrainId "
-               "LEFT JOIN CaseAttributeXRefNew cxref ON "
-               "(cxref.InbredSetId = px.InbredSetId AND "
-               "cxref.StrainId = st.Id) "
-               "LEFT JOIN CaseAttribute ca ON ca.Id = cxref.CaseAttributeId "
-               "WHERE px.Id = %s AND px.PhenotypeId = %s ORDER BY st.Name")
+    __query = (
+        "SELECT concat(st.Name, ',', ifnull(pd.value, 'x'), ',', "
+        "ifnull(ps.error, 'x'), ',', ifnull(ns.count, 'x')) as 'Data' "
+        ",ifnull(ca.Name, 'x') as 'CaseAttr', "
+        "ifnull(cxref.value, 'x') as 'Value' "
+        "FROM PublishFreeze pf "
+        "JOIN PublishXRef px ON px.InbredSetId = pf.InbredSetId "
+        "JOIN PublishData pd ON pd.Id = px.DataId "
+        "JOIN Strain st ON pd.StrainId = st.Id "
+        "LEFT JOIN PublishSE ps ON ps.DataId = pd.Id "
+        "AND ps.StrainId = pd.StrainId "
+        "LEFT JOIN NStrain ns ON ns.DataId = pd.Id "
+        "AND ns.StrainId = pd.StrainId "
+        "LEFT JOIN CaseAttributeXRefNew cxref ON "
+        "(cxref.InbredSetId = px.InbredSetId AND "
+        "cxref.StrainId = st.Id) "
+        "LEFT JOIN CaseAttribute ca ON ca.Id = cxref.CaseAttributeId "
+        "WHERE px.Id = %s AND px.PhenotypeId = %s ORDER BY st.Name"
+    )
     case_attr_columns = set()
     csv_data: Dict = {}
     with conn.cursor() as cursor:
@@ -87,72 +94,79 @@ def get_trait_csv_sample_data(conn: Any,
                 csv_data[sample][case_attr] = None if value == "x" else value
                 case_attr_columns.add(case_attr)
         if not case_attr_columns:
-            return ("Strain Name,Value,SE,Count\n" +
-                    "\n".join(csv_data.keys()))
+            return "Strain Name,Value,SE,Count\n" + "\n".join(csv_data.keys())
         columns = sorted(case_attr_columns)
-        csv = ("Strain Name,Value,SE,Count," +
-               ",".join(columns) + "\n")
+        csv = "Strain Name,Value,SE,Count," + ",".join(columns) + "\n"
         for key, value in csv_data.items():
             if not value:
-                csv += (key + (len(case_attr_columns) * ",x") + "\n")
+                csv += key + (len(case_attr_columns) * ",x") + "\n"
             else:
                 vals = [str(value.get(column, "x")) for column in columns]
-                csv += (key + "," + ",".join(vals) + "\n")
+                csv += key + "," + ",".join(vals) + "\n"
         return csv
     return "No Sample Data Found"
 
 
-def get_sample_data_ids(conn: Any, publishxref_id: int,
-                        phenotype_id: int,
-                        strain_name: str) -> Tuple:
+def get_sample_data_ids(
+    conn: Any, publishxref_id: int, phenotype_id: int, strain_name: str
+) -> Tuple:
     """Get the strain_id, publishdata_id and inbredset_id for a given strain"""
     strain_id, publishdata_id, inbredset_id = None, None, None
     with conn.cursor() as cursor:
-        cursor.execute("SELECT st.id, pd.Id, pf.InbredSetId "
-                       "FROM PublishData pd "
-                       "JOIN Strain st ON pd.StrainId = st.Id "
-                       "JOIN PublishXRef px ON px.DataId = pd.Id "
-                       "JOIN PublishFreeze pf ON pf.InbredSetId "
-                       "= px.InbredSetId WHERE px.Id = %s "
-                       "AND px.PhenotypeId = %s AND st.Name = %s",
-                       (publishxref_id, phenotype_id, strain_name))
+        cursor.execute(
+            "SELECT st.id, pd.Id, pf.InbredSetId "
+            "FROM PublishData pd "
+            "JOIN Strain st ON pd.StrainId = st.Id "
+            "JOIN PublishXRef px ON px.DataId = pd.Id "
+            "JOIN PublishFreeze pf ON pf.InbredSetId "
+            "= px.InbredSetId WHERE px.Id = %s "
+            "AND px.PhenotypeId = %s AND st.Name = %s",
+            (publishxref_id, phenotype_id, strain_name),
+        )
         if _result := cursor.fetchone():
             strain_id, publishdata_id, inbredset_id = _result
         if not all([strain_id, publishdata_id, inbredset_id]):
             # Applies for data to be inserted:
-            cursor.execute("SELECT DataId, InbredSetId FROM PublishXRef "
-                           "WHERE Id = %s AND PhenotypeId = %s",
-                           (publishxref_id, phenotype_id))
+            cursor.execute(
+                "SELECT DataId, InbredSetId FROM PublishXRef "
+                "WHERE Id = %s AND PhenotypeId = %s",
+                (publishxref_id, phenotype_id),
+            )
             publishdata_id, inbredset_id = cursor.fetchone()
-            cursor.execute("SELECT Id FROM Strain WHERE Name = %s",
-                           (strain_name,))
+            cursor.execute(
+                "SELECT Id FROM Strain WHERE Name = %s", (strain_name,)
+            )
             strain_id = cursor.fetchone()[0]
     return (strain_id, publishdata_id, inbredset_id)
 
 
 # pylint: disable=[R0913, R0914]
-def update_sample_data(conn: Any,
-                       trait_name: str,
-                       original_data: str,
-                       updated_data: str,
-                       csv_header: str,
-                       phenotype_id: int) -> int:
+def update_sample_data(
+    conn: Any,
+    trait_name: str,
+    original_data: str,
+    updated_data: str,
+    csv_header: str,
+    phenotype_id: int,
+) -> int:
     """Given the right parameters, update sample-data from the relevant
     table."""
+
     def __update_data(conn, table, value):
         if value and value != "x":
             with conn.cursor() as cursor:
-                sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) +
-                             " = %s")
+                sub_query = " = %s AND ".join(_MAP.get(table)[:2]) + " = %s"
                 _val = _MAP.get(table)[-1]
-                cursor.execute((f"UPDATE {table} SET {_val} = %s "
-                                f"WHERE {sub_query}"),
-                               (value, strain_id, data_id))
+                cursor.execute(
+                    (f"UPDATE {table} SET {_val} = %s " f"WHERE {sub_query}"),
+                    (value, strain_id, data_id),
+                )
                 return cursor.rowcount
         return 0
 
-    def __update_case_attribute(conn, value, strain_id,
-                                case_attr, inbredset_id):
+    def __update_case_attribute(
+        conn, value, strain_id, case_attr, inbredset_id
+    ):
         if value != "x":
             with conn.cursor() as cursor:
                 cursor.execute(
@@ -162,14 +176,17 @@ def update_sample_data(conn: Any,
                     "(SELECT CaseAttributeId FROM "
                     "CaseAttribute WHERE Name = %s) "
                     "AND InbredSetId = %s",
-                    (value, strain_id, case_attr, inbredset_id))
+                    (value, strain_id, case_attr, inbredset_id),
+                )
                 return cursor.rowcount
         return 0
 
     strain_id, data_id, inbredset_id = get_sample_data_ids(
-        conn=conn, publishxref_id=int(trait_name),
+        conn=conn,
+        publishxref_id=int(trait_name),
         phenotype_id=phenotype_id,
-        strain_name=extract_strain_name(csv_header, original_data))
+        strain_name=extract_strain_name(csv_header, original_data),
+    )
 
     none_case_attrs: Dict[str, Callable] = {
         "Strain Name": lambda x: 0,
@@ -179,15 +196,16 @@ def update_sample_data(conn: Any,
     }
     count = 0
     try:
-        __actions = __extract_actions(original_data=original_data,
-                                      updated_data=updated_data,
-                                      csv_header=csv_header)
+        __actions = __extract_actions(
+            original_data=original_data,
+            updated_data=updated_data,
+            csv_header=csv_header,
+        )
         if __actions.get("update"):
             _csv_header = __actions["update"]["csv_header"]
             _data = __actions["update"]["data"]
             # pylint: disable=[E1101]
-            for header, value in zip(_csv_header.split(","),
-                                     _data.split(",")):
+            for header, value in zip(_csv_header.split(","), _data.split(",")):
                 header = header.strip()
                 value = value.strip()
                 if header in none_case_attrs:
@@ -198,14 +216,16 @@ def update_sample_data(conn: Any,
                         value=none_case_attrs[header](value),
                         strain_id=strain_id,
                         case_attr=header,
-                        inbredset_id=inbredset_id)
+                        inbredset_id=inbredset_id,
+                    )
         if __actions.get("delete"):
             _rowcount = delete_sample_data(
                 conn=conn,
                 trait_name=trait_name,
                 data=__actions["delete"]["data"],
                 csv_header=__actions["delete"]["csv_header"],
-                phenotype_id=phenotype_id)
+                phenotype_id=phenotype_id,
+            )
             if _rowcount:
                 count += 1
         if __actions.get("insert"):
@@ -214,7 +234,8 @@ def update_sample_data(conn: Any,
                 trait_name=trait_name,
                 data=__actions["insert"]["data"],
                 csv_header=__actions["insert"]["csv_header"],
-                phenotype_id=phenotype_id)
+                phenotype_id=phenotype_id,
+            )
             if _rowcount:
                 count += 1
     except Exception as _e:
@@ -224,23 +245,22 @@ def update_sample_data(conn: Any,
     return count
 
 
-def delete_sample_data(conn: Any,
-                       trait_name: str,
-                       data: str,
-                       csv_header: str,
-                       phenotype_id: int) -> int:
+def delete_sample_data(
+    conn: Any, trait_name: str, data: str, csv_header: str, phenotype_id: int
+) -> int:
     """Given the right parameters, delete sample-data from the relevant
     tables."""
+
     def __delete_data(conn, table):
-        sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
+        sub_query = " = %s AND ".join(_MAP.get(table)[:2]) + " = %s"
         with conn.cursor() as cursor:
-            cursor.execute((f"DELETE FROM {table} "
-                            f"WHERE {sub_query}"),
-                           (strain_id, data_id))
+            cursor.execute(
+                (f"DELETE FROM {table} " f"WHERE {sub_query}"),
+                (strain_id, data_id),
+            )
             return cursor.rowcount
 
-    def __delete_case_attribute(conn, strain_id,
-                                case_attr, inbredset_id):
+    def __delete_case_attribute(conn, strain_id, case_attr, inbredset_id):
         with conn.cursor() as cursor:
             cursor.execute(
                 "DELETE FROM CaseAttributeXRefNew "
@@ -248,13 +268,16 @@ def delete_sample_data(conn: Any,
                 "(SELECT CaseAttributeId FROM "
                 "CaseAttribute WHERE Name = %s) "
                 "AND InbredSetId = %s",
-                (strain_id, case_attr, inbredset_id))
+                (strain_id, case_attr, inbredset_id),
+            )
             return cursor.rowcount
 
     strain_id, data_id, inbredset_id = get_sample_data_ids(
-        conn=conn, publishxref_id=int(trait_name),
+        conn=conn,
+        publishxref_id=int(trait_name),
         phenotype_id=phenotype_id,
-        strain_name=extract_strain_name(csv_header, data))
+        strain_name=extract_strain_name(csv_header, data),
+    )
 
     none_case_attrs: Dict[str, Any] = {
         "Strain Name": lambda: 0,
@@ -274,7 +297,8 @@ def delete_sample_data(conn: Any,
                     conn=conn,
                     strain_id=strain_id,
                     case_attr=header,
-                    inbredset_id=inbredset_id)
+                    inbredset_id=inbredset_id,
+                )
     except Exception as _e:
         conn.rollback()
         raise MySQLdb.Error(_e) from _e
@@ -283,52 +307,59 @@ def delete_sample_data(conn: Any,
 
 
 # pylint: disable=[R0913, R0914]
-def insert_sample_data(conn: Any,
-                       trait_name: str,
-                       data: str,
-                       csv_header: str,
-                       phenotype_id: int) -> int:
-    """Given the right parameters, insert sample-data to the relevant table.
+def insert_sample_data(
+    conn: Any, trait_name: str, data: str, csv_header: str, phenotype_id: int
+) -> int:
+    """Given the right parameters, insert sample-data to the relevant table."""
 
-    """
     def __insert_data(conn, table, value):
         if value and value != "x":
             with conn.cursor() as cursor:
                 columns = ", ".join(_MAP.get(table))
-                cursor.execute((f"INSERT INTO {table} "
-                                f"({columns}) "
-                                f"VALUES (%s, %s, %s)"),
-                               (strain_id, data_id, value))
+                cursor.execute(
+                    (
+                        f"INSERT INTO {table} "
+                        f"({columns}) "
+                        f"VALUES (%s, %s, %s)"
+                    ),
+                    (strain_id, data_id, value),
+                )
                 return cursor.rowcount
         return 0
 
     def __insert_case_attribute(conn, case_attr, value):
         if value != "x":
             with conn.cursor() as cursor:
-                cursor.execute("SELECT Id FROM "
-                               "CaseAttribute WHERE Name = %s",
-                               (case_attr,))
+                cursor.execute(
+                    "SELECT Id FROM " "CaseAttribute WHERE Name = %s",
+                    (case_attr,),
+                )
                 if case_attr_id := cursor.fetchone():
                     case_attr_id = case_attr_id[0]
-                cursor.execute("SELECT StrainId FROM "
-                               "CaseAttributeXRefNew WHERE StrainId = %s "
-                               "AND CaseAttributeId = %s "
-                               "AND InbredSetId = %s",
-                               (strain_id, case_attr_id, inbredset_id))
+                cursor.execute(
+                    "SELECT StrainId FROM "
+                    "CaseAttributeXRefNew WHERE StrainId = %s "
+                    "AND CaseAttributeId = %s "
+                    "AND InbredSetId = %s",
+                    (strain_id, case_attr_id, inbredset_id),
+                )
                 if (not cursor.fetchone()) and case_attr_id:
                     cursor.execute(
                         "INSERT INTO CaseAttributeXRefNew "
                         "(StrainId, CaseAttributeId, Value, InbredSetId) "
                         "VALUES (%s, %s, %s, %s)",
-                        (strain_id, case_attr_id, value, inbredset_id))
+                        (strain_id, case_attr_id, value, inbredset_id),
+                    )
                     row_count = cursor.rowcount
                     return row_count
         return 0
 
     strain_id, data_id, inbredset_id = get_sample_data_ids(
-        conn=conn, publishxref_id=int(trait_name),
+        conn=conn,
+        publishxref_id=int(trait_name),
         phenotype_id=phenotype_id,
-        strain_name=extract_strain_name(csv_header, data))
+        strain_name=extract_strain_name(csv_header, data),
+    )
 
     none_case_attrs: Dict[str, Any] = {
         "Strain Name": lambda _: 0,
@@ -345,7 +376,8 @@ def insert_sample_data(conn: Any,
             cursor.execute(
                 "SELECT Id FROM PublishData where Id = %s "
                 "AND StrainId = %s",
-                (data_id, strain_id))
+                (data_id, strain_id),
+            )
         if cursor.fetchone():  # Data already exists
             return count
 
@@ -356,9 +388,8 @@ def insert_sample_data(conn: Any,
                 count += none_case_attrs[header](value)
             else:
                 count += __insert_case_attribute(
-                    conn=conn,
-                    case_attr=header,
-                    value=value)
+                    conn=conn, case_attr=header, value=value
+                )
         return count
     except Exception as _e:
         conn.rollback()