about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-02-14 06:56:32 +0300
committerFrederick Muriuki Muriithi2022-02-17 06:37:30 +0300
commit74044f3c7985308b4996da3a52f91c5c20a19194 (patch)
treed86714b859b31cbbd1755522f8abd8eed16e321b
parent67f517aa0f44f55dc691ffd791bf22ef7af0b02c (diff)
downloadgenenetwork3-74044f3c7985308b4996da3a52f91c5c20a19194.tar.gz
Use pytest's "mark" feature to categorise tests
Use pytest's `mark` feature to explicitly categorise the tests and run them
per category
-rw-r--r--pytest.ini6
-rw-r--r--setup_commands/run_tests.py19
-rw-r--r--tests/integration/test_correlation.py4
-rw-r--r--tests/integration/test_gemma.py14
-rw-r--r--tests/integration/test_general.py8
-rw-r--r--tests/integration/test_wgcna.py3
-rw-r--r--tests/unit/computations/test_correlation.py18
-rw-r--r--tests/unit/computations/test_dictify_by_samples.py1
-rw-r--r--tests/unit/computations/test_diff.py3
-rw-r--r--tests/unit/computations/test_gemma.py8
-rw-r--r--tests/unit/computations/test_parsers.py4
-rw-r--r--tests/unit/computations/test_partial_correlations.py9
-rw-r--r--tests/unit/computations/test_qtlreaper.py4
-rw-r--r--tests/unit/computations/test_rqtl.py2
-rw-r--r--tests/unit/computations/test_slink.py13
-rw-r--r--tests/unit/computations/test_wgcna.py6
-rw-r--r--tests/unit/db/test_audit.py3
-rw-r--r--tests/unit/db/test_correlation.py4
-rw-r--r--tests/unit/db/test_datasets.py6
-rw-r--r--tests/unit/db/test_db.py8
-rw-r--r--tests/unit/db/test_genotypes.py6
-rw-r--r--tests/unit/db/test_species.py5
-rw-r--r--tests/unit/db/test_traits.py15
-rw-r--r--tests/unit/test_authentication.py9
-rw-r--r--tests/unit/test_commands.py9
-rw-r--r--tests/unit/test_data_helpers.py5
-rw-r--r--tests/unit/test_db_utils.py3
-rw-r--r--tests/unit/test_file_utils.py9
-rw-r--r--tests/unit/test_heatmaps.py7
29 files changed, 198 insertions, 13 deletions
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 0000000..58eba11
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,6 @@
+[pytest]
+addopts = --strict-markers
+markers =
+	unit_test
+	integration_test
+	performance_test
\ No newline at end of file
diff --git a/setup_commands/run_tests.py b/setup_commands/run_tests.py
index 9a2c9ad..1bb5dab 100644
--- a/setup_commands/run_tests.py
+++ b/setup_commands/run_tests.py
@@ -7,21 +7,18 @@ class RunTests(Command):
     A custom command to run tests.
     """
     description = "Run the tests"
-    commands = {
-        "all": "pytest",
-        "unit": "pytest tests/unit",
-        "integration": "pytest tests/integration",
-        "performance": "pytest tests/performance",
-    }
+    test_types = (
+        "all", "unit", "integration", "performance")
     user_options = [
         ("type=", None,
          f"""Specify the type of tests to run.
-         Valid types are {tuple(commands.keys())}.
+         Valid types are {tuple(test_types)}.
          Default is `all`.""")]
 
     def __init__(self, dist):
         """Initialise the command."""
         super().__init__(dist)
+        self.command = "pytest"
 
     def initialize_options(self):
         """Initialise the default values of all the options."""
@@ -29,13 +26,15 @@ class RunTests(Command):
 
     def finalize_options(self):
         """Set final value of all the options once they are processed."""
-        if self.type not in RunTests.commands.keys():
+        if self.type not in RunTests.test_types:
             raise Exception(f"""
             Invalid test type (self.type) requested!
             Valid types are
-            {tuple(RunTests.commands.keys())}""")
+            {tuple(RunTests.test_types)}""")
 
+        if self.type != "all":
+            self.command = f"pytest -m {self.type}_test"
     def run(self):
         """Run the chosen tests"""
         print(f"Running {self.type} tests")
-        os.system(RunTests.commands[self.type])
+        os.system(self.command)
diff --git a/tests/integration/test_correlation.py b/tests/integration/test_correlation.py
index bdd9bce..cf63c17 100644
--- a/tests/integration/test_correlation.py
+++ b/tests/integration/test_correlation.py
@@ -1,6 +1,7 @@
 """module contains integration tests for correlation"""
 from unittest import TestCase
 from unittest import mock
+import pytest
 from gn3.app import create_app
 
 
@@ -10,6 +11,7 @@ class CorrelationIntegrationTest(TestCase):
     def setUp(self):
         self.app = create_app().test_client()
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.correlation.compute_all_sample_correlation")
     def test_sample_r_correlation(self, mock_compute_samples):
         """Test /api/correlation/sample_r/{method}"""
@@ -61,6 +63,7 @@ class CorrelationIntegrationTest(TestCase):
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.get_json(), api_response)
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.correlation.compute_all_lit_correlation")
     @mock.patch("gn3.api.correlation.database_connector")
     def test_lit_correlation(self, database_connector, mock_compute_corr):
@@ -80,6 +83,7 @@ class CorrelationIntegrationTest(TestCase):
         self.assertEqual(mock_compute_corr.call_count, 1)
         self.assertEqual(response.status_code, 200)
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.correlation.compute_tissue_correlation")
     def test_tissue_correlation(self, mock_tissue_corr):
         """Test api/correlation/tissue_corr/{corr_method}"""
diff --git a/tests/integration/test_gemma.py b/tests/integration/test_gemma.py
index f871173..0515539 100644
--- a/tests/integration/test_gemma.py
+++ b/tests/integration/test_gemma.py
@@ -7,6 +7,8 @@ from dataclasses import dataclass
 from typing import Callable
 from unittest import mock
 
+import pytest
+
 from gn3.app import create_app
 
 
@@ -30,6 +32,7 @@ class GemmaAPITest(unittest.TestCase):
             "TMPDIR": "/tmp"
         }).test_client()
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.gemma.run_cmd")
     def test_get_version(self, mock_run_cmd):
         """Test that the correct response is returned"""
@@ -38,6 +41,7 @@ class GemmaAPITest(unittest.TestCase):
         self.assertEqual(response.get_json(), {"status": 0, "output": "v1.9"})
         self.assertEqual(response.status_code, 200)
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.gemma.redis.Redis")
     def test_check_cmd_status(self, mock_redis):
         """Test that you can check the status of a given command"""
@@ -52,6 +56,7 @@ class GemmaAPITest(unittest.TestCase):
             name="cmd::2021-02-1217-3224-3224-1234", key="status")
         self.assertEqual(response.get_json(), {"status": "test"})
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.gemma.queue_cmd")
     @mock.patch("gn3.computations.gemma.get_hash_of_files")
     @mock.patch("gn3.api.gemma.jsonfile_to_dict")
@@ -94,6 +99,7 @@ class GemmaAPITest(unittest.TestCase):
                 "unique_id": "my-unique-id"
             })
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.gemma.queue_cmd")
     @mock.patch("gn3.computations.gemma.get_hash_of_files")
     @mock.patch("gn3.api.gemma.jsonfile_to_dict")
@@ -137,6 +143,7 @@ class GemmaAPITest(unittest.TestCase):
                 "unique_id": "my-unique-id"
             })
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.gemma.queue_cmd")
     @mock.patch("gn3.computations.gemma.get_hash_of_files")
     @mock.patch("gn3.api.gemma.jsonfile_to_dict")
