about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/db/case_attributes.py22
-rw-r--r--tests/unit/db/test_case_attributes.py102
2 files changed, 112 insertions, 12 deletions
diff --git a/gn3/db/case_attributes.py b/gn3/db/case_attributes.py
index afebb7f..e14d99b 100644
--- a/gn3/db/case_attributes.py
+++ b/gn3/db/case_attributes.py
@@ -82,14 +82,17 @@ def queue_edit(cursor, directory: Path, edit: CaseAttributeEdit) -> Optional[int
 
 
 def update_case_attribute(cursor, directory: Path,
-                          change_id: int, edit: CaseAttributeEdit) -> int:
+                          change_id: int, edit: CaseAttributeEdit) -> bool:
     directory = f"{directory}/case-attributes/{edit.inbredset_id}"
     if not os.path.exists(directory):
         os.makedirs(directory)
     env = lmdb.open(directory, map_size=8_000_000)  # 1 MB
-    modifications = {}
-    if edit.changes.get("Modifications").get("Current"):
-        modifications = edit.get("Modifications").get("Current")
+    modifications = dict()
+    if edit.changes.get("Modifications") and \
+       edit.changes.get("Modifications").get("Current"):
+        modifications = edit.changes.get("Modifications").get("Current")
+    if not modifications:
+        return False
     for strain, changes in modifications.items():
         for case_attribute, value in changes.items():
             value = str(value).strip()
@@ -101,7 +104,7 @@ def update_case_attribute(cursor, directory: Path,
             cursor.execute("SELECT CaseAttributeId, Name AS CaseAttributeName "
                            "FROM CaseAttribute WHERE InbredSetId = %s "
                            "AND Name = %s",
-                           (inbredset_id, edit.inbredset_id,))
+                           (edit.inbredset_id, case_attribute,))
             case_attr_id, _ = cursor.fetchone()
             cursor.execute(
                 "INSERT INTO CaseAttributeXRefNew"
@@ -110,16 +113,17 @@ def update_case_attribute(cursor, directory: Path,
                 "ON DUPLICATE KEY UPDATE Value=VALUES(value)",
                 (edit.inbredset_id, strain_id, case_attr_id, value,))
             cursor.execute(
-                "UPDATE caseattributes_audit SET ",
+                "UPDATE caseattributes_audit SET "
                 "status = %s WHERE id = %s",
                 (str(edit.status), change_id,))
             with env.begin(write=True) as txn:
                 review_ids, approved_ids = set(), set()
                 if reviews := txn.get(b"review"):
                     review_ids = pickle.loads(reviews)
-                    review_ids.remove(change_id)
-                if approvals := txn.get(b"review"):
+                if approvals := txn.get(b"approved"):
                     approved_ids = pickle.loads(approvals)
-                    approved_ids.add(change_id)
+                review_ids.remove(change_id)
+                approved_ids.add(change_id)
                 txn.put(b"review", pickle.dumps(review_ids))
                 txn.put(b"approved", pickle.dumps(approved_ids))
+                return True
diff --git a/tests/unit/db/test_case_attributes.py b/tests/unit/db/test_case_attributes.py
index 18de4d3..790dbe6 100644
--- a/tests/unit/db/test_case_attributes.py
+++ b/tests/unit/db/test_case_attributes.py
@@ -1,13 +1,16 @@
 """Test cases for gn3.db.case_attributes.py"""
 
 import pytest
+import pickle
 import tempfile
 import os
+from pathlib import Path
 from pytest_mock import MockFixture
 from gn3.db.case_attributes import queue_edit
 from gn3.db.case_attributes import (
     CaseAttributeEdit,
-    EditStatus
+    EditStatus,
+    update_case_attribute
 )
 
 
@@ -19,9 +22,11 @@ def test_queue_edit(mocker: MockFixture) -> None:
         TMPDIR = os.environ.get("TMPDIR", tempfile.gettempdir())
         caseattr_id = queue_edit(
             cursor,
-            status=EditStatus.review
             directory=TMPDIR,
-            edit=CaseAttributeEdit(inbredset_id=1, user_id="xxxx", changes={"a": 1, "b": 2}))
+            edit=CaseAttributeEdit(
+                inbredset_id=1, status=EditStatus.review,
+                user_id="xxxx", changes={"a": 1, "b": 2}
+            ))
         cursor.execute.assert_called_once_with(
             "INSERT INTO "
             "caseattributes_audit(status, editor, json_diff_data) "
@@ -29,3 +34,94 @@ def test_queue_edit(mocker: MockFixture) -> None:
             "ON DUPLICATE KEY UPDATE status=%s",
             ('review', 'xxxx', '{"a": 1, "b": 2}', 'review'))
         assert 28 == caseattr_id
