about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/csvcmp.py14
-rw-r--r--tests/unit/test_csvcmp.py15
2 files changed, 29 insertions, 0 deletions
diff --git a/gn3/csvcmp.py b/gn3/csvcmp.py
index aa057b7..dd3f72b 100644
--- a/gn3/csvcmp.py
+++ b/gn3/csvcmp.py
@@ -118,3 +118,17 @@ def get_allowable_sampledata_headers(conn: Any) -> List:
         attributes += [attributes[0] for attributes in
                        cursor.fetchall()]
     return attributes
+
+
+def extract_invalid_csv_headers(allowed_headers: List, csv_text: str) -> List:
+    """Check whether a csv text's columns contains valid headers"""
+    csv_header = []
+    for line in csv_text.split("\n"):
+        if line.startswith("Strain Name"):
+            csv_header = [_l.strip() for _l in line.split(",")]
+            break
+    invalid_headers = []
+    for header in csv_header:
+        if header not in allowed_headers:
+            invalid_headers.append(header)
+    return invalid_headers
diff --git a/tests/unit/test_csvcmp.py b/tests/unit/test_csvcmp.py
index 984a61f..735979d 100644
--- a/tests/unit/test_csvcmp.py
+++ b/tests/unit/test_csvcmp.py
@@ -2,6 +2,7 @@
 import pytest
 
 from gn3.csvcmp import csv_diff
+from gn3.csvcmp import extract_invalid_csv_headers
 from gn3.csvcmp import fill_csv
 from gn3.csvcmp import get_allowable_sampledata_headers
 from gn3.csvcmp import remove_insignificant_edits
@@ -132,3 +133,17 @@ def test_get_allowable_csv_headers(mocker):
         assert get_allowable_sampledata_headers(mock_conn) == expected_values
         cursor.execute.assert_called_once_with(
             "SELECT Name from CaseAttribute")
+
+
+@pytest.mark.unit_test
+def test_extract_invalid_csv_headers_with_some_wrong_headers():
+    """Test that invalid column headers are extracted correctly from a csv
+string"""
+    allowed_headers = [
+        "Strain Name", "Value", "SE", "Count",
+        "Condition", "Tissue", "Sex", "Age",
+        "Ethn.", "PMI (hrs)", "pH", "Color",
+    ]
+
+    csv_text = "Strain Name, Value, SE, Colour"
+    assert extract_invalid_csv_headers(allowed_headers, csv_text) == ["Colour"]