diff options
Diffstat (limited to 'tests/unit')
-rw-r--r-- | tests/unit/computations/test_correlation.py | 131 | ||||
-rw-r--r-- | tests/unit/test_db_utils.py | 37 |
2 files changed, 134 insertions, 34 deletions
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py index 84b9330..52d1f60 100644 --- a/tests/unit/computations/test_correlation.py +++ b/tests/unit/computations/test_correlation.py @@ -18,6 +18,8 @@ from gn3.computations.correlations import query_formatter from gn3.computations.correlations import map_to_mouse_gene_id from gn3.computations.correlations import compute_all_lit_correlation from gn3.computations.correlations import compute_all_tissue_correlation +from gn3.computations.correlations import map_shared_keys_to_values +from gn3.computations.correlations import process_trait_symbol_dict class QueryableMixin: @@ -27,6 +29,10 @@ class QueryableMixin: """base method for execute""" raise NotImplementedError() + def cursor(self): + """method for creating db cursor""" + raise NotImplementedError() + def fetchone(self): """base method for fetching one iten""" raise NotImplementedError() @@ -46,37 +52,39 @@ class IllegalOperationError(Exception): class DataBase(QueryableMixin): """Class for creating db object""" - def __init__(self): + def __init__(self, expected_results=None, password="1234", db_name=None): + """expects the expectede results value to be an array""" + self.password = password + self.db_name = db_name self.__query_options = None - self.__results = None + self.results_generator(expected_results) def execute(self, query_options): """method to execute an sql query""" self.__query_options = query_options - self.results_generator() + return 1 + + def cursor(self): + """method for creating db cursor""" return self def fetchone(self): """method to fetch single item from the db query""" if self.__results is None: - raise IllegalOperationError() + return None return self.__results[0] def fetchall(self): """method for fetching all items from db query""" if self.__results is None: - raise IllegalOperationError() + return None return self.__results - def results_generator(self, expected_results=None): + def results_generator(self, expected_results): """private method for generating mock results""" - if expected_results is None: - self.__results = [namedtuple("lit_coeff", "val")(x*0.1) - for x in range(1, 4)] - else: - self.__results = expected_results + self.__results = expected_results class TestCorrelation(TestCase): @@ -236,21 +244,23 @@ class TestCorrelation(TestCase): """fetch results from db call for lit correlation given a trait list\ after doing correlation""" - target_trait_lists = [{"gene_id": 15}, - {"gene_id": 17}, - {"gene_id": 11}] + target_trait_lists = [("1426679_at", 15), + ("1426702_at", 17), + ("1426682_at", 11)] mock_mouse_gene_id.side_effect = [12, 11, 18, 16, 20] - database_instance = namedtuple("database", "execute")("fetchone") + conn = DataBase() fetch_lit_data.side_effect = [(15, 9), (17, 8), (11, 12)] lit_results = lit_correlation_for_trait_list( - database=database_instance, target_trait_lists=target_trait_lists, + conn=conn, target_trait_lists=target_trait_lists, species="rat", trait_gene_id="12") - expected_results = [{"gene_id": 15, "lit_corr": 9}, { - "gene_id": 17, "lit_corr": 8}, {"gene_id": 11, "lit_corr": 12}] + expected_results = [{"1426679_at": {"gene_id": 15, "lit_corr": 9}}, + {"1426702_at": { + "gene_id": 17, "lit_corr": 8}}, + {"1426682_at": {"gene_id": 11, "lit_corr": 12}}] self.assertEqual(lit_results, expected_results) @@ -258,8 +268,8 @@ class TestCorrelation(TestCase): """test for fetching lit correlation data from\ the database where the input and mouse geneid are none""" - database_instance = DataBase() - results = fetch_lit_correlation_data(database=database_instance, + conn = DataBase() + results = fetch_lit_correlation_data(conn=conn, gene_id="1", input_mouse_gene_id=None, mouse_gene_id=None) @@ -270,10 +280,12 @@ class TestCorrelation(TestCase): """test for fetching lit corr coefficent givent the input\ input trait mouse gene id and mouse gene id""" - database_instance = DataBase() + expected_db_results = [namedtuple("lit_coeff", "val")(x*0.1) + for x in range(1, 4)] + database_instance = DataBase(expected_results=expected_db_results) expected_results = ("1", 0.1) - lit_results = fetch_lit_correlation_data(database=database_instance, + lit_results = fetch_lit_correlation_data(conn=database_instance, gene_id="1", input_mouse_gene_id="20", mouse_gene_id="15") @@ -283,10 +295,8 @@ class TestCorrelation(TestCase): def test_query_lit_correlation_for_db_empty(self): """test that corr coeffient returned is 0 given the\ db value if corr coefficient is empty""" - database_instance = mock.Mock() - database_instance.execute.return_value.fetchone.return_value = None - - lit_results = fetch_lit_correlation_data(database=database_instance, + database_instance = DataBase() + lit_results = fetch_lit_correlation_data(conn=database_instance, input_mouse_gene_id="12", gene_id="16", mouse_gene_id="12") @@ -336,13 +346,15 @@ class TestCorrelation(TestCase): database_results = [namedtuple("mouse_id", "mouse")(val) for val in range(12, 20)] results = [] - - database_instance.execute.return_value.fetchone.side_effect = database_results + cursor = mock.Mock() + cursor.execute.return_value = 1 + cursor.fetchone.side_effect = database_results + database_instance.cursor.return_value = cursor expected_results = [12, None, 13, 14] for (species, gene_id) in test_data: mouse_gene_id_results = map_to_mouse_gene_id( - database=database_instance, species=species, gene_id=gene_id) + conn=database_instance, species=species, gene_id=gene_id) results.append(mouse_gene_id_results) self.assertEqual(results, expected_results) @@ -361,7 +373,7 @@ class TestCorrelation(TestCase): mock_lit_corr.side_effect = expected_mocked_lit_results lit_correlation_results = compute_all_lit_correlation( - database_instance=database, trait_lists=[{"gene_id": 11}], + conn=database, trait_lists=[{"gene_id": 11}], species="rat", gene_id=12) expected_results = { @@ -371,15 +383,26 @@ class TestCorrelation(TestCase): self.assertEqual(lit_correlation_results, expected_results) @mock.patch("gn3.computations.correlations.tissue_correlation_for_trait_list") - def test_compute_all_tissue_correlation(self, mock_tissue_corr): + @mock.patch("gn3.computations.correlations.process_trait_symbol_dict") + def test_compute_all_tissue_correlation(self, process_trait_symbol, mock_tissue_corr): """test for compute all tissue corelation which abstracts api calling the tissue_correlation for trait_list""" primary_tissue_dict = {"trait_id": "1419792_at", "tissue_values": [1, 2, 3, 4, 5]} - target_tissue_dict = [{"trait_id": "1418702_a_at", "tissue_values": [1, 2, 3]}, - {"trait_id": "1412_at", "tissue_values": [1, 2, 3]}] + target_tissue_dict = [{"trait_id": "1418702_a_at", + "symbol": "zf", "tissue_values": [1, 2, 3]}, + {"trait_id": "1412_at", + "symbol": "prkce", "tissue_values": [1, 2, 3]}] + + process_trait_symbol.return_value = target_tissue_dict + + target_trait_symbol = {"1418702_a_at": "Zf", "1412_at": "Prkce"} + target_symbol_tissue_vals = {"zf": [1, 2, 3], "prkce": [1, 2, 3]} + + target_tissue_data = {"trait_symbol_dict": target_trait_symbol, + "symbol_tissue_vals_dict": target_symbol_tissue_vals} mock_tissue_corr.side_effect = [{"tissue_corr": -0.5, "p_value": 0.9, "tissue_number": 3}, {"tissue_corr": 1.11, "p_value": 0.2, "tissue_number": 3}] @@ -391,9 +414,49 @@ class TestCorrelation(TestCase): results = compute_all_tissue_correlation( primary_tissue_dict=primary_tissue_dict, - target_tissues_dict_list=target_tissue_dict, + target_tissues_data=target_tissue_data, corr_method="pearson") + process_trait_symbol.assert_called_once_with( + target_trait_symbol, target_symbol_tissue_vals) self.assertEqual(mock_tissue_corr.call_count, 2) self.assertEqual(results, expected_results) + + def test_map_shared_keys_to_values(self): + """test helper function needed to integrate with genenenetwork2\ + given a a samplelist containing dataset sampelist keys\ + map that to given sample values """ + + dataset_sample_keys = ["BXD1", "BXD2", "BXD5"] + + target_dataset_data = {"HCMA:_AT": [4.1, 5.6, 3.2], + "TXD_AT": [6.2, 5.7, 3.6, ]} + + expected_results = [{"trait_id": "HCMA:_AT", + "trait_sample_data": {"BXD1": 4.1, "BXD2": 5.6, "BXD5": 3.2}}, + {"trait_id": "TXD_AT", + "trait_sample_data": {"BXD1": 6.2, "BXD2": 5.7, "BXD5": 3.6}}] + + results = map_shared_keys_to_values( + dataset_sample_keys, target_dataset_data) + + self.assertEqual(results, expected_results) + + def test_process_trait_symbol_dict(self): + """test for processing trait symbol dict\ + and fetch tissue values from tissue value dict\ + """ + trait_symbol_dict = {"1452864_at": "Igsf10"} + tissue_values_dict = {"igsf10": [8.9615, 10.6375, 9.2795, 8.6605]} + + expected_results = { + "trait_id": "1452864_at", + "symbol": "igsf10", + "tissue_values": [8.9615, 10.6375, 9.2795, 8.6605] + } + + results = process_trait_symbol_dict( + trait_symbol_dict, tissue_values_dict) + + self.assertEqual(results, [expected_results]) diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py new file mode 100644 index 0000000..0f2de9e --- /dev/null +++ b/tests/unit/test_db_utils.py @@ -0,0 +1,37 @@ +"""module contains test for db_utils""" + +from unittest import TestCase +from unittest import mock + +from types import SimpleNamespace + +from gn3.db_utils import database_connector +from gn3.db_utils import parse_db_url + + +class TestDatabase(TestCase): + """class contains testd for db connection functions""" + + @mock.patch("gn3.db_utils.mdb") + @mock.patch("gn3.db_utils.parse_db_url") + def test_database_connector(self, mock_db_parser, mock_sql): + """test for creating database connection""" + mock_db_parser.return_value = ("localhost", "guest", "4321", "users") + callable_cursor = lambda: SimpleNamespace(execute=3) + cursor_object = SimpleNamespace(cursor=callable_cursor) + mock_sql.connect.return_value = cursor_object + mock_sql.close.return_value = None + results = database_connector() + + mock_sql.connect.assert_called_with( + "localhost", "guest", "4321", "users") + self.assertIsInstance( + results, tuple, "database not created successfully") + + @mock.patch("gn3.db_utils.SQL_URI", + "mysql://username:4321@localhost/test") + def test_parse_db_url(self): + """test for parsing db_uri env variable""" + results = parse_db_url() + expected_results = ("localhost", "username", "4321", "test") + self.assertEqual(results, expected_results) |