aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/csvcmp.py64
-rw-r--r--gn3/db/sample_data.py93
2 files changed, 85 insertions, 72 deletions
diff --git a/gn3/csvcmp.py b/gn3/csvcmp.py
index ac09cc3..82d74d0 100644
--- a/gn3/csvcmp.py
+++ b/gn3/csvcmp.py
@@ -1,50 +1,59 @@
+"""This module contains functions for manipulating and working with csv
+texts"""
import json
import os
import uuid
from gn3.commands import run_cmd
-def extract_strain_name(csv_header, data, seek="Strain Name"):
+def extract_strain_name(csv_header, data, seek="Strain Name") -> str:
+ """Extract a strain's name given a csv header"""
for column, value in zip(csv_header.split(","), data.split(",")):
if seek in column:
return value
return ""
-def create_dirs_if_not_exists(dirs: list):
+def create_dirs_if_not_exists(dirs: list) -> None:
+ """Create directories from a list"""
for dir_ in dirs:
if not os.path.exists(dir_):
os.makedirs(dir_)
def remove_insignificant_edits(diff_data, epsilon=0.001):
- _mod = []
+ """Remove or ignore edits that are not within ε"""
+ __mod = []
for mod in diff_data.get("Modifications"):
original = mod.get("Original").split(",")
current = mod.get("Current").split(",")
- for i, (x, y) in enumerate(zip(original, current)):
- if (x.replace('.', '').isdigit() and
- y.replace('.', '').isdigit() and
- abs(float(x) - float(y)) < epsilon):
- current[i] = x
+ for i, (_x, _y) in enumerate(zip(original, current)):
+ if (
+ _x.replace(".", "").isdigit()
+ and _y.replace(".", "").isdigit()
+ and abs(float(_x) - float(_y)) < epsilon
+ ):
+ current[i] = _x
if not (__o := ",".join(original)) == (__c := ",".join(current)):
- _mod.append({
- "Original": __o,
- "Current": __c,
- })
- diff_data['Modifications'] = _mod
+ __mod.append(
+ {
+ "Original": __o,
+ "Current": __c,
+ }
+ )
+ diff_data["Modifications"] = __mod
return diff_data
-def csv_diff(base_csv, delta_csv, tmp_dir="/tmp"):
+def csv_diff(base_csv, delta_csv, tmp_dir="/tmp") -> dict:
+ """Diff 2 csv strings"""
base_csv_list = base_csv.strip().split("\n")
delta_csv_list = delta_csv.strip().split("\n")
- base_csv_header, delta_csv_header, header = "", "", ""
+ base_csv_header, delta_csv_header = "", ""
for i, line in enumerate(base_csv_list):
if line.startswith("Strain Name,Value,SE,Count"):
- header = line
- base_csv_header, delta_csv_header= line, delta_csv_list[i]
+ base_csv_header, delta_csv_header = line, delta_csv_list[i]
break
longest_header = max(base_csv_header, delta_csv_header)
@@ -53,22 +62,21 @@ def csv_diff(base_csv, delta_csv, tmp_dir="/tmp"):
base_csv = base_csv.replace("Strain Name,Value,SE,Count",
longest_header, 1)
else:
- delta_csv = delta_csv.replace("Strain Name,Value,SE,Count",
- longest_header, 1)
+ delta_csv = delta_csv.replace(
+ "Strain Name,Value,SE,Count", longest_header, 1
+ )
file_name1 = os.path.join(tmp_dir, str(uuid.uuid4()))
file_name2 = os.path.join(tmp_dir, str(uuid.uuid4()))
- with open(file_name1, "w") as f_:
+ with open(file_name1, "w", encoding="utf-8") as _f:
_l = len(longest_header.split(","))
- f_.write(fill_csv(csv_text=base_csv,
- width=_l))
- with open(file_name2, "w") as f_:
- f_.write(fill_csv(delta_csv,
- width=_l))
+ _f.write(fill_csv(csv_text=base_csv, width=_l))
+ with open(file_name2, "w", encoding="utf-8") as _f:
+ _f.write(fill_csv(delta_csv, width=_l))
# Now we can run the diff!
_r = run_cmd(cmd=('"csvdiff '
- f'{file_name1} {file_name2} '
+ f"{file_name1} {file_name2} "
'--format json"'))
if _r.get("code") == 0:
_r = json.loads(_r.get("output"))
@@ -86,6 +94,7 @@ def csv_diff(base_csv, delta_csv, tmp_dir="/tmp"):
def fill_csv(csv_text, width, value="x"):
+ """Fill a csv text with 'value' if it's length is less than width"""
data = []
for line in csv_text.strip().split("\n"):
if line.startswith("Strain") or line.startswith("#"):
@@ -95,6 +104,5 @@ def fill_csv(csv_text, width, value="x"):
for i, val in enumerate(_n):
if not val.strip():
_n[i] = value
- data.append(
- ",".join(_n + [value] * (width - len(_n))))
+ data.append(",".join(_n + [value] * (width - len(_n))))
return "\n".join(data)
diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py
index b2d6aed..e3daa21 100644
--- a/gn3/db/sample_data.py
+++ b/gn3/db/sample_data.py
@@ -1,7 +1,8 @@
-from gn3.csvcmp import extract_strain_name
-from typing import Any, Tuple
+"""Module containing functions that work with sample data"""
+from typing import Any, Tuple, Dict, Callable
import MySQLdb
+from gn3.csvcmp import extract_strain_name
_MAP = {
@@ -14,6 +15,10 @@ _MAP = {
def __extract_actions(original_data: str,
updated_data: str,
csv_header: str) -> dict:
+ """Return a dictionary containing elements that need to be deleted, inserted,
+or updated.
+
+ """
original_data = original_data.strip().split(",")
updated_data = updated_data.strip().split(",")
csv_header = csv_header.strip().split(",")
@@ -28,7 +33,7 @@ def __extract_actions(original_data: str,
strain_name = _o
if _o == _u: # No change
continue
- elif _o and _u == "x": # Deletion
+ if _o and _u == "x": # Deletion
result["delete"]["data"].append(_o)
result["delete"]["csv_header"].append(_h)
elif _o == "x" and _u: # Insert
@@ -84,23 +89,23 @@ def get_trait_csv_sample_data(conn: Any,
if not case_attr_columns:
return ("Strain Name,Value,SE,Count\n" +
"\n".join(csv_data.keys()))
- else:
- columns = sorted(case_attr_columns)
- csv = ("Strain Name,Value,SE,Count," +
- ",".join(columns) + "\n")
- for key, value in csv_data.items():
- if not value:
- csv += (key + (len(case_attr_columns) * ",x") + "\n")
- else:
- vals = [str(value.get(column, "x")) for column in columns]
- csv += (key + "," + ",".join(vals) + "\n")
- return csv
+ columns = sorted(case_attr_columns)
+ csv = ("Strain Name,Value,SE,Count," +
+ ",".join(columns) + "\n")
+ for key, value in csv_data.items():
+ if not value:
+ csv += (key + (len(case_attr_columns) * ",x") + "\n")
+ else:
+ vals = [str(value.get(column, "x")) for column in columns]
+ csv += (key + "," + ",".join(vals) + "\n")
+ return csv
return "No Sample Data Found"
def get_sample_data_ids(conn: Any, publishxref_id: int,
phenotype_id: int,
strain_name: str) -> Tuple:
+ """Get the strain_id, publishdata_id and inbredset_id for a given strain"""
strain_id, publishdata_id, inbredset_id = None, None, None
with conn.cursor() as cursor:
cursor.execute("SELECT st.id, pd.Id, pf.InbredSetId "
@@ -125,7 +130,8 @@ def get_sample_data_ids(conn: Any, publishxref_id: int,
return (strain_id, publishdata_id, inbredset_id)
-def update_sample_data(conn: Any, # pylint: disable=[R0913]
+# pylint: disable=[R0913, R0914]
+def update_sample_data(conn: Any,
trait_name: str,
original_data: str,
updated_data: str,
@@ -136,12 +142,14 @@ def update_sample_data(conn: Any, # pylint: disable=[R0913]
def __update_data(conn, table, value):
if value and value != "x":
with conn.cursor() as cursor:
- sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
+ sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) +
+ " = %s")
_val = _MAP.get(table)[-1]
cursor.execute((f"UPDATE {table} SET {_val} = %s "
f"WHERE {sub_query}"),
(value, strain_id, data_id))
return cursor.rowcount
+ return 0
def __update_case_attribute(conn, value, strain_id,
case_attr, inbredset_id):
@@ -177,6 +185,7 @@ def update_sample_data(conn: Any, # pylint: disable=[R0913]
if __actions.get("update"):
_csv_header = __actions["update"]["csv_header"]
_data = __actions["update"]["data"]
+ # pylint: disable=[E1101]
for header, value in zip(_csv_header.split(","),
_data.split(",")):
header = header.strip()
@@ -208,9 +217,9 @@ def update_sample_data(conn: Any, # pylint: disable=[R0913]
phenotype_id=phenotype_id)
if _rowcount:
count += 1
- except Exception as e: # pylint: disable=[C0103, W0612]
+ except Exception as _e:
conn.rollback()
- raise MySQLdb.Error
+ raise MySQLdb.Error(_e) from _e
conn.commit()
return count
@@ -223,28 +232,24 @@ def delete_sample_data(conn: Any,
"""Given the right parameters, delete sample-data from the relevant
tables."""
def __delete_data(conn, table):
- if value and value != "x":
- sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
- with conn.cursor() as cursor:
- cursor.execute((f"DELETE FROM {table} "
- f"WHERE {sub_query}"),
- (strain_id, data_id))
- return cursor.rowcount
- return 0
+ sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
+ with conn.cursor() as cursor:
+ cursor.execute((f"DELETE FROM {table} "
+ f"WHERE {sub_query}"),
+ (strain_id, data_id))
+ return cursor.rowcount
def __delete_case_attribute(conn, strain_id,
case_attr, inbredset_id):
- if value != "x":
- with conn.cursor() as cursor:
- cursor.execute(
- "DELETE FROM CaseAttributeXRefNew "
- "WHERE StrainId = %s AND CaseAttributeId = "
- "(SELECT CaseAttributeId FROM "
- "CaseAttribute WHERE Name = %s) "
- "AND InbredSetId = %s",
- (strain_id, case_attr, inbredset_id))
- return cursor.rowcount
- return 0
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "DELETE FROM CaseAttributeXRefNew "
+ "WHERE StrainId = %s AND CaseAttributeId = "
+ "(SELECT CaseAttributeId FROM "
+ "CaseAttribute WHERE Name = %s) "
+ "AND InbredSetId = %s",
+ (strain_id, case_attr, inbredset_id))
+ return cursor.rowcount
strain_id, data_id, inbredset_id = get_sample_data_ids(
conn=conn, publishxref_id=trait_name,
@@ -260,9 +265,8 @@ def delete_sample_data(conn: Any,
count = 0
try:
- for header, value in zip(csv_header.split(","), data.split(",")):
+ for header in csv_header.split(","):
header = header.strip()
- value = value.strip()
if header in none_case_attrs:
count += none_case_attrs.get(header)()
else:
@@ -271,14 +275,15 @@ def delete_sample_data(conn: Any,
strain_id=strain_id,
case_attr=header,
inbredset_id=inbredset_id)
- except Exception as e: # pylint: disable=[C0103, W0612]
+ except Exception as _e:
conn.rollback()
- raise MySQLdb.Error
+ raise MySQLdb.Error(_e) from _e
conn.commit()
return count
-def insert_sample_data(conn: Any, # pylint: disable=[R0913]
+# pylint: disable=[R0913, R0914]
+def insert_sample_data(conn: Any,
trait_name: str,
data: str,
csv_header: str,
@@ -355,6 +360,6 @@ def insert_sample_data(conn: Any, # pylint: disable=[R0913]
case_attr=header,
value=value)
return count
- except Exception as e: # pylint: disable=[C0103, W0612]
+ except Exception as _e:
conn.rollback()
- raise MySQLdb.Error
+ raise MySQLdb.Error(_e) from _e