about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/csvcmp.py12
-rw-r--r--tests/unit/test_csvcmp.py19
2 files changed, 31 insertions, 0 deletions
diff --git a/gn3/csvcmp.py b/gn3/csvcmp.py
index 43b795d..aa057b7 100644
--- a/gn3/csvcmp.py
+++ b/gn3/csvcmp.py
@@ -1,5 +1,7 @@
 """This module contains functions for manipulating and working with csv
 texts"""
+from typing import Any, List
+
 import json
 import os
 import uuid
@@ -106,3 +108,13 @@ def fill_csv(csv_text, width, value="x"):
                     _n[i] = value
             data.append(",".join(_n + [value] * (width - len(_n))))
     return "\n".join(data)
+
+
+def get_allowable_sampledata_headers(conn: Any) -> List:
+    """Get a list of all the case-attributes stored in the database"""
+    attributes = ["Strain Name", "Value", "SE", "Count"]
+    with conn.cursor() as cursor:
+        cursor.execute("SELECT Name from CaseAttribute")
+        attributes += [attributes[0] for attributes in
+                       cursor.fetchall()]
+    return attributes
diff --git a/tests/unit/test_csvcmp.py b/tests/unit/test_csvcmp.py
index c7ffe2f..984a61f 100644
--- a/tests/unit/test_csvcmp.py
+++ b/tests/unit/test_csvcmp.py
@@ -3,6 +3,7 @@ import pytest
 
 from gn3.csvcmp import csv_diff
 from gn3.csvcmp import fill_csv
+from gn3.csvcmp import get_allowable_sampledata_headers
 from gn3.csvcmp import remove_insignificant_edits
 from gn3.csvcmp import extract_strain_name
 
@@ -113,3 +114,21 @@ def test_extract_strain_name():
         extract_strain_name(csv_header="Strain Name,Value,SE,Count", data="BXD1,18,x,0")
         == "BXD1"
     )
+
+
+@pytest.mark.unit_test
+def test_get_allowable_csv_headers(mocker):
+    """Test that all the csv headers are fetched properly"""
+    mock_conn = mocker.MagicMock()
+    expected_values = [
+        "Strain Name", "Value", "SE", "Count",
+        "Condition", "Tissue", "Sex", "Age",
+        "Ethn.", "PMI (hrs)", "pH", "Color",
+    ]
+    with mock_conn.cursor() as cursor:
+        cursor.fetchall.return_value = (
+            ('Condition',), ('Tissue',), ('Sex',),
+            ('Age',), ('Ethn.',), ('PMI (hrs)',), ('pH',), ('Color',))
+        assert get_allowable_sampledata_headers(mock_conn) == expected_values
+        cursor.execute.assert_called_once_with(
+            "SELECT Name from CaseAttribute")