@@ -187,6 +194,7 @@ class GemmaAPITest(unittest.TestCase):
                 "output_file": "hash-output.json"
             })
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.gemma.queue_cmd")
     @mock.patch("gn3.computations.gemma.get_hash_of_files")
     @mock.patch("gn3.api.gemma.jsonfile_to_dict")
@@ -240,6 +248,7 @@ class GemmaAPITest(unittest.TestCase):
                 "output_file": "hash-output.json"
             })
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.gemma.queue_cmd")
     @mock.patch("gn3.computations.gemma.get_hash_of_files")
     @mock.patch("gn3.api.gemma.jsonfile_to_dict")
@@ -292,6 +301,7 @@ class GemmaAPITest(unittest.TestCase):
                 "output_file": "hash-output.json"
             })
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.gemma.queue_cmd")
     @mock.patch("gn3.computations.gemma.get_hash_of_files")
     @mock.patch("gn3.api.gemma.jsonfile_to_dict")
@@ -346,6 +356,7 @@ class GemmaAPITest(unittest.TestCase):
                 "output_file": "hash-output.json"
             })
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.gemma.queue_cmd")
     @mock.patch("gn3.computations.gemma.get_hash_of_files")
     @mock.patch("gn3.api.gemma.jsonfile_to_dict")
@@ -401,6 +412,7 @@ class GemmaAPITest(unittest.TestCase):
                 "output_file": "hash-output.json"
             })
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.gemma.queue_cmd")
     @mock.patch("gn3.computations.gemma.get_hash_of_files")
     @mock.patch("gn3.api.gemma.jsonfile_to_dict")
@@ -465,6 +477,7 @@ class GemmaAPITest(unittest.TestCase):
                 "output_file": "hash-output.json"
             })
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.gemma.queue_cmd")
     @mock.patch("gn3.computations.gemma.get_hash_of_files")
     @mock.patch("gn3.api.gemma.jsonfile_to_dict")
@@ -530,6 +543,7 @@ class GemmaAPITest(unittest.TestCase):
                 "output_file": "hash-output.json"
             })
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.gemma.queue_cmd")
     @mock.patch("gn3.computations.gemma.get_hash_of_files")
     @mock.patch("gn3.api.gemma.jsonfile_to_dict")
diff --git a/tests/integration/test_general.py b/tests/integration/test_general.py
index 8fc2b43..9d87449 100644
--- a/tests/integration/test_general.py
+++ b/tests/integration/test_general.py
@@ -1,8 +1,10 @@
 """Integration tests for some 'general' API endpoints"""
 import os
 import unittest
-
 from unittest import mock
+
+import pytest
+
 from gn3.app import create_app
 
 
@@ -11,6 +13,7 @@ class GeneralAPITest(unittest.TestCase):
     def setUp(self):
         self.app = create_app().test_client()
 
+    @pytest.mark.integration_test
     def test_metadata_endpoint_exists(self):
         """Test that /metadata/upload exists"""
         response = self.app.post("/api/metadata/upload/d41d86-e4ceEo")
@@ -19,6 +22,7 @@ class GeneralAPITest(unittest.TestCase):
                          {"status": 128,
                           "error": "Please provide a file!"})
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.general.extract_uploaded_file")
     def test_metadata_file_upload(self, mock_extract_upload):
         """Test correct upload of file"""
@@ -37,6 +41,7 @@ class GeneralAPITest(unittest.TestCase):
                          {"status": 0,
                           "token": "d41d86-e4ceEo"})
 
+    @pytest.mark.integration_test
     def test_metadata_file_wrong_upload(self):
         """Test that incorrect upload return correct status code"""
         response = self.app.post("/api/metadata/upload/d41d86-e4ceEo",
@@ -47,6 +52,7 @@ class GeneralAPITest(unittest.TestCase):
                          {"status": 128,
                           "error": "gzip failed to unpack file"})
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.general.run_cmd")
     def test_run_r_qtl(self, mock_run_cmd):
         """Test correct upload of file"""
diff --git a/tests/integration/test_wgcna.py b/tests/integration/test_wgcna.py
index 078449d..5880b40 100644
--- a/tests/integration/test_wgcna.py
+++ b/tests/integration/test_wgcna.py
@@ -3,6 +3,8 @@
 from unittest import TestCase
 from unittest import mock
 
+import pytest
+
 from gn3.app import create_app
 
 
@@ -12,6 +14,7 @@ class WgcnaIntegrationTest(TestCase):
     def setUp(self):
         self.app = create_app().test_client()
 
+    @pytest.mark.integration_test
     @mock.patch("gn3.api.wgcna.call_wgcna_script")
     def test_wgcna_endpoint(self, mock_wgcna_script):
         """test /api/wgcna/run_wgcna endpoint"""
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py
index 7523d99..69d4c52 100644
--- a/tests/unit/computations/test_correlation.py
+++ b/tests/unit/computations/test_correlation.py
@@ -2,6 +2,7 @@
 from unittest import TestCase
 from unittest import mock
 
+import pytest
 from collections import namedtuple
 import math
 from numpy.testing import assert_almost_equal
@@ -91,6 +92,7 @@ class DataBase(QueryableMixin):
 class TestCorrelation(TestCase):
     """Class for testing correlation functions"""
 
+    @pytest.mark.unit_test
     def test_normalize_values(self):
         """Function to test normalizing values """
 
@@ -106,6 +108,7 @@ class TestCorrelation(TestCase):
                 results = normalize_values(a_values, b_values)
                 self.assertEqual(list(zip(*list(results))), expected_result)
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.computations.correlations.compute_corr_coeff_p_value")
     @mock.patch("gn3.computations.correlations.normalize_values")
     def test_compute_sample_r_correlation(self, norm_vals, compute_corr):
@@ -130,6 +133,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(bicor_results, ("1412_at", 0.8, 0.21, 7))
 
+    @pytest.mark.unit_test
     def test_filter_shared_sample_keys(self):
         """Function to  tests shared key between two dicts"""
 
@@ -157,6 +161,7 @@ class TestCorrelation(TestCase):
         self.assertEqual(list(zip(*list(results))), [filtered_this_samplelist,
                                                      filtered_target_samplelist])
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.computations.correlations.compute_sample_r_correlation")
     @mock.patch("gn3.computations.correlations.filter_shared_sample_keys")
     def test_compute_all_sample(self, filter_shared_samples, sample_r_corr):
@@ -199,6 +204,7 @@ class TestCorrelation(TestCase):
             corr_method="pearson", trait_vals=('1.23', '6.565', '6.456'),
             target_samples_vals=('6.266', '6.565', '6.456'))
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.computations.correlations.compute_corr_coeff_p_value")
     def test_tissue_correlation_for_trait(self, mock_compute_corr_coeff):
         """Test given a primary tissue values for a trait  and and a list of\
@@ -217,6 +223,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(tissue_results, expected_tissue_results)
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.computations.correlations.fetch_lit_correlation_data")
     @mock.patch("gn3.computations.correlations.map_to_mouse_gene_id")
     def test_lit_correlation_for_trait(self, mock_mouse_gene_id, fetch_lit_data):
@@ -244,6 +251,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(lit_results, expected_results)
 
+    @pytest.mark.unit_test
     def test_fetch_lit_correlation_data(self):
         """Test for fetching lit correlation data from\
         the database where the input and mouse geneid are none
@@ -257,6 +265,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(results, ("1", 0))
 
+    @pytest.mark.unit_test
     def test_fetch_lit_correlation_data_db_query(self):
         """Test for fetching lit corr coefficent givent the input\
          input trait mouse gene id and mouse gene id
@@ -274,6 +283,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(expected_results, lit_results)
 
+    @pytest.mark.unit_test
     def test_query_lit_correlation_for_db_empty(self):
         """Test that corr coeffient returned is 0 given the\
         db value if corr coefficient is empty
