diff options
Diffstat (limited to 'gn3/csvcmp.py')
-rw-r--r-- | gn3/csvcmp.py | 146 |
1 files changed, 146 insertions, 0 deletions
diff --git a/gn3/csvcmp.py b/gn3/csvcmp.py new file mode 100644 index 0000000..8db89ca --- /dev/null +++ b/gn3/csvcmp.py @@ -0,0 +1,146 @@ +"""This module contains functions for manipulating and working with csv +texts""" +from typing import Any, List + +import json +import os +import uuid +from gn3.commands import run_cmd + + +def extract_strain_name(csv_header, data, seek="Strain Name") -> str: + """Extract a strain's name given a csv header""" + for column, value in zip(csv_header.split(","), data.split(",")): + if seek in column: + return value + return "" + + +def create_dirs_if_not_exists(dirs: list) -> None: + """Create directories from a list""" + for dir_ in dirs: + if not os.path.exists(dir_): + os.makedirs(dir_) + + +def remove_insignificant_edits(diff_data, epsilon=0.001): + """Remove or ignore edits that are not within ε""" + __mod = [] + if diff_data.get("Modifications"): + for mod in diff_data.get("Modifications"): + original = mod.get("Original").split(",") + current = mod.get("Current").split(",") + for i, (_x, _y) in enumerate(zip(original, current)): + if ( + _x.replace(".", "").isdigit() + and _y.replace(".", "").isdigit() + and abs(float(_x) - float(_y)) < epsilon + ): + current[i] = _x + if not (__o := ",".join(original)) == (__c := ",".join(current)): + __mod.append( + { + "Original": __o, + "Current": __c, + } + ) + diff_data["Modifications"] = __mod + return diff_data + + +def clean_csv_text(csv_text: str) -> str: + """Remove extra white space elements in all elements of the CSV file""" + _csv_text = [] + for line in csv_text.strip().split("\n"): + _csv_text.append( + ",".join([el.strip() for el in line.split(",")])) + return "\n".join(_csv_text) + + +def csv_diff(base_csv, delta_csv, tmp_dir="/tmp") -> dict: + """Diff 2 csv strings""" + base_csv = clean_csv_text(base_csv) + delta_csv = clean_csv_text(delta_csv) + base_csv_list = base_csv.split("\n") + delta_csv_list = delta_csv.split("\n") + + base_csv_header, delta_csv_header = "", "" + for i, line in enumerate(base_csv_list): + if line.startswith("Strain Name,Value,SE,Count"): + base_csv_header, delta_csv_header = line, delta_csv_list[i] + break + longest_header = max(base_csv_header, delta_csv_header) + + if base_csv_header != delta_csv_header: + if longest_header != base_csv_header: + base_csv = base_csv.replace("Strain Name,Value,SE,Count", + longest_header, 1) + else: + delta_csv = delta_csv.replace( + "Strain Name,Value,SE,Count", longest_header, 1 + ) + file_name1 = os.path.join(tmp_dir, str(uuid.uuid4())) + file_name2 = os.path.join(tmp_dir, str(uuid.uuid4())) + + with open(file_name1, "w", encoding="utf-8") as _f: + _l = len(longest_header.split(",")) + _f.write(fill_csv(csv_text=base_csv, width=_l)) + with open(file_name2, "w", encoding="utf-8") as _f: + _f.write(fill_csv(delta_csv, width=_l)) + + # Now we can run the diff! + _r = run_cmd(cmd=('"csvdiff ' + f"{file_name1} {file_name2} " + '--format json"')) + if _r.get("code") == 0: + _r = json.loads(_r.get("output", "")) + if any(_r.values()): + _r["Columns"] = max(base_csv_header, delta_csv_header) + else: + _r = {} + + # Clean Up! + if os.path.exists(file_name1): + os.remove(file_name1) + if os.path.exists(file_name2): + os.remove(file_name2) + return _r + + +def fill_csv(csv_text, width, value="x"): + """Fill a csv text with 'value' if it's length is less than width""" + data = [] + for line in csv_text.strip().split("\n"): + if line.startswith("Strain") or line.startswith("#"): + data.append(line) + elif line: + _n = line.split(",") + for i, val in enumerate(_n): + if not val.strip(): + _n[i] = value + data.append(",".join(_n + [value] * (width - len(_n)))) + return "\n".join(data) + + +def get_allowable_sampledata_headers(conn: Any) -> List: + """Get a list of all the case-attributes stored in the database""" + attributes = ["Strain Name", "Value", "SE", "Count"] + with conn.cursor() as cursor: + cursor.execute("SELECT Name from CaseAttribute") + attributes += [attributes[0] for attributes in + cursor.fetchall()] + return attributes + + +def extract_invalid_csv_headers(allowed_headers: List, csv_text: str) -> List: + """Check whether a csv text's columns contains valid headers""" + csv_header = [] + for line in csv_text.split("\n"): + if line.startswith("Strain Name"): + csv_header = [_l.strip() for _l in line.split(",")] + break + invalid_headers = [] + for header in csv_header: + if header not in allowed_headers: + invalid_headers.append(header) + return invalid_headers |