aboutsummaryrefslogtreecommitdiff
path: root/gn3/csvcmp.py
blob: aa057b73ccba55a7ab093dc75480fc0c77745722 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""This module contains functions for manipulating and working with csv
texts"""
from typing import Any, List

import json
import os
import uuid
from gn3.commands import run_cmd


def extract_strain_name(csv_header, data, seek="Strain Name") -> str:
    """Extract a strain's name given a csv header"""
    for column, value in zip(csv_header.split(","), data.split(",")):
        if seek in column:
            return value
    return ""


def create_dirs_if_not_exists(dirs: list) -> None:
    """Create directories from a list"""
    for dir_ in dirs:
        if not os.path.exists(dir_):
            os.makedirs(dir_)


def remove_insignificant_edits(diff_data, epsilon=0.001):
    """Remove or ignore edits that are not within ε"""
    __mod = []
    for mod in diff_data.get("Modifications"):
        original = mod.get("Original").split(",")
        current = mod.get("Current").split(",")
        for i, (_x, _y) in enumerate(zip(original, current)):
            if (
                _x.replace(".", "").isdigit()
                and _y.replace(".", "").isdigit()
                and abs(float(_x) - float(_y)) < epsilon
            ):
                current[i] = _x
        if not (__o := ",".join(original)) == (__c := ",".join(current)):
            __mod.append(
                {
                    "Original": __o,
                    "Current": __c,
                }
            )
    diff_data["Modifications"] = __mod
    return diff_data


def csv_diff(base_csv, delta_csv, tmp_dir="/tmp") -> dict:
    """Diff 2 csv strings"""
    base_csv_list = base_csv.strip().split("\n")
    delta_csv_list = delta_csv.strip().split("\n")

    base_csv_header, delta_csv_header = "", ""
    for i, line in enumerate(base_csv_list):
        if line.startswith("Strain Name,Value,SE,Count"):
            base_csv_header, delta_csv_header = line, delta_csv_list[i]
            break
    longest_header = max(base_csv_header, delta_csv_header)

    if base_csv_header != delta_csv_header:
        if longest_header != base_csv_header:
            base_csv = base_csv.replace("Strain Name,Value,SE,Count",
                                        longest_header, 1)
        else:
            delta_csv = delta_csv.replace(
                "Strain Name,Value,SE,Count", longest_header, 1
            )
    file_name1 = os.path.join(tmp_dir, str(uuid.uuid4()))
    file_name2 = os.path.join(tmp_dir, str(uuid.uuid4()))

    with open(file_name1, "w", encoding="utf-8") as _f:
        _l = len(longest_header.split(","))
        _f.write(fill_csv(csv_text=base_csv, width=_l))
    with open(file_name2, "w", encoding="utf-8") as _f:
        _f.write(fill_csv(delta_csv, width=_l))

    # Now we can run the diff!
    _r = run_cmd(cmd=('"csvdiff '
                      f"{file_name1} {file_name2} "
                      '--format json"'))
    if _r.get("code") == 0:
        _r = json.loads(_r.get("output", ""))
        if any(_r.values()):
            _r["Columns"] = max(base_csv_header, delta_csv_header)
    else:
        _r = {}

    # Clean Up!
    if os.path.exists(file_name1):
        os.remove(file_name1)
    if os.path.exists(file_name2):
        os.remove(file_name2)
    return _r


def fill_csv(csv_text, width, value="x"):
    """Fill a csv text with 'value' if it's length is less than width"""
    data = []
    for line in csv_text.strip().split("\n"):
        if line.startswith("Strain") or line.startswith("#"):
            data.append(line)
        elif line:
            _n = line.split(",")
            for i, val in enumerate(_n):
                if not val.strip():
                    _n[i] = value
            data.append(",".join(_n + [value] * (width - len(_n))))
    return "\n".join(data)


def get_allowable_sampledata_headers(conn: Any) -> List:
    """Get a list of all the case-attributes stored in the database"""
    attributes = ["Strain Name", "Value", "SE", "Count"]
    with conn.cursor() as cursor:
        cursor.execute("SELECT Name from CaseAttribute")
        attributes += [attributes[0] for attributes in
                       cursor.fetchall()]
    return attributes