@@ -289,6 +299,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(lit_results, ("16", 0))
 
+    @pytest.mark.unit_test
     def test_query_formatter(self):
         """Test for formatting a query given the query string and also the\
         values
@@ -316,6 +327,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(formatted_query, expected_formatted_query)
 
+    @pytest.mark.unit_test
     def test_query_formatter_no_query_values(self):
         """Test for formatting a query where there are no\
         string placeholder
@@ -325,6 +337,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(formatted_query, query)
 
+    @pytest.mark.unit_test
     def test_map_to_mouse_gene_id(self):
         """Test for converting a gene id to mouse geneid\
         given a species which is not mouse
@@ -348,6 +361,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(results, expected_results)
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.computations.correlations.lit_correlation_for_trait")
     def test_compute_all_lit_correlation(self, mock_lit_corr):
         """Test for compute all lit correlation which acts\
@@ -368,6 +382,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(lit_correlation_results, expected_mocked_lit_results)
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.computations.correlations.tissue_correlation_for_trait")
     @mock.patch("gn3.computations.correlations.process_trait_symbol_dict")
     def test_compute_all_tissue_correlation(self, process_trait_symbol, mock_tissue_corr):
@@ -411,6 +426,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(results, expected_results)
 
+    @pytest.mark.unit_test
     def test_map_shared_keys_to_values(self):
         """test helper function needed to integrate with genenenetwork2\
         given a a samplelist containing dataset sampelist keys\
@@ -431,6 +447,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(results, expected_results)
 
+    @pytest.mark.unit_test
     def test_process_trait_symbol_dict(self):
         """test for processing trait symbol dict\
         and fetch tissue values from tissue value dict\
@@ -449,6 +466,7 @@ class TestCorrelation(TestCase):
 
         self.assertEqual(results, [expected_results])
 
+    @pytest.mark.unit_test
     def test_compute_correlation(self):
         """Test that the new correlation function works the same as the original
         from genenetwork1."""
diff --git a/tests/unit/computations/test_dictify_by_samples.py b/tests/unit/computations/test_dictify_by_samples.py
index decc095..8a1332f 100644
--- a/tests/unit/computations/test_dictify_by_samples.py
+++ b/tests/unit/computations/test_dictify_by_samples.py
@@ -63,6 +63,7 @@ values = st.lists(st.floats())
 variances = st.lists(st.one_of(st.none(), st.floats()))
 other = st.lists(st.integers())
 
+@pytest.mark.unit_test
 @given(svv=st.tuples(
     st.lists(non_empty_samples),
     st.lists(values),
diff --git a/tests/unit/computations/test_diff.py b/tests/unit/computations/test_diff.py
index e4f5dde..128fb60 100644
--- a/tests/unit/computations/test_diff.py
+++ b/tests/unit/computations/test_diff.py
@@ -2,6 +2,8 @@
 import unittest
 import os
 
+import pytest
+
 from gn3.computations.diff import generate_diff
 
 TESTDIFF = """3,4c3,4
@@ -19,6 +21,7 @@ TESTDIFF = """3,4c3,4
 
 class TestDiff(unittest.TestCase):
     """Test cases for computations.diff"""
+    @pytest.mark.unit_test
     def test_generate_diff(self):
         """Test that the correct diff is generated"""
         data = os.path.join(os.path.dirname(__file__).split("unit")[0],
diff --git a/tests/unit/computations/test_gemma.py b/tests/unit/computations/test_gemma.py
index 73dd5eb..b36a93e 100644
--- a/tests/unit/computations/test_gemma.py
+++ b/tests/unit/computations/test_gemma.py
@@ -1,7 +1,9 @@
 """Test cases for procedures defined in computations.gemma"""
 import unittest
-
 from unittest import mock
+
+import pytest
+
 from gn3.computations.gemma import generate_gemma_cmd
 from gn3.computations.gemma import generate_hash_of_string
 from gn3.computations.gemma import generate_pheno_txt_file
@@ -9,6 +11,7 @@ from gn3.computations.gemma import generate_pheno_txt_file
 
 class TestGemma(unittest.TestCase):
     """Test cases for computations.gemma module"""
+    @pytest.mark.unit_test
     def test_generate_pheno_txt_file(self):
         """Test that the pheno text file is generated correctly"""
         open_mock = mock.mock_open()
@@ -26,11 +29,13 @@ class TestGemma(unittest.TestCase):
             mock.call("BXD07 438.700\n")
         ])
 
+    @pytest.mark.unit_test
     def test_generate_hash_of_string(self):
         """Test that a string is hashed correctly"""
         self.assertEqual(generate_hash_of_string("I^iQP&TlSR^z"),
                          "hMVRw8kbEp49rOmoIkhMjA")
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.computations.gemma.get_hash_of_files")
     def test_compute_k_values_without_loco(self, mock_get_hash):
         """Test computing k values without loco"""
@@ -52,6 +57,7 @@ class TestGemma(unittest.TestCase):
                                     "-gk > /tmp/my-token/my-hash-output.json")
                                })
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.computations.gemma.get_hash_of_files")
     def test_generate_gemma_cmd_with_loco(self, mock_get_hash):
         """Test computing k values with loco"""
diff --git a/tests/unit/computations/test_parsers.py b/tests/unit/computations/test_parsers.py
index b51b0bf..f05f766 100644
--- a/tests/unit/computations/test_parsers.py
+++ b/tests/unit/computations/test_parsers.py
@@ -2,17 +2,21 @@
 import unittest
 import os
 
+import pytest
+
 from gn3.computations.parsers import parse_genofile
 
 
 class TestParsers(unittest.TestCase):
     """Test cases for some various parsers"""
 
+    @pytest.mark.unit_test
     def test_parse_genofile_without_existing_file(self):
         """Assert that an error is raised if the genotype file is absent"""
         self.assertRaises(FileNotFoundError, parse_genofile,
                           "/non-existent-file")
 
+    @pytest.mark.unit_test
     def test_parse_genofile_with_existing_file(self):
         """Test that a genotype file is parsed correctly"""
         samples = ["bxd1", "bxd2"]
diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py
index 3690ca4..ee17659 100644
--- a/tests/unit/computations/test_partial_correlations.py
+++ b/tests/unit/computations/test_partial_correlations.py
@@ -3,6 +3,7 @@
 from unittest import TestCase
 
 import pandas
+import pytest
 from numpy.testing import assert_allclose
 
 from gn3.computations.partial_correlations import (
@@ -98,6 +99,7 @@ dictified_control_samples = (
 class TestPartialCorrelations(TestCase):
     """Class for testing partial correlations computation functions"""
 
+    @pytest.mark.unit_test
     def test_control_samples(self):
         """Test that the control_samples works as expected."""
         self.assertEqual(
@@ -112,6 +114,7 @@ class TestPartialCorrelations(TestCase):
               (None, None, None)),
              (6, 4, 3)))
 
+    @pytest.mark.unit_test
     def test_dictify_by_samples(self):
         """
         Test that `dictify_by_samples` generates the appropriate dict
@@ -142,6 +145,7 @@ class TestPartialCorrelations(TestCase):
                  (6, 4, 3))),
             dictified_control_samples)
 
+    @pytest.mark.unit_test
     def test_fix_samples(self):
         """
         Test that `fix_samples` returns only the common samples
@@ -187,6 +191,7 @@ class TestPartialCorrelations(TestCase):
              (None, None, None, None, None, None, None, None, None, None, None,
               None, None)))
 
