about summary refs log tree commit diff
path: root/gn3/csvcmp.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/csvcmp.py')
-rw-r--r--gn3/csvcmp.py146
1 files changed, 146 insertions, 0 deletions
diff --git a/gn3/csvcmp.py b/gn3/csvcmp.py
new file mode 100644
index 0000000..8db89ca
--- /dev/null
+++ b/gn3/csvcmp.py
@@ -0,0 +1,146 @@
+"""This module contains functions for manipulating and working with csv
+texts"""
+from typing import Any, List
+
+import json
+import os
+import uuid
+from gn3.commands import run_cmd
+
+
+def extract_strain_name(csv_header, data, seek="Strain Name") -> str:
+    """Extract a strain's name given a csv header"""
+    for column, value in zip(csv_header.split(","), data.split(",")):
+        if seek in column:
+            return value
+    return ""
+
+
+def create_dirs_if_not_exists(dirs: list) -> None:
+    """Create directories from a list"""
+    for dir_ in dirs:
+        if not os.path.exists(dir_):
+            os.makedirs(dir_)
+
+
+def remove_insignificant_edits(diff_data, epsilon=0.001):
+    """Remove or ignore edits that are not within ε"""
+    __mod = []
+    if diff_data.get("Modifications"):
+        for mod in diff_data.get("Modifications"):
+            original = mod.get("Original").split(",")
+            current = mod.get("Current").split(",")
+            for i, (_x, _y) in enumerate(zip(original, current)):
+                if (
+                    _x.replace(".", "").isdigit()
+                    and _y.replace(".", "").isdigit()
+                    and abs(float(_x) - float(_y)) < epsilon
+                ):
+                    current[i] = _x
+            if not (__o := ",".join(original)) == (__c := ",".join(current)):
+                __mod.append(
+                    {
+                        "Original": __o,
+                        "Current": __c,
+                    }
+                )
+        diff_data["Modifications"] = __mod
+    return diff_data
+
+
+def clean_csv_text(csv_text: str) -> str:
+    """Remove extra white space elements in all elements of the CSV file"""
+    _csv_text = []
+    for line in csv_text.strip().split("\n"):
+        _csv_text.append(
+            ",".join([el.strip() for el in line.split(",")]))
+    return "\n".join(_csv_text)
+
+
+def csv_diff(base_csv, delta_csv, tmp_dir="/tmp") -> dict:
+    """Diff 2 csv strings"""
+    base_csv = clean_csv_text(base_csv)
+    delta_csv = clean_csv_text(delta_csv)
+    base_csv_list = base_csv.split("\n")
+    delta_csv_list = delta_csv.split("\n")
+
+    base_csv_header, delta_csv_header = "", ""
+    for i, line in enumerate(base_csv_list):
+        if line.startswith("Strain Name,Value,SE,Count"):
+            base_csv_header, delta_csv_header = line, delta_csv_list[i]
+            break
+    longest_header = max(base_csv_header, delta_csv_header)
+
+    if base_csv_header != delta_csv_header:
+        if longest_header != base_csv_header:
+            base_csv = base_csv.replace("Strain Name,Value,SE,Count",
+                                        longest_header, 1)
+        else:
+            delta_csv = delta_csv.replace(
+                "Strain Name,Value,SE,Count", longest_header, 1
+            )
+    file_name1 = os.path.join(tmp_dir, str(uuid.uuid4()))
+    file_name2 = os.path.join(tmp_dir, str(uuid.uuid4()))
+
+    with open(file_name1, "w", encoding="utf-8") as _f:
+        _l = len(longest_header.split(","))
+        _f.write(fill_csv(csv_text=base_csv, width=_l))
+    with open(file_name2, "w", encoding="utf-8") as _f:
+        _f.write(fill_csv(delta_csv, width=_l))
+
+    # Now we can run the diff!
+    _r = run_cmd(cmd=('"csvdiff '
+                      f"{file_name1} {file_name2} "
+                      '--format json"'))
+    if _r.get("code") == 0:
+        _r = json.loads(_r.get("output", ""))
+        if any(_r.values()):
+            _r["Columns"] = max(base_csv_header, delta_csv_header)
+    else:
+        _r = {}
+
+    # Clean Up!
+    if os.path.exists(file_name1):
+        os.remove(file_name1)
+    if os.path.exists(file_name2):
+        os.remove(file_name2)
+    return _r
+
+
+def fill_csv(csv_text, width, value="x"):
+    """Fill a csv text with 'value' if it's length is less than width"""
+    data = []
+    for line in csv_text.strip().split("\n"):
+        if line.startswith("Strain") or line.startswith("#"):
+            data.append(line)
+        elif line:
+            _n = line.split(",")
+            for i, val in enumerate(_n):
+                if not val.strip():
+                    _n[i] = value
+            data.append(",".join(_n + [value] * (width - len(_n))))
+    return "\n".join(data)
+
+
+def get_allowable_sampledata_headers(conn: Any) -> List:
+    """Get a list of all the case-attributes stored in the database"""
+    attributes = ["Strain Name", "Value", "SE", "Count"]
+    with conn.cursor() as cursor:
+        cursor.execute("SELECT Name from CaseAttribute")
+        attributes += [attributes[0] for attributes in
+                       cursor.fetchall()]
+    return attributes
+
+
+def extract_invalid_csv_headers(allowed_headers: List, csv_text: str) -> List:
+    """Check whether a csv text's columns contains valid headers"""
+    csv_header = []
+    for line in csv_text.split("\n"):
+        if line.startswith("Strain Name"):
+            csv_header = [_l.strip() for _l in line.split(",")]
+            break
+    invalid_headers = []
+    for header in csv_header:
+        if header not in allowed_headers:
+            invalid_headers.append(header)
+    return invalid_headers