about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/authentication.py4
-rw-r--r--gn3/db/traits.py47
-rw-r--r--tests/unit/db/test_traits.py33
3 files changed, 53 insertions, 31 deletions
diff --git a/gn3/authentication.py b/gn3/authentication.py
index 4aedacd..d0b35bc 100644
--- a/gn3/authentication.py
+++ b/gn3/authentication.py
@@ -163,8 +163,4 @@ def create_group(conn: Redis, group_name: Optional[str],
         }
         conn.hset("groups", group_id, json.dumps(group))
         return group
-    # This might break stuff, but it fixes the linting error regarding
-    # inconsistent return types.
-    # @BonfaceKilz please review this and replace with appropriate return and
-    # remove these comments.
     return None
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index 68b6059..338b320 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -110,7 +110,6 @@ def get_trait_csv_sample_data(conn: Any,
 
 
 def update_sample_data(conn: Any, #pylint: disable=[R0913]
-
                        trait_name: str,
                        strain_name: str,
                        phenotype_id: int,
@@ -203,25 +202,30 @@ def delete_sample_data(conn: Any,
                  "AND Strain.Name = \"%s\"") % (trait_name,
                                                 phenotype_id,
                                                 str(strain_name)))
-            strain_id, data_id = cursor.fetchone()
 
-            cursor.execute(("DELETE FROM PublishData "
+            # Check if it exists if the data was already deleted:
+            if _result := cursor.fetchone():
+                strain_id, data_id = _result
+
+            # Only run if the strain_id and data_id exist
+            if strain_id and data_id:
+                cursor.execute(("DELETE FROM PublishData "
                             "WHERE StrainId = %s AND Id = %s")
-                           % (strain_id, data_id))
-            deleted_published_data = cursor.rowcount
-
-            # Delete the PublishSE table
-            cursor.execute(("DELETE FROM PublishSE "
-                            "WHERE StrainId = %s AND DataId = %s") %
-                           (strain_id, data_id))
-            deleted_se_data = cursor.rowcount
-
-            # Delete the NStrain table
-            cursor.execute(("DELETE FROM NStrain "
-                            "WHERE StrainId = %s AND DataId = %s" %
-                            (strain_id, data_id)))
-            deleted_n_strains = cursor.rowcount
-        except Exception as e: #pylint: disable=[C0103, W0612]
+                               % (strain_id, data_id))
+                deleted_published_data = cursor.rowcount
+
+                # Delete the PublishSE table
+                cursor.execute(("DELETE FROM PublishSE "
+                                "WHERE StrainId = %s AND DataId = %s") %
+                               (strain_id, data_id))
+                deleted_se_data = cursor.rowcount
+
+                # Delete the NStrain table
+                cursor.execute(("DELETE FROM NStrain "
+                                "WHERE StrainId = %s AND DataId = %s" %
+                                (strain_id, data_id)))
+                deleted_n_strains = cursor.rowcount
+        except Exception as e:  #pylint: disable=[C0103, W0612]
             conn.rollback()
             raise MySQLdb.Error
         conn.commit()
@@ -254,6 +258,13 @@ def insert_sample_data(conn: Any, #pylint: disable=[R0913]
                            (strain_name,))
             strain_id = cursor.fetchone()
 
+            # Return early if an insert already exists!
+            cursor.execute("SELECT Id FROM PublishData where Id = %s "
+                           "AND StrainId = %s",
+                           (data_id, strain_id))
+            if cursor.fetchone():  # This strain already exists
+                return (0, 0, 0)
+
             # Insert the PublishData table
             cursor.execute(("INSERT INTO PublishData (Id, StrainId, value)"
                             "VALUES (%s, %s, %s)"),
diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py
index 4aa9389..75f3d4c 100644
--- a/tests/unit/db/test_traits.py
+++ b/tests/unit/db/test_traits.py
@@ -202,8 +202,6 @@ class TestTraitsDBFunctions(TestCase):
         """
         # pylint: disable=C0103
         db_mock = mock.MagicMock()
-
-        STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
         PUBLISH_DATA_SQL: str = (
             "UPDATE PublishData SET value = %s "
             "WHERE StrainId = %s AND Id = %s")
@@ -216,16 +214,33 @@ class TestTraitsDBFunctions(TestCase):
 
         with db_mock.cursor() as cursor:
             type(cursor).rowcount = 1
+            mock_fetchone = mock.MagicMock()
+            mock_fetchone.return_value = (1, 1)
+            type(cursor).fetchone = mock_fetchone
             self.assertEqual(update_sample_data(
                 conn=db_mock, strain_name="BXD11",
-                strain_id=10, publish_data_id=8967049,
-                value=18.7, error=2.3, count=2),
-                             (1, 1, 1, 1))
+                trait_name="1",
+                phenotype_id=10, value=18.7,
+                error=2.3, count=2),
+                             (1, 1, 1))
             cursor.execute.assert_has_calls(
-                [mock.call(STRAIN_ID_SQL, ('BXD11', 10)),
-                 mock.call(PUBLISH_DATA_SQL, (18.7, 10, 8967049)),
-                 mock.call(PUBLISH_SE_SQL, (2.3, 10, 8967049)),
-                 mock.call(N_STRAIN_SQL, (2, 10, 8967049))]
+                [mock.call('SELECT Strain.Id, PublishData.Id FROM'
+                           ' (PublishData, Strain, PublishXRef, '
+                           'PublishFreeze) LEFT JOIN PublishSE ON '
+                           '(PublishSE.DataId = PublishData.Id '
+                           'AND PublishSE.StrainId = '
+                           'PublishData.StrainId) LEFT JOIN NStrain ON '
+                           '(NStrain.DataId = PublishData.Id AND '
+                           'NStrain.StrainId = PublishData.StrainId) WHERE '
+                           'PublishXRef.InbredSetId = '
+                           'PublishFreeze.InbredSetId AND PublishData.Id = '
+                           'PublishXRef.DataId AND PublishXRef.Id = 1 AND '
+                           'PublishXRef.PhenotypeId = 10 AND '
+                           'PublishData.StrainId = Strain.Id AND '
+                           'Strain.Name = "BXD11"'),
+                 mock.call(PUBLISH_DATA_SQL, (18.7, 1, 1)),
+                 mock.call(PUBLISH_SE_SQL, (2.3, 1, 1)),
+                 mock.call(N_STRAIN_SQL, (2, 1, 1))]
             )
 
     def test_set_haveinfo_field(self):