about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/csvcmp.py7
-rw-r--r--gn3/db/sample_data.py5
-rw-r--r--tests/unit/test_csvcmp.py8
3 files changed, 18 insertions, 2 deletions
diff --git a/gn3/csvcmp.py b/gn3/csvcmp.py
index ebd323e..360a101 100644
--- a/gn3/csvcmp.py
+++ b/gn3/csvcmp.py
@@ -4,6 +4,13 @@ import uuid
 from gn3.commands import run_cmd
 
 
+def extract_strain_name(csv_header, data, seek="Strain Name"):
+    for column, value in zip(csv_header.split(","), data.split(",")):
+        if seek in column:
+            return value
+    return ""
+
+
 def create_dirs_if_not_exists(dirs: list):
     for dir_ in dirs:
         if not os.path.exists(dir_):
diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py
index 06b5767..a410978 100644
--- a/gn3/db/sample_data.py
+++ b/gn3/db/sample_data.py
@@ -1,3 +1,4 @@
+from gn3.csvcmp import extract_strain_name
 from typing import Any, Tuple, Union
 
 import MySQLdb
@@ -200,7 +201,7 @@ def delete_sample_data(conn: Any,
     strain_id, data_id, inbredset_id = get_sample_data_ids(
         conn=conn, publishxref_id=trait_name,
         phenotype_id=phenotype_id,
-        strain_name=strain_name)
+        strain_name=extract_strain_name(csv_header, data))
 
     none_case_attrs = {
         "Strain Name": lambda: 0,
@@ -273,7 +274,7 @@ def insert_sample_data(conn: Any,  # pylint: disable=[R0913]
     strain_id, data_id, inbredset_id = get_sample_data_ids(
         conn=conn, publishxref_id=trait_name,
         phenotype_id=phenotype_id,
-        strain_name=strain_name)
+        strain_name=extract_strain_name(csv_header, data))
 
     none_case_attrs = {
         "Strain Name": lambda _: 0,
diff --git a/tests/unit/test_csvcmp.py b/tests/unit/test_csvcmp.py
index 4a96f99..fd7aa28 100644
--- a/tests/unit/test_csvcmp.py
+++ b/tests/unit/test_csvcmp.py
@@ -1,6 +1,7 @@
 from gn3.csvcmp import csv_diff
 from gn3.csvcmp import fill_csv
 from gn3.csvcmp import remove_insignificant_edits
+from gn3.csvcmp import extract_strain_name
 
 import pytest
 
@@ -99,3 +100,10 @@ BXD15,14,x,x
                         'Additions': [],
                         'Deletions': [],
                         'Modifications': []})
+
+
+@pytest.mark.unit_test
+def test_extract_strain_name():
+    assert(extract_strain_name(csv_header="Strain Name,Value,SE,Count",
+                               data="BXD1,18,x,0") ==
+           "BXD1")