+
+
+@pytest.mark.unit_test
+def test_update_case_attribute_success(mocker: MockFixture) -> None:
+    """Test successful case attribute update with valid modifications."""
+    mock_cursor, mock_conn = mocker.MagicMock(), mocker.MagicMock()
+    mock_conn.cursor.return_value = mock_cursor
+    mock_lmdb = mocker.patch("gn3.db.case_attributes.lmdb")
+    mock_env, mock_txn = mocker.MagicMock(), mocker.MagicMock()
+    mock_lmdb.open.return_value = mock_env
+    mock_env.begin.return_value.__enter__.return_value = mock_txn
+    mock_txn.get.side_effect = [
+        pickle.dumps({100}),  # b"review" key
+        None,                 # b"approved" key
+    ]
+
+    TMPDIR = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
+    edit = CaseAttributeEdit(
+        inbredset_id=1,
+        user_id="test_user",
+        status=EditStatus.approved,
+        changes={
+            "Modifications": {
+                "Current": {
+                    "Strain1": {"Attribute1": "Value1"}
+                }
+            }
+        }
+    )
+    change_id = 100
+
+    # Mock cursor fetch results
+    mock_cursor.fetchone.side_effect = [
+        (10, "Strain1"),          # Strain query
+        (20, "Attribute1"),       # CaseAttribute query
+    ]
+
+    assert update_case_attribute(mock_cursor, TMPDIR, change_id, edit)
+
+    # Assertions for lmdb interactions
+    mock_lmdb.open.assert_called_once_with(
+        f"{TMPDIR}/case-attributes/1", map_size=8_000_000)
+    mock_env.begin.assert_called_once_with(write=True)
+    mock_txn.get.assert_has_calls([
+        mocker.call(b"review"),
+        mocker.call(b"approved")
+    ])
+    mock_txn.put.assert_has_calls([
+        mocker.call(b"review", pickle.dumps(set())),
+        mocker.call(b"approved", pickle.dumps({100}))
+    ])
+
+    # Assertions for SQL executions
+    mock_cursor.execute.assert_has_calls([
+        mocker.call(
+            "SELECT Id AS StrainId, Name AS StrainName FROM Strain WHERE Name = %s",
+            ("Strain1",)
+        ),
+        mocker.call(
+            "SELECT CaseAttributeId, Name AS CaseAttributeName FROM CaseAttribute "
+            "WHERE InbredSetId = %s AND Name = %s",
+            (1, "Attribute1")
+        ),
+        mocker.call(
+            "INSERT INTO CaseAttributeXRefNew(InbredSetId, StrainId, CaseAttributeId, Value) "
+            "VALUES (%s, %s, %s, %s) ON DUPLICATE KEY UPDATE Value=VALUES(value)",
+            (1, 10, 20, "Value1")
+        ),
+        mocker.call(
+            "UPDATE caseattributes_audit SET status = %s WHERE id = %s",
+            ("approved", 100)
+        )
+    ])
+
+
+@pytest.mark.unit_test
+def test_update_case_attribute_no_modifications(mocker: MockFixture) -> None:
+    """Test update_case_attribute with no modifications in edit.changes."""
+    mock_cursor, mock_conn = mocker.MagicMock(), mocker.MagicMock()
+    mock_conn.cursor.return_value = mock_cursor
+    mock_lmdb = mocker.patch("gn3.db.case_attributes.lmdb")
+    mock_env, mock_txn = mocker.MagicMock(), mocker.MagicMock()
+    TMPDIR = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
+    edit = CaseAttributeEdit(
+        inbredset_id=1,
+        user_id="test_user",
+        status=EditStatus.approved,
+        changes={}  # No modifications
+    )
+    change_id = 28
+    assert not update_case_attribute(mock_cursor, TMPDIR, change_id, edit)