+    @pytest.mark.unit_test
     def test_find_identical_traits(self):
         """
         Test `gn3.partial_correlations.find_identical_traits`.
@@ -219,6 +224,7 @@ class TestPartialCorrelations(TestCase):
                 self.assertEqual(
                     find_identical_traits(primn, primv, contn, contv), expected)
 
+    @pytest.mark.unit_test
     def test_tissue_correlation_error(self):
         """
         Test that `tissue_correlation` raises specific exceptions for particular
@@ -253,6 +259,7 @@ class TestPartialCorrelations(TestCase):
                 with self.assertRaises(error, msg=error_msg):
                     tissue_correlation(primary, target, method)
 
+    @pytest.mark.unit_test
     def test_tissue_correlation(self): # pylint: disable=R0201
         """
         Test that the correct correlation values are computed for the given:
@@ -269,6 +276,7 @@ class TestPartialCorrelations(TestCase):
                 assert_allclose(
                     tissue_correlation(primary, target, method), expected)
 
+    @pytest.mark.unit_test
     def test_good_dataset_samples_indexes(self):
         """
         Test that `good_dataset_samples_indexes` returns correct indices.
@@ -279,6 +287,7 @@ class TestPartialCorrelations(TestCase):
                 ("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l")),
             (0, 4, 8, 10))
 
+    @pytest.mark.unit_test
     def test_build_data_frame(self):
         """
         Check that the function builds the correct data frame.
diff --git a/tests/unit/computations/test_qtlreaper.py b/tests/unit/computations/test_qtlreaper.py
index 742d106..607f4a6 100644
--- a/tests/unit/computations/test_qtlreaper.py
+++ b/tests/unit/computations/test_qtlreaper.py
@@ -1,5 +1,6 @@
 """Module contains tests for gn3.computations.qtlreaper"""
 from unittest import TestCase
+import pytest
 from gn3.computations.qtlreaper import (
     parse_reaper_main_results,
     organise_reaper_main_results,
@@ -9,6 +10,7 @@ from tests.unit.sample_test_data import organised_trait_1
 class TestQTLReaper(TestCase):
     """Class for testing qtlreaper interface functions."""
 
+    @pytest.mark.unit_test
     def test_parse_reaper_main_results(self):
         """Test that the main results file is parsed correctly."""
         self.assertEqual(
@@ -67,6 +69,7 @@ class TestQTLReaper(TestCase):
                 }
             ])
 
+    @pytest.mark.unit_test
     def test_parse_reaper_permutation_results(self):
         """Test that the permutations results file is parsed correctly."""
         self.assertEqual(
@@ -77,6 +80,7 @@ class TestQTLReaper(TestCase):
              5.63874, 5.71346, 5.71936, 5.74275, 5.76764, 5.79815, 5.81671,
              5.82775, 5.89659, 5.92117, 5.93396, 5.93396, 5.94957])
 
+    @pytest.mark.unit_test
     def test_organise_reaper_main_results(self):
         """Check that results are organised correctly."""
         self.assertEqual(
diff --git a/tests/unit/computations/test_rqtl.py b/tests/unit/computations/test_rqtl.py
index 955d0ab..51df281 100644
--- a/tests/unit/computations/test_rqtl.py
+++ b/tests/unit/computations/test_rqtl.py
@@ -2,10 +2,12 @@
 import unittest
 
 from unittest import mock
+import pytest
 from gn3.computations.rqtl import generate_rqtl_cmd
 
 class TestRqtl(unittest.TestCase):
     """Test cases for computations.rqtl module"""
+    @pytest.mark.unit_test
     @mock.patch("gn3.computations.rqtl.generate_hash_of_string")
     @mock.patch("gn3.computations.rqtl.get_hash_of_files")
     def test_generate_rqtl_command(self, mock_get_hash_files, mock_generate_hash_string):
diff --git a/tests/unit/computations/test_slink.py b/tests/unit/computations/test_slink.py
index 995393b..276133a 100644
--- a/tests/unit/computations/test_slink.py
+++ b/tests/unit/computations/test_slink.py
@@ -1,6 +1,8 @@
 """Module contains tests for slink"""
 from unittest import TestCase
 
+import pytest
+
 from gn3.computations.slink import slink
 from gn3.computations.slink import nearest
 from gn3.computations.slink import LengthError
@@ -9,6 +11,7 @@ from gn3.computations.slink import MirrorError
 class TestSlink(TestCase):
     """Class for testing slink functions"""
 
+    @pytest.mark.unit_test
     def test_nearest_expects_list_of_lists(self):
         """Test that function only accepts a list of lists."""
         # This might be better handled with type-hints and mypy
@@ -18,6 +21,7 @@ class TestSlink(TestCase):
                 with self.assertRaises(ValueError, msg="Expected list or tuple"):
                     nearest(item, 1, 1)
 
+    @pytest.mark.unit_test
     def test_nearest_does_not_allow_empty_lists(self):
         """Test that function does not accept an empty list, or any of the child
         lists to be empty."""
@@ -29,6 +33,7 @@ class TestSlink(TestCase):
                 with self.assertRaises(ValueError):
                     nearest(lst, 1, 1)
 
+    @pytest.mark.unit_test
     def test_nearest_expects_children_are_same_length_as_parent(self):
         """Test that children lists are same length as parent list."""
         for lst in [[[0, 1]],
@@ -40,6 +45,7 @@ class TestSlink(TestCase):
                 with self.assertRaises(LengthError):
                     nearest(lst, 1, 1)
 
+    @pytest.mark.unit_test
     def test_nearest_expects_member_is_zero_distance_from_itself(self):
         """Test that distance of a member from itself is zero"""
         for lst in [[[1]],
@@ -50,6 +56,7 @@ class TestSlink(TestCase):
                 with self.assertRaises(ValueError):
                     nearest(lst, 1, 1)
 
+    @pytest.mark.unit_test
     def test_nearest_expects_distance_atob_is_equal_to_distance_btoa(self):
         """Test that the distance from member A to member B is the same as that
         from member B to member A."""
@@ -60,6 +67,7 @@ class TestSlink(TestCase):
                 with self.assertRaises(MirrorError):
                     nearest(lst, 1, 1)
 
+    @pytest.mark.unit_test
     def test_nearest_expects_zero_or_positive_distances(self):
         """Test that all distances are either zero, or greater than zero."""
         # Based on:
@@ -74,6 +82,7 @@ class TestSlink(TestCase):
                 with self.assertRaises(ValueError, msg="Distances should be positive."):
                     nearest(lst, 1, 1)
 
+    @pytest.mark.unit_test
     def test_nearest_returns_shortest_distance_given_coordinates_to_both_group_members(self):
         """Test that the shortest distance is returned."""
         # This test is named wrong - at least I think it is, from the expected results
@@ -234,6 +243,7 @@ class TestSlink(TestCase):
             with self.subTest(lst=lst):
                 self.assertEqual(nearest(lst, i, j), expected)
 
+    @pytest.mark.unit_test
     def test_nearest_gives_shortest_distance_between_list_of_members_and_member(self):
         """Test that the shortest distance is returned."""
         for members_distances, members_list, member_coordinate, expected_distance in [
@@ -260,6 +270,7 @@ class TestSlink(TestCase):
                         members_distances, member_coordinate, members_list),
                     expected_distance)
 
+    @pytest.mark.unit_test
     def test_nearest_returns_shortest_distance_given_two_lists_of_members(self):
         """Test that the shortest distance is returned."""
         for members_distances, members_list, member_list2, expected_distance in [
@@ -289,12 +300,14 @@ class TestSlink(TestCase):
                         members_distances, member_list2, members_list),
                     expected_distance)
 
+    @pytest.mark.unit_test
     def test_slink_wrong_data_returns_empty_list(self):
         """Test that empty list is returned for wrong data."""
         for data in [1, "test", [], 2.945, nearest, [0]]:
             with self.subTest(data=data):
                 self.assertEqual(slink(data), [])
 
+    @pytest.mark.unit_test
     def test_slink_with_data(self):
         """Test slink with example data, and expected results for each data
         sample."""
diff --git a/tests/unit/computations/test_wgcna.py b/tests/unit/computations/test_wgcna.py
index 5f23a86..3130374 100644
--- a/tests/unit/computations/test_wgcna.py
+++ b/tests/unit/computations/test_wgcna.py
@@ -2,6 +2,8 @@
 from unittest import TestCase
 from unittest import mock
 
+import pytest
+
 from gn3.computations.wgcna import dump_wgcna_data
 from gn3.computations.wgcna import compose_wgcna_cmd
 from gn3.computations.wgcna import call_wgcna_script
@@ -10,6 +12,7 @@ from gn3.computations.wgcna import call_wgcna_script
 class TestWgcna(TestCase):
     """test class for wgcna"""
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.computations.wgcna.process_image")
     @mock.patch("gn3.computations.wgcna.run_cmd")
     @mock.patch("gn3.computations.wgcna.compose_wgcna_cmd")
@@ -95,6 +98,7 @@ class TestWgcna(TestCase):
 
             self.assertEqual(results, expected_output)
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.computations.wgcna.run_cmd")
     @mock.patch("gn3.computations.wgcna.compose_wgcna_cmd")
     @mock.patch("gn3.computations.wgcna.dump_wgcna_data")
@@ -117,6 +121,7 @@ class TestWgcna(TestCase):
             self.assertEqual(call_wgcna_script(
                 "input_file.R", ""), expected_error)
 
+    @pytest.mark.unit_test
     def test_compose_wgcna_cmd(self):
         """test for composing wgcna cmd"""
         wgcna_cmd = compose_wgcna_cmd(
@@ -124,6 +129,7 @@ class TestWgcna(TestCase):
         self.assertEqual(
             wgcna_cmd, "Rscript ./scripts/wgcna.r  /tmp/wgcna.json")
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.computations.wgcna.TMPDIR", "/tmp")
     @mock.patch("gn3.computations.wgcna.uuid.uuid4")
     def test_create_json_file(self, file_name_generator):
diff --git a/tests/unit/db/test_audit.py b/tests/unit/db/test_audit.py
index 7480169..884afc6 100644
--- a/tests/unit/db/test_audit.py
+++ b/tests/unit/db/test_audit.py
@@ -3,6 +3,8 @@ import json
 from unittest import TestCase
 from unittest import mock
 
+import pytest
+
 from gn3.db import insert
 from gn3.db.metadata_audit import MetadataAudit
 
@@ -10,6 +12,7 @@ from gn3.db.metadata_audit import MetadataAudit
 class TestMetadatAudit(TestCase):
     """Test cases for fetching chromosomes"""
 
+    @pytest.mark.unit_test
     def test_insert_into_metadata_audit(self):
         """Test that data is inserted correctly in the audit table
 
