aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/db/sample_data.py161
-rw-r--r--tests/unit/db/test_sample_data.py43
2 files changed, 121 insertions, 83 deletions
diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py
index 06c3cc5..cfa4a3d 100644
--- a/gn3/db/sample_data.py
+++ b/gn3/db/sample_data.py
@@ -122,93 +122,92 @@ def get_sample_data_ids(conn: Any, publishxref_id: int,
def update_sample_data(conn: Any, # pylint: disable=[R0913]
trait_name: str,
- strain_name: str,
- phenotype_id: int,
- value: Union[int, float, str],
- error: Union[int, float, str],
- count: Union[int, str]):
+ original_data: str,
+ updated_data: str,
+ csv_header: str,
+ phenotype_id: int) -> int:
"""Given the right parameters, update sample-data from the relevant
table."""
- strain_id, data_id, _ = get_sample_data_ids(
- conn=conn, publishxref_id=trait_name,
- phenotype_id=phenotype_id, strain_name=strain_name)
-
- updated_published_data: int = 0
- updated_se_data: int = 0
- updated_n_strains: int = 0
-
- with conn.cursor() as cursor:
- # Update the PublishData table
- if value == "x":
- cursor.execute(("DELETE FROM PublishData "
- "WHERE StrainId = %s AND Id = %s")
- % (strain_id, data_id))
- updated_published_data = cursor.rowcount
- else:
- cursor.execute(("UPDATE PublishData SET value = %s "
- "WHERE StrainId = %s AND Id = %s"),
- (value, strain_id, data_id))
- updated_published_data = cursor.rowcount
+ def __update_data(conn, table, value):
+ if value and value != "x":
+ with conn.cursor() as cursor:
+ sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
+ _val = _MAP.get(table)[-1]
+ cursor.execute((f"UPDATE {table} SET {_val} = %s "
+ f"WHERE {sub_query}"),
+ (value, strain_id, data_id))
+ return cursor.rowcount
- if not updated_published_data:
+ def __update_case_attribute(conn, value, strain_id,
+ case_attr, inbredset_id):
+ if value != "x":
+ with conn.cursor() as cursor:
cursor.execute(
- "SELECT * FROM "
- "PublishData WHERE StrainId = "
- "%s AND Id = %s" % (strain_id, data_id))
- if not cursor.fetchone():
- cursor.execute(("INSERT INTO PublishData (Id, StrainId, "
- " value) VALUES (%s, %s, %s)") %
- (data_id, strain_id, value))
- updated_published_data = cursor.rowcount
+ "UPDATE CaseAttributeXRefNew "
+ "SET Value = %s "
+ "WHERE StrainId = %s AND CaseAttributeId = "
+ "(SELECT CaseAttributeId FROM "
+ "CaseAttribute WHERE Name = %s) "
+ "AND InbredSetId = %s",
+ (value, strain_id, case_attr, inbredset_id))
+ return cursor.rowcount
+ return 0
- # Update the PublishSE table
- if error == "x":
- cursor.execute(("DELETE FROM PublishSE "
- "WHERE StrainId = %s AND DataId = %s") %
- (strain_id, data_id))
- updated_se_data = cursor.rowcount
- else:
- cursor.execute(("UPDATE PublishSE SET error = %s "
- "WHERE StrainId = %s AND DataId = %s"),
- (None if error == "x" else error,
- strain_id, data_id))
- updated_se_data = cursor.rowcount
- if not updated_se_data:
- cursor.execute(
- "SELECT * FROM "
- "PublishSE WHERE StrainId = "
- "%s AND DataId = %s" % (strain_id, data_id))
- if not cursor.fetchone():
- cursor.execute(("INSERT INTO PublishSE (StrainId, DataId, "
- " error) VALUES (%s, %s, %s)") %
- (strain_id, data_id,
- None if error == "x" else error))
- updated_se_data = cursor.rowcount
+ strain_id, data_id, inbredset_id = get_sample_data_ids(
+ conn=conn, publishxref_id=trait_name,
+ phenotype_id=phenotype_id,
+ strain_name=extract_strain_name(csv_header, original_data))
- # Update the NStrain table
- if count == "x":
- cursor.execute(("DELETE FROM NStrain "
- "WHERE StrainId = %s AND DataId = %s" %
- (strain_id, data_id)))
- updated_n_strains = cursor.rowcount
- else:
- cursor.execute(("UPDATE NStrain SET count = %s "
- "WHERE StrainId = %s AND DataId = %s"),
- (count, strain_id, data_id))
- updated_n_strains = cursor.rowcount
- if not updated_n_strains:
- cursor.execute(
- "SELECT * FROM "
- "NStrain WHERE StrainId = "
- "%s AND DataId = %s" % (strain_id, data_id))
- if not cursor.fetchone():
- cursor.execute(("INSERT INTO NStrain "
- "(StrainId, DataId, count) "
- "VALUES (%s, %s, %s)") %
- (strain_id, data_id, count))
- updated_n_strains = cursor.rowcount
- return (updated_published_data,
- updated_se_data, updated_n_strains)
+ none_case_attrs = {
+ "Strain Name": lambda x: 0,
+ "Value": lambda x: __update_data(conn, "PublishData", x),
+ "SE": lambda x: __update_data(conn, "PublishSE", x),
+ "Count": lambda x: __update_data(conn, "NStrain", x),
+ }
+ count = 0
+ try:
+ __actions = __extract_actions(original_data=original_data,
+ updated_data=updated_data,
+ csv_header=csv_header)
+ if __actions.get("update"):
+ _csv_header = __actions["update"]["csv_header"]
+ _data = __actions["update"]["data"]
+ for header, value in zip(_csv_header.split(","),
+ _data.split(",")):
+ header = header.strip()
+ value = value.strip()
+ if header in none_case_attrs:
+ count += none_case_attrs.get(header)(value)
+ else:
+ count += __update_case_attribute(
+ conn=conn,
+ value=none_case_attrs.get(header)(value),
+ strain_id=strain_id,
+ case_attr=header,
+ inbredset_id=inbredset_id)
+ if __actions.get("delete"):
+ _rowcount = delete_sample_data(
+ conn=conn,
+ trait_name=trait_name,
+ data=__actions["delete"]["data"],
+ csv_header=__actions["delete"]["csv_header"],
+ phenotype_id=phenotype_id)
+ if _rowcount:
+ count += 1
+ if __actions.get("insert"):
+ _rowcount = insert_sample_data(
+ conn=conn,
+ trait_name=trait_name,
+ data=__actions["insert"]["data"],
+ csv_header=__actions["insert"]["csv_header"],
+ phenotype_id=phenotype_id)
+ if _rowcount:
+ count += 1
+ except Exception as e: # pylint: disable=[C0103, W0612]
+ conn.rollback()
+ raise MySQLdb.Error
+ conn.commit()
+ return count
def delete_sample_data(conn: Any,
diff --git a/tests/unit/db/test_sample_data.py b/tests/unit/db/test_sample_data.py
index 4015ba5..80166fc 100644
--- a/tests/unit/db/test_sample_data.py
+++ b/tests/unit/db/test_sample_data.py
@@ -1,8 +1,10 @@
import pytest
+import gn3
-from gn3.db.sample_data import insert_sample_data
-from gn3.db.sample_data import delete_sample_data
from gn3.db.sample_data import __extract_actions
+from gn3.db.sample_data import delete_sample_data
+from gn3.db.sample_data import insert_sample_data
+from gn3.db.sample_data import update_sample_data
@pytest.mark.unit_test
@@ -94,3 +96,40 @@ def test_extract_actions():
"insert": {"data": "2,F", "csv_header": "SE,Sex"},
"update": {"data": "19,1", "csv_header": "Value,Count"},
})
+
+
+@pytest.mark.unit_test
+def test_update_sample_data(mocker):
+ mock_conn = mocker.MagicMock()
+ strain_id, data_id, inbredset_id = 1, 17373, 20
+ with mock_conn.cursor() as cursor:
+ # cursor.fetchone.side_effect = (0, [19, ], 0)
+ mocker.patch('gn3.db.sample_data.get_sample_data_ids',
+ return_value=(strain_id, data_id, inbredset_id))
+ mocker.patch('gn3.db.sample_data.insert_sample_data',
+ return_value=1)
+ mocker.patch('gn3.db.sample_data.delete_sample_data',
+ return_value=1)
+ update_sample_data(conn=mock_conn,
+ trait_name=35,
+ original_data="BXD1,18,x,0,x",
+ updated_data="BXD1,x,2,1,F",
+ csv_header="Strain Name,Value,SE,Count,Sex",
+ phenotype_id=10007)
+ gn3.db.sample_data.insert_sample_data.assert_called_once_with(
+ conn=mock_conn,
+ trait_name=35,
+ data="2,F",
+ csv_header="SE,Sex",
+ phenotype_id=10007)
+ gn3.db.sample_data.delete_sample_data.assert_called_once_with(
+ conn=mock_conn,
+ trait_name=35,
+ data="18",
+ csv_header="Value",
+ phenotype_id=10007)
+ cursor.execute.assert_has_calls(
+ [mocker.call("UPDATE NStrain SET count = %s "
+ "WHERE StrainId = %s AND DataId = %s",
+ ('1', strain_id, data_id))],
+ any_order=False)