aboutsummaryrefslogtreecommitdiff
path: root/gn3/db/sample_data.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/db/sample_data.py')
-rw-r--r--gn3/db/sample_data.py365
1 files changed, 365 insertions, 0 deletions
diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py
new file mode 100644
index 0000000..f73954f
--- /dev/null
+++ b/gn3/db/sample_data.py
@@ -0,0 +1,365 @@
+"""Module containing functions that work with sample data"""
+from typing import Any, Tuple, Dict, Callable
+
+import MySQLdb
+
+from gn3.csvcmp import extract_strain_name
+
+
+_MAP = {
+ "PublishData": ("StrainId", "Id", "value"),
+ "PublishSE": ("StrainId", "DataId", "error"),
+ "NStrain": ("StrainId", "DataId", "count"),
+}
+
+
+def __extract_actions(original_data: str,
+ updated_data: str,
+ csv_header: str) -> Dict:
+ """Return a dictionary containing elements that need to be deleted, inserted,
+or updated.
+
+ """
+ result: Dict[str, Any] = {
+ "delete": {"data": [], "csv_header": []},
+ "insert": {"data": [], "csv_header": []},
+ "update": {"data": [], "csv_header": []},
+ }
+ strain_name = ""
+ for _o, _u, _h in zip(original_data.strip().split(","),
+ updated_data.strip().split(","),
+ csv_header.strip().split(",")):
+ if _h == "Strain Name":
+ strain_name = _o
+ if _o == _u: # No change
+ continue
+ if _o and _u == "x": # Deletion
+ result["delete"]["data"].append(_o)
+ result["delete"]["csv_header"].append(_h)
+ elif _o == "x" and _u: # Insert
+ result["insert"]["data"].append(_u)
+ result["insert"]["csv_header"].append(_h)
+ elif _o and _u: # Update
+ result["update"]["data"].append(_u)
+ result["update"]["csv_header"].append(_h)
+ for key, val in result.items():
+ if not val["data"]:
+ result[key] = None
+ else:
+ result[key]["data"] = (f"{strain_name}," +
+ ",".join(result[key]["data"]))
+ result[key]["csv_header"] = ("Strain Name," +
+ ",".join(result[key]["csv_header"]))
+ return result
+
+
+def get_trait_csv_sample_data(conn: Any,
+ trait_name: int, phenotype_id: int) -> str:
+ """Fetch a trait and return it as a csv string"""
+ __query = ("SELECT concat(st.Name, ',', ifnull(pd.value, 'x'), ',', "
+ "ifnull(ps.error, 'x'), ',', ifnull(ns.count, 'x')) as 'Data' "
+ ",ifnull(ca.Name, 'x') as 'CaseAttr', "
+ "ifnull(cxref.value, 'x') as 'Value' "
+ "FROM PublishFreeze pf "
+ "JOIN PublishXRef px ON px.InbredSetId = pf.InbredSetId "
+ "JOIN PublishData pd ON pd.Id = px.DataId "
+ "JOIN Strain st ON pd.StrainId = st.Id "
+ "LEFT JOIN PublishSE ps ON ps.DataId = pd.Id "
+ "AND ps.StrainId = pd.StrainId "
+ "LEFT JOIN NStrain ns ON ns.DataId = pd.Id "
+ "AND ns.StrainId = pd.StrainId "
+ "LEFT JOIN CaseAttributeXRefNew cxref ON "
+ "(cxref.InbredSetId = px.InbredSetId AND "
+ "cxref.StrainId = st.Id) "
+ "LEFT JOIN CaseAttribute ca ON ca.Id = cxref.CaseAttributeId "
+ "WHERE px.Id = %s AND px.PhenotypeId = %s ORDER BY st.Name")
+ case_attr_columns = set()
+ csv_data: Dict = {}
+ with conn.cursor() as cursor:
+ cursor.execute(__query, (trait_name, phenotype_id))
+ for data in cursor.fetchall():
+ if data[1] == "x":
+ csv_data[data[0]] = None
+ else:
+ sample, case_attr, value = data[0], data[1], data[2]
+ if not csv_data.get(sample):
+ csv_data[sample] = {}
+ csv_data[sample][case_attr] = None if value == "x" else value
+ case_attr_columns.add(case_attr)
+ if not case_attr_columns:
+ return ("Strain Name,Value,SE,Count\n" +
+ "\n".join(csv_data.keys()))
+ columns = sorted(case_attr_columns)
+ csv = ("Strain Name,Value,SE,Count," +
+ ",".join(columns) + "\n")
+ for key, value in csv_data.items():
+ if not value:
+ csv += (key + (len(case_attr_columns) * ",x") + "\n")
+ else:
+ vals = [str(value.get(column, "x")) for column in columns]
+ csv += (key + "," + ",".join(vals) + "\n")
+ return csv
+ return "No Sample Data Found"
+
+
+def get_sample_data_ids(conn: Any, publishxref_id: int,
+ phenotype_id: int,
+ strain_name: str) -> Tuple:
+ """Get the strain_id, publishdata_id and inbredset_id for a given strain"""
+ strain_id, publishdata_id, inbredset_id = None, None, None
+ with conn.cursor() as cursor:
+ cursor.execute("SELECT st.id, pd.Id, pf.InbredSetId "
+ "FROM PublishData pd "
+ "JOIN Strain st ON pd.StrainId = st.Id "
+ "JOIN PublishXRef px ON px.DataId = pd.Id "
+ "JOIN PublishFreeze pf ON pf.InbredSetId "
+ "= px.InbredSetId WHERE px.Id = %s "
+ "AND px.PhenotypeId = %s AND st.Name = %s",
+ (publishxref_id, phenotype_id, strain_name))
+ if _result := cursor.fetchone():
+ strain_id, publishdata_id, inbredset_id = _result
+ if not all([strain_id, publishdata_id, inbredset_id]):
+ # Applies for data to be inserted:
+ cursor.execute("SELECT DataId, InbredSetId FROM PublishXRef "
+ "WHERE Id = %s AND PhenotypeId = %s",
+ (publishxref_id, phenotype_id))
+ publishdata_id, inbredset_id = cursor.fetchone()
+ cursor.execute("SELECT Id FROM Strain WHERE Name = %s",
+ (strain_name,))
+ strain_id = cursor.fetchone()[0]
+ return (strain_id, publishdata_id, inbredset_id)
+
+
+# pylint: disable=[R0913, R0914]
+def update_sample_data(conn: Any,
+ trait_name: str,
+ original_data: str,
+ updated_data: str,
+ csv_header: str,
+ phenotype_id: int) -> int:
+ """Given the right parameters, update sample-data from the relevant
+ table."""
+ def __update_data(conn, table, value):
+ if value and value != "x":
+ with conn.cursor() as cursor:
+ sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) +
+ " = %s")
+ _val = _MAP.get(table)[-1]
+ cursor.execute((f"UPDATE {table} SET {_val} = %s "
+ f"WHERE {sub_query}"),
+ (value, strain_id, data_id))
+ return cursor.rowcount
+ return 0
+
+ def __update_case_attribute(conn, value, strain_id,
+ case_attr, inbredset_id):
+ if value != "x":
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "UPDATE CaseAttributeXRefNew "
+ "SET Value = %s "
+ "WHERE StrainId = %s AND CaseAttributeId = "
+ "(SELECT CaseAttributeId FROM "
+ "CaseAttribute WHERE Name = %s) "
+ "AND InbredSetId = %s",
+ (value, strain_id, case_attr, inbredset_id))
+ return cursor.rowcount
+ return 0
+
+ strain_id, data_id, inbredset_id = get_sample_data_ids(
+ conn=conn, publishxref_id=int(trait_name),
+ phenotype_id=phenotype_id,
+ strain_name=extract_strain_name(csv_header, original_data))
+
+ none_case_attrs: Dict[str, Callable] = {
+ "Strain Name": lambda x: 0,
+ "Value": lambda x: __update_data(conn, "PublishData", x),
+ "SE": lambda x: __update_data(conn, "PublishSE", x),
+ "Count": lambda x: __update_data(conn, "NStrain", x),
+ }
+ count = 0
+ try:
+ __actions = __extract_actions(original_data=original_data,
+ updated_data=updated_data,
+ csv_header=csv_header)
+ if __actions.get("update"):
+ _csv_header = __actions["update"]["csv_header"]
+ _data = __actions["update"]["data"]
+ # pylint: disable=[E1101]
+ for header, value in zip(_csv_header.split(","),
+ _data.split(",")):
+ header = header.strip()
+ value = value.strip()
+ if header in none_case_attrs:
+ count += none_case_attrs[header](value)
+ else:
+ count += __update_case_attribute(
+ conn=conn,
+ value=none_case_attrs[header](value),
+ strain_id=strain_id,
+ case_attr=header,
+ inbredset_id=inbredset_id)
+ if __actions.get("delete"):
+ _rowcount = delete_sample_data(
+ conn=conn,
+ trait_name=trait_name,
+ data=__actions["delete"]["data"],
+ csv_header=__actions["delete"]["csv_header"],
+ phenotype_id=phenotype_id)
+ if _rowcount:
+ count += 1
+ if __actions.get("insert"):
+ _rowcount = insert_sample_data(
+ conn=conn,
+ trait_name=trait_name,
+ data=__actions["insert"]["data"],
+ csv_header=__actions["insert"]["csv_header"],
+ phenotype_id=phenotype_id)
+ if _rowcount:
+ count += 1
+ except Exception as _e:
+ conn.rollback()
+ raise MySQLdb.Error(_e) from _e
+ conn.commit()
+ return count
+
+
+def delete_sample_data(conn: Any,
+ trait_name: str,
+ data: str,
+ csv_header: str,
+ phenotype_id: int) -> int:
+ """Given the right parameters, delete sample-data from the relevant
+ tables."""
+ def __delete_data(conn, table):
+ sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
+ with conn.cursor() as cursor:
+ cursor.execute((f"DELETE FROM {table} "
+ f"WHERE {sub_query}"),
+ (strain_id, data_id))
+ return cursor.rowcount
+
+ def __delete_case_attribute(conn, strain_id,
+ case_attr, inbredset_id):
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "DELETE FROM CaseAttributeXRefNew "
+ "WHERE StrainId = %s AND CaseAttributeId = "
+ "(SELECT CaseAttributeId FROM "
+ "CaseAttribute WHERE Name = %s) "
+ "AND InbredSetId = %s",
+ (strain_id, case_attr, inbredset_id))
+ return cursor.rowcount
+
+ strain_id, data_id, inbredset_id = get_sample_data_ids(
+ conn=conn, publishxref_id=int(trait_name),
+ phenotype_id=phenotype_id,
+ strain_name=extract_strain_name(csv_header, data))
+
+ none_case_attrs: Dict[str, Any] = {
+ "Strain Name": lambda: 0,
+ "Value": lambda: __delete_data(conn, "PublishData"),
+ "SE": lambda: __delete_data(conn, "PublishSE"),
+ "Count": lambda: __delete_data(conn, "NStrain"),
+ }
+ count = 0
+
+ try:
+ for header in csv_header.split(","):
+ header = header.strip()
+ if header in none_case_attrs:
+ count += none_case_attrs[header]()
+ else:
+ count += __delete_case_attribute(
+ conn=conn,
+ strain_id=strain_id,
+ case_attr=header,
+ inbredset_id=inbredset_id)
+ except Exception as _e:
+ conn.rollback()
+ raise MySQLdb.Error(_e) from _e
+ conn.commit()
+ return count
+
+
+# pylint: disable=[R0913, R0914]
+def insert_sample_data(conn: Any,
+ trait_name: str,
+ data: str,
+ csv_header: str,
+ phenotype_id: int) -> int:
+ """Given the right parameters, insert sample-data to the relevant table.
+
+ """
+ def __insert_data(conn, table, value):
+ if value and value != "x":
+ with conn.cursor() as cursor:
+ columns = ", ".join(_MAP.get(table))
+ cursor.execute((f"INSERT INTO {table} "
+ f"({columns}) "
+ f"VALUES (%s, %s, %s)"),
+ (strain_id, data_id, value))
+ return cursor.rowcount
+ return 0
+
+ def __insert_case_attribute(conn, case_attr, value):
+ if value != "x":
+ with conn.cursor() as cursor:
+ cursor.execute("SELECT Id FROM "
+ "CaseAttribute WHERE Name = %s",
+ (case_attr,))
+ if case_attr_id := cursor.fetchone():
+ case_attr_id = case_attr_id[0]
+ cursor.execute("SELECT StrainId FROM "
+ "CaseAttributeXRefNew WHERE StrainId = %s "
+ "AND CaseAttributeId = %s "
+ "AND InbredSetId = %s",
+ (strain_id, case_attr_id, inbredset_id))
+ if (not cursor.fetchone()) and case_attr_id:
+ cursor.execute(
+ "INSERT INTO CaseAttributeXRefNew "
+ "(StrainId, CaseAttributeId, Value, InbredSetId) "
+ "VALUES (%s, %s, %s, %s)",
+ (strain_id, case_attr_id, value, inbredset_id))
+ row_count = cursor.rowcount
+ return row_count
+ return 0
+
+ strain_id, data_id, inbredset_id = get_sample_data_ids(
+ conn=conn, publishxref_id=int(trait_name),
+ phenotype_id=phenotype_id,
+ strain_name=extract_strain_name(csv_header, data))
+
+ none_case_attrs: Dict[str, Any] = {
+ "Strain Name": lambda _: 0,
+ "Value": lambda x: __insert_data(conn, "PublishData", x),
+ "SE": lambda x: __insert_data(conn, "PublishSE", x),
+ "Count": lambda x: __insert_data(conn, "NStrain", x),
+ }
+
+ try:
+ count = 0
+
+ # Check if the data already exists:
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "SELECT Id FROM PublishData where Id = %s "
+ "AND StrainId = %s",
+ (data_id, strain_id))
+ if cursor.fetchone(): # Data already exists
+ return count
+
+ for header, value in zip(csv_header.split(","), data.split(",")):
+ header = header.strip()
+ value = value.strip()
+ if header in none_case_attrs:
+ count += none_case_attrs[header](value)
+ else:
+ count += __insert_case_attribute(
+ conn=conn,
+ case_attr=header,
+ value=value)
+ return count
+ except Exception as _e:
+ conn.rollback()
+ raise MySQLdb.Error(_e) from _e