diff --git a/tests/unit/db/test_correlation.py b/tests/unit/db/test_correlation.py
index 3f940b2..5afe55f 100644
--- a/tests/unit/db/test_correlation.py
+++ b/tests/unit/db/test_correlation.py
@@ -4,6 +4,8 @@ Tests for the gn3.db.correlations module
 
 from unittest import TestCase
 
+import pytest
+
 from gn3.db.correlations import (
     build_query_sgo_lit_corr,
     build_query_tissue_corr)
@@ -12,6 +14,7 @@ class TestCorrelation(TestCase):
     """Test cases for correlation data fetching functions"""
     maxDiff = None
 
+    @pytest.mark.unit_test
     def test_build_query_sgo_lit_corr(self):
         """
         Test that the literature correlation query is built correctly.
@@ -53,6 +56,7 @@ class TestCorrelation(TestCase):
               "ORDER BY Probeset.Id"),
              2))
 
+    @pytest.mark.unit_test
     def test_build_query_tissue_corr(self):
         """
         Test that the tissue correlation query is built correctly.
diff --git a/tests/unit/db/test_datasets.py b/tests/unit/db/test_datasets.py
index 0b8c2fe..5b86db9 100644
--- a/tests/unit/db/test_datasets.py
+++ b/tests/unit/db/test_datasets.py
@@ -1,6 +1,7 @@
 """Tests for gn3/db/datasets.py"""
 
 from unittest import mock, TestCase
+import pytest
 from gn3.db.datasets import (
     retrieve_dataset_name,
     retrieve_group_fields,
@@ -11,6 +12,7 @@ from gn3.db.datasets import (
 class TestDatasetsDBFunctions(TestCase):
     """Test cases for datasets functions."""
 
+    @pytest.mark.unit_test
     def test_retrieve_dataset_name(self):
         """Test that the function is called correctly."""
         for trait_type, thresh, trait_name, dataset_name, columns, table, expected in [
@@ -42,6 +44,7 @@ class TestDatasetsDBFunctions(TestCase):
                             table=table, cols=columns),
                         {"threshold": thresh, "name": dataset_name})
 
+    @pytest.mark.unit_test
     def test_retrieve_probeset_group_fields(self):
         """
         Test that the `group` and `group_id` fields are retrieved appropriately
@@ -65,6 +68,7 @@ class TestDatasetsDBFunctions(TestCase):
                             " AND ProbeSetFreeze.Name = %(name)s"),
                         {"name": trait_name})
 
+    @pytest.mark.unit_test
     def test_retrieve_group_fields(self):
         """
         Test that the group fields are set up correctly for the different trait
@@ -90,6 +94,7 @@ class TestDatasetsDBFunctions(TestCase):
                             trait_type, trait_name, dataset_info, db_mock),
                         expected)
 
+    @pytest.mark.unit_test
     def test_retrieve_publish_group_fields(self):
         """
         Test that the `group` and `group_id` fields are retrieved appropriately
@@ -112,6 +117,7 @@ class TestDatasetsDBFunctions(TestCase):
                             " AND PublishFreeze.Name = %(name)s"),
                         {"name": trait_name})
 
+    @pytest.mark.unit_test
     def test_retrieve_geno_group_fields(self):
         """
         Test that the `group` and `group_id` fields are retrieved appropriately
diff --git a/tests/unit/db/test_db.py b/tests/unit/db/test_db.py
index e47c9fd..8ac468c 100644
--- a/tests/unit/db/test_db.py
+++ b/tests/unit/db/test_db.py
@@ -2,6 +2,8 @@
 from unittest import TestCase
 from unittest import mock
 
+import pytest
+
 from gn3.db import fetchall
 from gn3.db import fetchone
 from gn3.db import update
@@ -14,6 +16,7 @@ from gn3.db.metadata_audit import MetadataAudit
 class TestCrudMethods(TestCase):
     """Test cases for CRUD methods"""
 
+    @pytest.mark.unit_test
     def test_update_phenotype_with_no_data(self):
         """Test that a phenotype is updated correctly if an empty Phenotype dataclass
         is provided
@@ -24,6 +27,7 @@ class TestCrudMethods(TestCase):
             conn=db_mock, table="Phenotype",
             data=Phenotype(), where=Phenotype()), None)
 
+    @pytest.mark.unit_test
     def test_update_phenotype_with_data(self):
         """
         Test that a phenotype is updated correctly if some
@@ -46,6 +50,7 @@ class TestCrudMethods(TestCase):
                 "Submitter = %s WHERE id = %s AND Owner = %s",
                 ('Test Pre Pub', 'Test Post Pub', 'Rob', 1, 'Rob'))
 
+    @pytest.mark.unit_test
     def test_fetch_phenotype(self):
         """Test that a single phenotype is fetched properly
 
@@ -68,6 +73,7 @@ class TestCrudMethods(TestCase):
                 "SELECT * FROM Phenotype WHERE id = %s AND Owner = %s",
                 (35, 'Rob'))
 
+    @pytest.mark.unit_test
     def test_fetchall_metadataaudit(self):
         """Test that multiple metadata_audit entries are fetched properly
 
@@ -96,6 +102,7 @@ class TestCrudMethods(TestCase):
                  "dataset_id = %s AND editor = %s"),
                 (35, 'Rob'))
 
+    @pytest.mark.unit_test
     # pylint: disable=R0201
     def test_probeset_called_with_right_columns(self):
         """Given a columns argument, test that the correct sql query is
@@ -112,6 +119,7 @@ class TestCrudMethods(TestCase):
                 "Name = %s",
                 ("1446112_at",))
 
+    @pytest.mark.unit_test
     def test_diff_from_dict(self):
         """Test that a correct diff is generated"""
         self.assertEqual(diff_from_dict({"id": 1, "data": "a"},
diff --git a/tests/unit/db/test_genotypes.py b/tests/unit/db/test_genotypes.py
index c125224..28728bf 100644
--- a/tests/unit/db/test_genotypes.py
+++ b/tests/unit/db/test_genotypes.py
@@ -1,5 +1,6 @@
 """Tests gn3.db.genotypes"""
 from unittest import TestCase
+import pytest
 from gn3.db.genotypes import (
     parse_genotype_file,
     parse_genotype_labels,
@@ -10,6 +11,7 @@ from gn3.db.genotypes import (
 class TestGenotypes(TestCase):
     """Tests for functions in `gn3.db.genotypes`."""
 
+    @pytest.mark.unit_test
     def test_parse_genotype_labels(self):
         """Test that the genotype labels are parsed correctly."""
         self.assertEqual(
@@ -22,6 +24,7 @@ class TestGenotypes(TestCase):
              ("type", "test_type"), ("mat", "test_mat"), ("pat", "test_pat"),
              ("het", "test_het"), ("unk", "test_unk")))
 
+    @pytest.mark.unit_test
     def test_parse_genotype_header(self):
         """Test that the genotype header is parsed correctly."""
         for header, expected in [
@@ -43,6 +46,7 @@ class TestGenotypes(TestCase):
             with self.subTest(header=header):
                 self.assertEqual(parse_genotype_header(header), expected)
 
+    @pytest.mark.unit_test
     def test_parse_genotype_data_line(self):
         """Test parsing of data lines."""
         for line, geno_obj, parlist, expected in [
@@ -76,6 +80,7 @@ class TestGenotypes(TestCase):
                     parse_genotype_marker(line, geno_obj, parlist),
                     expected)
 
+    @pytest.mark.unit_test
     def test_build_genotype_chromosomes(self):
         """
         Given `markers` and `geno_obj`, test that `build_genotype_chromosomes`
@@ -115,6 +120,7 @@ class TestGenotypes(TestCase):
                     build_genotype_chromosomes(geno_obj, markers),
                     expected)
 
+    @pytest.mark.unit_test
     def test_parse_genotype_file(self):
         """Test the parsing of genotype files. """
         self.assertEqual(
diff --git a/tests/unit/db/test_species.py b/tests/unit/db/test_species.py
index b2c4844..e883b21 100644
--- a/tests/unit/db/test_species.py
+++ b/tests/unit/db/test_species.py
@@ -2,6 +2,8 @@
 from unittest import TestCase
 from unittest import mock
 
+import pytest
+
 from gn3.db.species import get_chromosome
 from gn3.db.species import get_all_species
 
@@ -9,6 +11,7 @@ from gn3.db.species import get_all_species
 class TestChromosomes(TestCase):
     """Test cases for fetching chromosomes"""
 
+    @pytest.mark.unit_test
     def test_get_chromosome_using_species_name(self):
         """Test that the chromosome is fetched using a species name"""
         db_mock = mock.MagicMock()
@@ -24,6 +27,7 @@ class TestChromosomes(TestCase):
                 "Species.Name = 'TestCase' ORDER BY OrderId"
             )
 
+    @pytest.mark.unit_test
     def test_get_chromosome_using_group_name(self):
         """Test that the chromosome is fetched using a group name"""
         db_mock = mock.MagicMock()
@@ -39,6 +43,7 @@ class TestChromosomes(TestCase):
                 "InbredSet.Name = 'TestCase' ORDER BY OrderId"
             )
 
+    @pytest.mark.unit_test
     def test_get_all_species(self):
         """Test that species are fetched correctly"""
         db_mock = mock.MagicMock()
diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py
index f3e0bab..d7c0b27 100644
--- a/tests/unit/db/test_traits.py
+++ b/tests/unit/db/test_traits.py
@@ -1,5 +1,6 @@
 """Tests for gn3/db/traits.py"""
 from unittest import mock, TestCase
+import pytest
 from gn3.db.traits import (
     build_trait_name,
     export_trait_data,
@@ -49,6 +50,7 @@ trait_data = {
 class TestTraitsDBFunctions(TestCase):
     "Test cases for traits functions"
 
+    @pytest.mark.unit_test
     def test_retrieve_publish_trait_info(self):
         """Test retrieval of type `Publish` traits."""
         db_mock = mock.MagicMock()
@@ -83,6 +85,7 @@ class TestTraitsDBFunctions(TestCase):
                  " AND PublishXRef.InbredSetId = %(trait_dataset_id)s"),
                 trait_source)
 
+    @pytest.mark.unit_test
     def test_retrieve_probeset_trait_info(self):
         """Test retrieval of type `Probeset` traits."""
         db_mock = mock.MagicMock()
@@ -118,6 +121,7 @@ class TestTraitsDBFunctions(TestCase):
                     "AND ProbeSetFreeze.Name = %(trait_dataset_name)s "
                     "AND ProbeSet.Name = %(trait_name)s"), trait_source)
 
+    @pytest.mark.unit_test
     def test_retrieve_geno_trait_info(self):
         """Test retrieval of type `Geno` traits."""
         db_mock = mock.MagicMock()
@@ -141,6 +145,7 @@ class TestTraitsDBFunctions(TestCase):
                     "AND Geno.Name = %(trait_name)s"),
                 trait_source)
 
+    @pytest.mark.unit_test
     def test_retrieve_temp_trait_info(self):
         """Test retrieval of type `Temp` traits."""
         db_mock = mock.MagicMock()
@@ -153,6 +158,7 @@ class TestTraitsDBFunctions(TestCase):
                 "SELECT name, description FROM Temp WHERE Name = %(trait_name)s",
                 trait_source)
 
+    @pytest.mark.unit_test
     def test_build_trait_name_with_good_fullnames(self):
         """
         Check that the name is built correctly.
@@ -169,6 +175,7 @@ class TestTraitsDBFunctions(TestCase):
             with self.subTest(fullname=fullname):
                 self.assertEqual(build_trait_name(fullname), expected)
 
+    @pytest.mark.unit_test
     def test_build_trait_name_with_bad_fullnames(self):
         """
         Check that an exception is raised if the full name format is wrong.
@@ -178,6 +185,7 @@ class TestTraitsDBFunctions(TestCase):
                 with self.assertRaises(AssertionError, msg="Name format error"):
                     build_trait_name(fullname)
 
+    @pytest.mark.unit_test
     def test_retrieve_trait_info(self):
         """Test that information on traits is retrieved as appropriate."""
         for threshold, trait_fullname, expected in [
@@ -194,6 +202,7 @@ class TestTraitsDBFunctions(TestCase):
                             threshold, trait_fullname, db_mock),
                         expected)
 
+    @pytest.mark.unit_test
     def test_update_sample_data(self):
         """Test that the SQL queries when calling update_sample_data are called with
         the right calls.
@@ -242,6 +251,7 @@ class TestTraitsDBFunctions(TestCase):
                  mock.call(N_STRAIN_SQL, (2, 1, 1))]
             )
 
+    @pytest.mark.unit_test
     def test_set_haveinfo_field(self):
         """Test that the `haveinfo` field is set up correctly"""
         for trait_info, expected in [
@@ -250,6 +260,7 @@ class TestTraitsDBFunctions(TestCase):
             with self.subTest(trait_info=trait_info, expected=expected):
                 self.assertEqual(set_haveinfo_field(trait_info), expected)
 
+    @pytest.mark.unit_test
     def test_set_homologene_id_field(self):
         """Test that the `homologene_id` field is set up correctly"""
         for trait_type, trait_info, expected in [
@@ -264,6 +275,7 @@ class TestTraitsDBFunctions(TestCase):
                     self.assertEqual(
                         set_homologene_id_field(trait_type, trait_info, db_mock), expected)
 
+    @pytest.mark.unit_test
     def test_set_confidential_field(self):
         """Test that the `confidential` field is set up correctly"""
         for trait_type, trait_info, expected in [
@@ -275,6 +287,7 @@ class TestTraitsDBFunctions(TestCase):
                 self.assertEqual(
                     set_confidential_field(trait_type, trait_info), expected)
 
+    @pytest.mark.unit_test
     def test_export_trait_data_dtype(self):
         """
         Test `export_trait_data` with different values for the `dtype` keyword
@@ -290,6 +303,7 @@ class TestTraitsDBFunctions(TestCase):
                     export_trait_data(trait_data, samplelist, dtype=dtype),
                     expected)
 
+    @pytest.mark.unit_test
     def test_export_trait_data_dtype_all_flags(self):
         """
         Test `export_trait_data` with different values for the `dtype` keyword
@@ -331,6 +345,7 @@ class TestTraitsDBFunctions(TestCase):
                         n_exists=nflag),
                     expected)
 
+    @pytest.mark.unit_test
     def test_export_informative(self):
         """Test that the function exports appropriate data."""
         # pylint: disable=W0621
diff --git a/tests/unit/test_authentication.py b/tests/unit/test_authentication.py
index 061b684..59c88ef 100644
--- a/tests/unit/test_authentication.py
+++ b/tests/unit/test_authentication.py
@@ -1,8 +1,10 @@
 """Test cases for authentication.py"""
 import json
 import unittest
-
 from unittest import mock
+
+import pytest
+
 from gn3.authentication import AdminRole
 from gn3.authentication import DataRole
 from gn3.authentication import get_highest_user_access_role
@@ -24,6 +26,7 @@ class TestGetUserMembership(unittest.TestCase):
                 '"created_timestamp": "Oct 06 2021 06:39PM"}')}
         self.conn = conn
 
+    @pytest.mark.unit_test
     def test_user_is_group_member_only(self):
         """Test that a user is only a group member"""
         self.assertEqual(
@@ -34,6 +37,7 @@ class TestGetUserMembership(unittest.TestCase):
             {"member": True,
              "admin": False})
 
+    @pytest.mark.unit_test
     def test_user_is_group_admin_only(self):
         """Test that a user is a group admin only"""
         self.assertEqual(
@@ -44,6 +48,7 @@ class TestGetUserMembership(unittest.TestCase):
             {"member": False,
              "admin": True})
 
+    @pytest.mark.unit_test
     def test_user_is_both_group_member_and_admin(self):
         """Test that a user is both an admin and member of a group"""
         self.assertEqual(
@@ -58,6 +63,7 @@ class TestGetUserMembership(unittest.TestCase):
 class TestCheckUserAccessRole(unittest.TestCase):
     """Test cases for `get_highest_user_access_role`"""
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.authentication.requests.get")
     def test_edit_access(self, requests_mock):
         """Test that the right access roles are set if the user has edit access"""
@@ -79,6 +85,7 @@ class TestCheckUserAccessRole(unittest.TestCase):
                 "admin": AdminRole.EDIT_ACCESS,
             })
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.authentication.requests.get")
     def test_no_access(self, requests_mock):
         """Test that the right access roles are set if the user has no access"""
diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py
index e644e1a..e0efaf7 100644
--- a/tests/unit/test_commands.py
+++ b/tests/unit/test_commands.py
@@ -5,6 +5,7 @@ from dataclasses import dataclass
 from datetime import datetime
 from typing import Callable
 from unittest import mock
+import pytest
 from gn3.commands import compose_gemma_cmd
 from gn3.commands import compose_rqtl_cmd
 from gn3.commands import queue_cmd
@@ -23,6 +24,7 @@ class MockRedis:
 class TestCommands(unittest.TestCase):
     """Test cases for commands.py"""
 
+    @pytest.mark.unit_test
     def test_compose_gemma_cmd_no_extra_args(self):
         """Test that the gemma cmd is composed correctly"""
         self.assertEqual(
@@ -37,6 +39,7 @@ class TestCommands(unittest.TestCase):
              "-p /tmp/gf13Ad0tRX/phenofile.txt"
              " -gk"))
 
+    @pytest.mark.unit_test
     def test_compose_gemma_cmd_extra_args(self):
         """Test that the gemma cmd is composed correctly"""
         self.assertEqual(
@@ -54,6 +57,7 @@ class TestCommands(unittest.TestCase):
              "-p /tmp/gf13Ad0tRX/phenofile.txt"
              " -gk"))
 
+    @pytest.mark.unit_test
     def test_compose_rqtl_cmd(self):
         """Test that the R/qtl cmd is composed correctly"""
         self.assertEqual(
@@ -78,6 +82,7 @@ class TestCommands(unittest.TestCase):
              "--addcovar")
         )
 
+    @pytest.mark.unit_test
     def test_queue_cmd_exception_raised_when_redis_is_down(self):
         """Test that the correct error is raised when Redis is unavailable"""
         self.assertRaises(RedisConnectionError,
@@ -88,6 +93,7 @@ class TestCommands(unittest.TestCase):
                                          hset=mock.MagicMock(),
                                          rpush=mock.MagicMock()))
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.commands.datetime")
     @mock.patch("gn3.commands.uuid4")
     def test_queue_cmd_correct_calls_to_redis(self, mock_uuid4,
@@ -112,6 +118,7 @@ class TestCommands(unittest.TestCase):
         mock_redis_conn.rpush.assert_has_calls(
             [mock.call("GN2::job-queue", actual_unique_id)])
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.commands.datetime")
     @mock.patch("gn3.commands.uuid4")
     def test_queue_cmd_right_calls_to_redis_with_email(self,
@@ -140,11 +147,13 @@ class TestCommands(unittest.TestCase):
         mock_redis_conn.rpush.assert_has_calls(
             [mock.call("GN2::job-queue", actual_unique_id)])
 
+    @pytest.mark.unit_test
     def test_run_cmd_correct_input(self):
         """Test that a correct cmd is processed correctly"""
         self.assertEqual(run_cmd("echo test"),
                          {"code": 0, "output": "test\n"})
 
+    @pytest.mark.unit_test
     def test_run_cmd_incorrect_input(self):
         """Test that an incorrect cmd is processed correctly"""
         result = run_cmd("echoo test")
diff --git a/tests/unit/test_data_helpers.py b/tests/unit/test_data_helpers.py
index 88ea469..b6de42e 100644
--- a/tests/unit/test_data_helpers.py
+++ b/tests/unit/test_data_helpers.py
@@ -4,6 +4,8 @@ Test functions in gn3.data_helpers
 
 from unittest import TestCase
 
+import pytest
+
 from gn3.data_helpers import partition_by, partition_all, parse_csv_line
 
 class TestDataHelpers(TestCase):
@@ -11,6 +13,7 @@ class TestDataHelpers(TestCase):
     Test functions in gn3.data_helpers
     """
 
+    @pytest.mark.unit_test
     def test_partition_all(self):
         """
         Test that `gn3.data_helpers.partition_all` partitions sequences as expected.
@@ -36,6 +39,7 @@ class TestDataHelpers(TestCase):
             with self.subTest(n=count, items=items):
                 self.assertEqual(partition_all(count, items), expected)
 
+    @pytest.mark.unit_test
     def test_parse_csv_line(self):
         """
         Test parsing a single line from a CSV file
@@ -60,6 +64,7 @@ class TestDataHelpers(TestCase):
                         line=line, delimiter=delimiter, quoting=quoting),
                     expected)
 
+    @pytest.mark.unit_test
     def test_partition_by(self):
         """
         Test that `partition_by` groups the data using the given predicate
diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py
index 0f2de9e..dd0cd5d 100644
--- a/tests/unit/test_db_utils.py
+++ b/tests/unit/test_db_utils.py
@@ -3,6 +3,7 @@
 from unittest import TestCase
 from unittest import mock
 
+import pytest
 from types import SimpleNamespace
 
 from gn3.db_utils import database_connector
@@ -12,6 +13,7 @@ from gn3.db_utils import parse_db_url
 class TestDatabase(TestCase):
     """class contains testd for db connection functions"""
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.db_utils.mdb")
     @mock.patch("gn3.db_utils.parse_db_url")
     def test_database_connector(self, mock_db_parser, mock_sql):
@@ -28,6 +30,7 @@ class TestDatabase(TestCase):
         self.assertIsInstance(
             results, tuple, "database not created successfully")
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.db_utils.SQL_URI",
                 "mysql://username:4321@localhost/test")
     def test_parse_db_url(self):
diff --git a/tests/unit/test_file_utils.py b/tests/unit/test_file_utils.py
index 75be4f6..77fea88 100644
--- a/tests/unit/test_file_utils.py
+++ b/tests/unit/test_file_utils.py
@@ -5,6 +5,7 @@ import unittest
 from dataclasses import dataclass
 from typing import Callable
 from unittest import mock
+import pytest
 from gn3.fs_helpers import extract_uploaded_file
 from gn3.fs_helpers import get_dir_hash
 from gn3.fs_helpers import jsonfile_to_dict
@@ -21,17 +22,20 @@ class MockFile:
 class TestFileUtils(unittest.TestCase):
     """Test cases for procedures defined in fs_helpers.py"""
 
+    @pytest.mark.unit_test
     def test_get_dir_hash(self):
         """Test that a directory is hashed correctly"""
         test_dir = os.path.join(os.path.dirname(__file__), "test_data")
         self.assertEqual("3aeafab7d53b4f76d223366ae7ee9738",
                          get_dir_hash(test_dir))
 
+    @pytest.mark.unit_test
     def test_get_dir_hash_non_existent_dir(self):
         """Test thata an error is raised when the dir does not exist"""
         self.assertRaises(FileNotFoundError, get_dir_hash,
                           "/non-existent-file")
 
+    @pytest.mark.unit_test
     def test_jsonfile_to_dict(self):
         """Test that a json file is parsed correctly""" ""
         json_file = os.path.join(os.path.dirname(__file__), "test_data",
@@ -39,12 +43,14 @@ class TestFileUtils(unittest.TestCase):
         self.assertEqual("Longer description",
                          jsonfile_to_dict(json_file).get("description"))
 
+    @pytest.mark.unit_test
     def test_jsonfile_to_dict_nonexistent_file(self):
         """Test that a ValueError is raised when the json file is
 non-existent"""
         self.assertRaises(FileNotFoundError, jsonfile_to_dict,
                           "/non-existent-dir")
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.fs_helpers.tarfile")
     @mock.patch("gn3.fs_helpers.secure_filename")
     def test_extract_uploaded_file(self, mock_file, mock_tarfile):
@@ -65,6 +71,7 @@ non-existent"""
         mock_file.assert_called_once_with("upload-data.tar.gz")
         self.assertEqual(result, {"status": 0, "token": "abcdef-abcdef"})
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.fs_helpers.secure_filename")
     def test_extract_uploaded_file_non_existent_gzip(self, mock_file):
         """Test that the right error message is returned when there is a problem
@@ -78,6 +85,7 @@ extracting the file"""
             "error": "gzip failed to unpack file"
         })
 
+    @pytest.mark.unit_test
     def test_cache_ipfs_file_cache_hit(self):
         """Test that the correct file location is returned if there's a cache hit"""
         # Create empty file
@@ -96,6 +104,7 @@ extracting the file"""
         os.rmdir(test_dir)
         self.assertEqual(file_loc, f"{test_dir}/genotype.txt")
 
+    @pytest.mark.unit_test
     @mock.patch("gn3.fs_helpers.ipfshttpclient")
     def test_cache_ipfs_file_cache_miss(self,
                                         mock_ipfs):
diff --git a/tests/unit/test_heatmaps.py b/tests/unit/test_heatmaps.py
index a88341b..8781d6f 100644
--- a/tests/unit/test_heatmaps.py
+++ b/tests/unit/test_heatmaps.py
@@ -1,6 +1,7 @@
 """Module contains tests for gn3.heatmaps.heatmaps"""
 from unittest import TestCase
 
+import pytest
 from numpy.testing import assert_allclose
 
 from gn3.heatmaps import (
@@ -27,6 +28,7 @@ slinked = (
 class TestHeatmap(TestCase):
     """Class for testing heatmap computation functions"""
 
+    @pytest.mark.unit_test
     def test_cluster_traits(self): # pylint: disable=R0201
         """
         Test that the clustering is working as expected.
@@ -76,11 +78,13 @@ class TestHeatmap(TestCase):
               1.7413442197913358, 0.33370067057028485, 1.3256191648260216,
               0.0)))
 
+    @pytest.mark.unit_test
     def test_compute_heatmap_order(self):
         """Test the orders."""
         self.assertEqual(
             compute_traits_order(slinked), (0, 2, 1, 7, 5, 9, 3, 6, 8, 4))
 
+    @pytest.mark.unit_test
     def test_retrieve_samples_and_values(self):
         """Test retrieval of samples and values."""
         for orders, slist, tdata, expected in [
@@ -106,6 +110,7 @@ class TestHeatmap(TestCase):
                 self.assertEqual(
                     retrieve_samples_and_values(orders, slist, tdata), expected)
 
+    @pytest.mark.unit_test
     def test_get_lrs_from_chr(self):
         """Check that function gets correct LRS values"""
         for trait, chromosome, expected in [
@@ -120,6 +125,7 @@ class TestHeatmap(TestCase):
             with self.subTest(trait=trait, chromosome=chromosome):
                 self.assertEqual(get_lrs_from_chr(trait, chromosome), expected)
 
+    @pytest.mark.unit_test
     def test_process_traits_data_for_heatmap(self):
         """Check for correct processing of data for heatmap generation."""
         self.assertEqual(
@@ -132,6 +138,7 @@ class TestHeatmap(TestCase):
              [[0.5, 0.579, 0.5],
               [0.5, 0.5, 0.5]]])
 
+    @pytest.mark.unit_test
     def test_get_loci_names(self):
         """Check that loci names are retrieved correctly."""
         for organised, expected in (