From 31ac939f58bf7b6d353ced995ca395376203b25f Mon Sep 17 00:00:00 2001 From: Alexander Kabui Date: Mon, 12 Apr 2021 09:54:12 +0300 Subject: Integrate correlation API - add new api for gn2-gn3 sample r integration - delete map for sample list to values - add db util file - add python msql-client dependency - add db for fetching lit correlation results - add unittests for db utils - add tests for db_utils - modify api for fetching lit correlation results - refactor Mock Database Connector and unittests - add sql url parser - add SQL URI env variable - refactor code for db utils - modify return data for lit correlation - refactor tissue correlation endpoint - replace db_instance with conn--- gn3/api/correlation.py | 39 +++++++-- gn3/computations/correlations.py | 104 ++++++++++++++++------ gn3/db_utils.py | 24 +++++ gn3/settings.py | 2 +- guix.scm | 2 +- requirements.txt | 6 ++ tests/integration/test_correlation.py | 27 +++--- tests/unit/computations/test_correlation.py | 131 ++++++++++++++++++++-------- tests/unit/test_db_utils.py | 37 ++++++++ 9 files changed, 290 insertions(+), 82 deletions(-) create mode 100644 gn3/db_utils.py create mode 100644 tests/unit/test_db_utils.py diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py index 53ea6a7..2339088 100644 --- a/gn3/api/correlation.py +++ b/gn3/api/correlation.py @@ -1,6 +1,4 @@ """Endpoints for running correlations""" -from unittest import mock - from flask import jsonify from flask import Blueprint from flask import request @@ -8,11 +6,31 @@ from flask import request from gn3.computations.correlations import compute_all_sample_correlation from gn3.computations.correlations import compute_all_lit_correlation from gn3.computations.correlations import compute_all_tissue_correlation - +from gn3.computations.correlations import map_shared_keys_to_values +from gn3.db_utils import database_connector correlation = Blueprint("correlation", __name__) +@correlation.route("/sample_x/", methods=["POST"]) +def compute_sample_integration(corr_method="pearson"): + """temporary api to help integrate genenetwork2 to genenetwork3 """ + + correlation_input = request.get_json() + + target_samplelist = correlation_input.get("target_samplelist") + target_data_values = correlation_input.get("target_dataset") + this_trait_data = correlation_input.get("trait_data") + + results = map_shared_keys_to_values(target_samplelist, target_data_values) + + correlation_results = compute_all_sample_correlation(corr_method=corr_method, + this_trait=this_trait_data, + target_dataset=results) + + return jsonify(correlation_results) + + @correlation.route("/sample_r/", methods=["POST"]) def compute_sample_r(corr_method="pearson"): """correlation endpoint for computing sample r correlations\ @@ -22,11 +40,11 @@ def compute_sample_r(corr_method="pearson"): # xtodo move code below to compute_all_sampl correlation this_trait_data = correlation_input.get("this_trait") - target_datasets = correlation_input.get("target_dataset") + target_dataset_data = correlation_input.get("target_dataset") correlation_results = compute_all_sample_correlation(corr_method=corr_method, this_trait=this_trait_data, - target_dataset=target_datasets) + target_dataset=target_dataset_data) return jsonify({ "corr_results": correlation_results @@ -39,13 +57,16 @@ def compute_lit_corr(species=None, gene_id=None): are fetched from the database this is the only case where the db\ might be needed for actual computing of the correlation results""" - database_instance = mock.Mock() + conn, _cursor_object = database_connector() target_traits_gene_ids = request.get_json() + target_trait_gene_list = list(target_traits_gene_ids.items()) lit_corr_results = compute_all_lit_correlation( - database_instance=database_instance, trait_lists=target_traits_gene_ids, + conn=conn, trait_lists=target_trait_gene_list, species=species, gene_id=gene_id) + conn.close() + return jsonify(lit_corr_results) @@ -54,10 +75,10 @@ def compute_tissue_corr(corr_method="pearson"): """api endpoint fr doing tissue correlation""" tissue_input_data = request.get_json() primary_tissue_dict = tissue_input_data["primary_tissue"] - target_tissues_dict_list = tissue_input_data["target_tissues"] + target_tissues_dict = tissue_input_data["target_tissues_dict"] results = compute_all_tissue_correlation(primary_tissue_dict=primary_tissue_dict, - target_tissues_dict_list=target_tissues_dict_list, + target_tissues_data=target_tissues_dict, corr_method=corr_method) return jsonify(results) diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py index dc2f8d3..26b7294 100644 --- a/gn3/computations/correlations.py +++ b/gn3/computations/correlations.py @@ -12,10 +12,30 @@ def compute_sum(rhs: int, lhs: int) -> int: return rhs + lhs +def map_shared_keys_to_values(target_sample_keys: List, target_sample_vals: dict)-> List: + """Function to construct target dataset data items given commoned shared\ + keys and trait samplelist values for example given keys >>>>>>>>>>\ + ["BXD1", "BXD2", "BXD5", "BXD6", "BXD8", "BXD9"] and value object as\ + "HCMA:_AT": [4.1, 5.6, 3.2, 1.1, 4.4, 2.2],TXD_AT": [6.2, 5.7, 3.6, 1.5, 4.2, 2.3]}\ + return results should be a list of dicts mapping the shared keys to the trait values""" + target_dataset_data = [] + + for trait_id, sample_values in target_sample_vals.items(): + target_trait_dict = dict(zip(target_sample_keys, sample_values)) + + target_trait = { + "trait_id": trait_id, + "trait_sample_data": target_trait_dict + } + + target_dataset_data.append(target_trait) + + return target_dataset_data + + def normalize_values(a_values: List, b_values: List) -> Tuple[List[float], List[float], int]: """Trim two lists of values to contain only the values they both share - Given two lists of sample values, trim each list so that it contains only the samples that contain a value in both lists. Also returns the number of such samples. @@ -175,7 +195,7 @@ def tissue_correlation_for_trait_list( """ - # ax :todo assertion that lenggth one one target tissue ==primary_tissue + # ax :todo assertion that length one one target tissue ==primary_tissue (tissue_corr_coeffient, p_value) = compute_corr_p_value(primary_values=primary_tissue_vals, @@ -192,11 +212,11 @@ def tissue_correlation_for_trait_list( def fetch_lit_correlation_data( - database, + conn, input_mouse_gene_id: Optional[str], gene_id: str, mouse_gene_id: Optional[str] = None) -> Tuple[str, float]: - """given input trait mouse gene id and mouse gene id fetch the lit\ + """Given input trait mouse gene id and mouse gene id fetch the lit\ corr_data""" if mouse_gene_id is not None and ";" not in mouse_gene_id: query = """ @@ -208,15 +228,19 @@ def fetch_lit_correlation_data( query_values = (str(mouse_gene_id), str(input_mouse_gene_id)) - results = database.execute(query_formatter(query, - *query_values)).fetchone() + cursor = conn.cursor() + + cursor.execute(query_formatter(query, + *query_values)) + results = cursor.fetchone() lit_corr_results = None if results is not None: lit_corr_results = results else: - lit_corr_results = database.execute( - query_formatter(query, - *tuple(reversed(query_values)))).fetchone() + cursor = conn.cursor() + cursor.execute(query_formatter(query, + *tuple(reversed(query_values)))) + lit_corr_results = cursor.fetchone() lit_results = (gene_id, lit_corr_results.val)\ if lit_corr_results else (gene_id, 0) return lit_results @@ -225,7 +249,7 @@ def fetch_lit_correlation_data( def lit_correlation_for_trait_list( - database, + conn, target_trait_lists: List, species: Optional[str] = None, trait_gene_id: Optional[str] = None) -> List: @@ -233,43 +257,45 @@ def lit_correlation_for_trait_list( output is float for lit corr results """ fetched_lit_corr_results = [] - this_trait_mouse_gene_id = map_to_mouse_gene_id(database=database, + this_trait_mouse_gene_id = map_to_mouse_gene_id(conn=conn, species=species, gene_id=trait_gene_id) - for trait in target_trait_lists: - target_trait_gene_id = trait.get("gene_id") + for (trait_name, target_trait_gene_id) in target_trait_lists: + corr_results = {} if target_trait_gene_id: target_mouse_gene_id = map_to_mouse_gene_id( - database=database, + conn=conn, species=species, gene_id=target_trait_gene_id) fetched_corr_data = fetch_lit_correlation_data( - database=database, + conn=conn, input_mouse_gene_id=this_trait_mouse_gene_id, gene_id=target_trait_gene_id, mouse_gene_id=target_mouse_gene_id) dict_results = dict(zip(("gene_id", "lit_corr"), fetched_corr_data)) - fetched_lit_corr_results.append(dict_results) + corr_results[trait_name] = dict_results + fetched_lit_corr_results.append(corr_results) return fetched_lit_corr_results def query_formatter(query_string: str, *query_values): - """formatter query string given the unformatted query string\ + """Formatter query string given the unformatted query string\ and the respectibe values.Assumes number of placeholders is equal to the number of query values """ + # xtodo escape sql queries results = query_string % (query_values) return results -def map_to_mouse_gene_id(database, species: Optional[str], +def map_to_mouse_gene_id(conn, species: Optional[str], gene_id: Optional[str]) -> Optional[str]: - """given a species which is not mouse map the gene_id\ + """Given a species which is not mouse map the gene_id\ to respective mouse gene id""" # AK:xtodo move the code for checking nullity out of thing functions bug # while method for string @@ -278,28 +304,29 @@ def map_to_mouse_gene_id(database, species: Optional[str], if species == "mouse": return gene_id + cursor = conn.cursor() query = """SELECT mouse FROM GeneIDXRef WHERE '%s' = '%s'""" query_values = (species, gene_id) - - results = database.execute(query_formatter(query, - *query_values)).fetchone() + cursor.execute(query_formatter(query, + *query_values)) + results = cursor.fetchone() mouse_gene_id = results.mouse if results is not None else None return mouse_gene_id -def compute_all_lit_correlation(database_instance, trait_lists: List, +def compute_all_lit_correlation(conn, trait_lists: List, species: str, gene_id): """Function that acts as an abstraction for lit_correlation_for_trait_list""" # xtodo to be refactored lit_results = lit_correlation_for_trait_list( - database=database_instance, + conn=conn, target_trait_lists=trait_lists, species=species, trait_gene_id=gene_id) @@ -308,18 +335,22 @@ def compute_all_lit_correlation(database_instance, trait_lists: List, def compute_all_tissue_correlation(primary_tissue_dict: dict, - target_tissues_dict_list: List, + target_tissues_data: dict, corr_method: str): """Function acts as an abstraction for tissue_correlation_for_trait_list\ - required input are target tissue object and primary tissue trait + required input are target tissue object and primary tissue trait\ + target tissues data contains the trait_symbol_dict and symbol_tissue_vals """ tissues_results = {} primary_tissue_vals = primary_tissue_dict["tissue_values"] + traits_symbol_dict = target_tissues_data["trait_symbol_dict"] + symbol_tissue_vals_dict = target_tissues_data["symbol_tissue_vals_dict"] - target_tissues_list = target_tissues_dict_list + target_tissues_list = process_trait_symbol_dict( + traits_symbol_dict, symbol_tissue_vals_dict) for target_tissue_obj in target_tissues_list: trait_id = target_tissue_obj.get("trait_id") @@ -334,3 +365,22 @@ def compute_all_tissue_correlation(primary_tissue_dict: dict, tissues_results[trait_id] = tissue_result return tissues_results + + +def process_trait_symbol_dict(trait_symbol_dict, symbol_tissue_vals_dict) -> List: + """Method for processing trait symbol\ + dict given the symbol tissue values """ + traits_tissue_vals = [] + + for (trait, symbol) in trait_symbol_dict.items(): + if symbol is not None: + target_symbol = symbol.lower() + if target_symbol in symbol_tissue_vals_dict: + trait_tissue_val = symbol_tissue_vals_dict[target_symbol] + target_tissue_dict = {"trait_id": trait, + "symbol": target_symbol, + "tissue_values": trait_tissue_val} + + traits_tissue_vals.append(target_tissue_dict) + + return traits_tissue_vals diff --git a/gn3/db_utils.py b/gn3/db_utils.py new file mode 100644 index 0000000..34c5bf0 --- /dev/null +++ b/gn3/db_utils.py @@ -0,0 +1,24 @@ +"""module contains all db related stuff""" +from typing import Tuple +from urllib.parse import urlparse +import MySQLdb as mdb # type: ignore +from gn3.settings import SQL_URI + + +def parse_db_url() -> Tuple: + """function to parse SQL_URI env variable note:there\ + is a default value for SQL_URI so a tuple result is\ + always expected""" + parsed_db = urlparse(SQL_URI) + return (parsed_db.hostname, parsed_db.username, + parsed_db.password, parsed_db.path[1:]) + + +def database_connector()->Tuple: + """function to create db connector""" + host, user, passwd, db_name = parse_db_url() + conn = mdb.connect(host, user, passwd, db_name) + cursor = conn.cursor() + + return (conn, cursor) + \ No newline at end of file diff --git a/gn3/settings.py b/gn3/settings.py index 2836581..e77a977 100644 --- a/gn3/settings.py +++ b/gn3/settings.py @@ -12,6 +12,6 @@ REDIS_JOB_QUEUE = "GN3::job-queue" TMPDIR = os.environ.get("TMPDIR", tempfile.gettempdir()) # SQL confs -SQLALCHEMY_DATABASE_URI = "mysql://kabui:1234@localhost/test" +SQL_URI = os.environ.get("SQL_URI", "mysql://kabui:1234@localhost/db_webqtl") SECRET_KEY = "password" SQLALCHEMY_TRACK_MODIFICATIONS = False diff --git a/guix.scm b/guix.scm index 4f4e3b2..a810a12 100644 --- a/guix.scm +++ b/guix.scm @@ -78,7 +78,7 @@ ("python-pylint" python-pylint) ("python-redis" ,python-redis) ("python-scipy" ,python-scipy) - ("python-scipy" ,python-scipy) + ("python-mysqlclient" ,python-mysqlclient) ("python-sqlalchemy-stubs" ,python-sqlalchemy-stubs))) (build-system python-build-system) (home-page "https://github.com/genenetwork/genenetwork3") diff --git a/requirements.txt b/requirements.txt index 4586f96..e581297 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,12 @@ astroid==2.5 base58==2.1.0 bcrypt==3.1.7 +certifi==2020.12.5 cffi==1.14.5 +chardet==4.0.0 click==7.1.2 Flask==1.1.2 +idna==2.10 ipfshttpclient==0.7.0 isort==4.3.21 itsdangerous==1.1.0 @@ -14,16 +17,19 @@ mccabe==0.6.1 multiaddr==0.0.9 mypy==0.790 mypy-extensions==0.4.3 +mysqlclient==2.0.3 netaddr==0.8.0 numpy==1.20.1 pycparser==2.20 pylint==2.5.3 redis==3.5.3 +requests==2.25.1 scipy==1.6.0 six==1.15.0 toml==0.10.2 typed-ast==1.4.2 typing-extensions==3.7.4.3 +urllib3==1.26.4 varint==1.0.2 Werkzeug==1.0.1 wrapt==1.12.1 diff --git a/tests/integration/test_correlation.py b/tests/integration/test_correlation.py index 488a8a4..e67f58d 100644 --- a/tests/integration/test_correlation.py +++ b/tests/integration/test_correlation.py @@ -10,10 +10,6 @@ class CorrelationIntegrationTest(TestCase): def setUp(self): self.app = create_app().test_client() - def test_fail(self): - """initial method for class that fails""" - self.assertEqual(2, 2) - @mock.patch("gn3.api.correlation.compute_all_sample_correlation") def test_sample_r_correlation(self, mock_compute_samples): """Test /api/correlation/sample_r/{method}""" @@ -66,13 +62,17 @@ class CorrelationIntegrationTest(TestCase): self.assertEqual(response.get_json(), api_response) @mock.patch("gn3.api.correlation.compute_all_lit_correlation") - def test_lit_correlation(self, mock_compute_corr): + @mock.patch("gn3.api.correlation.database_connector") + def test_lit_correlation(self, database_connector, mock_compute_corr): """Test api/correlation/lit_corr/{species}/{gene_id}""" mock_compute_corr.return_value = [] - post_data = [{"gene_id": 8, "lit_corr": 1}, { - "gene_id": 12, "lit_corr": 0.3}] + database_connector.return_value = (mock.Mock(), mock.Mock()) + + post_data = {"1426678_at": "68031", + "1426679_at": "68036", + "1426680_at": "74777"} response = self.app.post( "/api/correlation/lit_corr/mouse/16", json=post_data, follow_redirects=True) @@ -85,13 +85,20 @@ class CorrelationIntegrationTest(TestCase): """Test api/correlation/tissue_corr/{corr_method}""" mock_tissue_corr.return_value = {} + target_trait_symbol_dict = { + "1418702_a_at": "Bxdc1", "1412_at": "Bxdc2"} + symbol_tissue_dict = { + "bxdc1": [12, 21.1, 11.4, 16.7], "bxdc2": [12, 20.1, 12.4, 1.1]} + primary_dict = {"trait_id": "1449593_at", "tissue_values": [1, 2, 3]} - target_tissue_dict_list = [ - {"trait_id": "1449593_at", "tissue_values": [1, 2, 3]}] + target_tissue_data = { + "trait_symbol_dict": target_trait_symbol_dict, + "symbol_tissue_vals_dict": symbol_tissue_dict + } tissue_corr_input_data = {"primary_tissue": primary_dict, - "target_tissues": target_tissue_dict_list} + "target_tissues_dict": target_tissue_data} response = self.app.post("/api/correlation/tissue_corr/spearman", json=tissue_corr_input_data, follow_redirects=True) diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py index 84b9330..52d1f60 100644 --- a/tests/unit/computations/test_correlation.py +++ b/tests/unit/computations/test_correlation.py @@ -18,6 +18,8 @@ from gn3.computations.correlations import query_formatter from gn3.computations.correlations import map_to_mouse_gene_id from gn3.computations.correlations import compute_all_lit_correlation from gn3.computations.correlations import compute_all_tissue_correlation +from gn3.computations.correlations import map_shared_keys_to_values +from gn3.computations.correlations import process_trait_symbol_dict class QueryableMixin: @@ -27,6 +29,10 @@ class QueryableMixin: """base method for execute""" raise NotImplementedError() + def cursor(self): + """method for creating db cursor""" + raise NotImplementedError() + def fetchone(self): """base method for fetching one iten""" raise NotImplementedError() @@ -46,37 +52,39 @@ class IllegalOperationError(Exception): class DataBase(QueryableMixin): """Class for creating db object""" - def __init__(self): + def __init__(self, expected_results=None, password="1234", db_name=None): + """expects the expectede results value to be an array""" + self.password = password + self.db_name = db_name self.__query_options = None - self.__results = None + self.results_generator(expected_results) def execute(self, query_options): """method to execute an sql query""" self.__query_options = query_options - self.results_generator() + return 1 + + def cursor(self): + """method for creating db cursor""" return self def fetchone(self): """method to fetch single item from the db query""" if self.__results is None: - raise IllegalOperationError() + return None return self.__results[0] def fetchall(self): """method for fetching all items from db query""" if self.__results is None: - raise IllegalOperationError() + return None return self.__results - def results_generator(self, expected_results=None): + def results_generator(self, expected_results): """private method for generating mock results""" - if expected_results is None: - self.__results = [namedtuple("lit_coeff", "val")(x*0.1) - for x in range(1, 4)] - else: - self.__results = expected_results + self.__results = expected_results class TestCorrelation(TestCase): @@ -236,21 +244,23 @@ class TestCorrelation(TestCase): """fetch results from db call for lit correlation given a trait list\ after doing correlation""" - target_trait_lists = [{"gene_id": 15}, - {"gene_id": 17}, - {"gene_id": 11}] + target_trait_lists = [("1426679_at", 15), + ("1426702_at", 17), + ("1426682_at", 11)] mock_mouse_gene_id.side_effect = [12, 11, 18, 16, 20] - database_instance = namedtuple("database", "execute")("fetchone") + conn = DataBase() fetch_lit_data.side_effect = [(15, 9), (17, 8), (11, 12)] lit_results = lit_correlation_for_trait_list( - database=database_instance, target_trait_lists=target_trait_lists, + conn=conn, target_trait_lists=target_trait_lists, species="rat", trait_gene_id="12") - expected_results = [{"gene_id": 15, "lit_corr": 9}, { - "gene_id": 17, "lit_corr": 8}, {"gene_id": 11, "lit_corr": 12}] + expected_results = [{"1426679_at": {"gene_id": 15, "lit_corr": 9}}, + {"1426702_at": { + "gene_id": 17, "lit_corr": 8}}, + {"1426682_at": {"gene_id": 11, "lit_corr": 12}}] self.assertEqual(lit_results, expected_results) @@ -258,8 +268,8 @@ class TestCorrelation(TestCase): """test for fetching lit correlation data from\ the database where the input and mouse geneid are none""" - database_instance = DataBase() - results = fetch_lit_correlation_data(database=database_instance, + conn = DataBase() + results = fetch_lit_correlation_data(conn=conn, gene_id="1", input_mouse_gene_id=None, mouse_gene_id=None) @@ -270,10 +280,12 @@ class TestCorrelation(TestCase): """test for fetching lit corr coefficent givent the input\ input trait mouse gene id and mouse gene id""" - database_instance = DataBase() + expected_db_results = [namedtuple("lit_coeff", "val")(x*0.1) + for x in range(1, 4)] + database_instance = DataBase(expected_results=expected_db_results) expected_results = ("1", 0.1) - lit_results = fetch_lit_correlation_data(database=database_instance, + lit_results = fetch_lit_correlation_data(conn=database_instance, gene_id="1", input_mouse_gene_id="20", mouse_gene_id="15") @@ -283,10 +295,8 @@ class TestCorrelation(TestCase): def test_query_lit_correlation_for_db_empty(self): """test that corr coeffient returned is 0 given the\ db value if corr coefficient is empty""" - database_instance = mock.Mock() - database_instance.execute.return_value.fetchone.return_value = None - - lit_results = fetch_lit_correlation_data(database=database_instance, + database_instance = DataBase() + lit_results = fetch_lit_correlation_data(conn=database_instance, input_mouse_gene_id="12", gene_id="16", mouse_gene_id="12") @@ -336,13 +346,15 @@ class TestCorrelation(TestCase): database_results = [namedtuple("mouse_id", "mouse")(val) for val in range(12, 20)] results = [] - - database_instance.execute.return_value.fetchone.side_effect = database_results + cursor = mock.Mock() + cursor.execute.return_value = 1 + cursor.fetchone.side_effect = database_results + database_instance.cursor.return_value = cursor expected_results = [12, None, 13, 14] for (species, gene_id) in test_data: mouse_gene_id_results = map_to_mouse_gene_id( - database=database_instance, species=species, gene_id=gene_id) + conn=database_instance, species=species, gene_id=gene_id) results.append(mouse_gene_id_results) self.assertEqual(results, expected_results) @@ -361,7 +373,7 @@ class TestCorrelation(TestCase): mock_lit_corr.side_effect = expected_mocked_lit_results lit_correlation_results = compute_all_lit_correlation( - database_instance=database, trait_lists=[{"gene_id": 11}], + conn=database, trait_lists=[{"gene_id": 11}], species="rat", gene_id=12) expected_results = { @@ -371,15 +383,26 @@ class TestCorrelation(TestCase): self.assertEqual(lit_correlation_results, expected_results) @mock.patch("gn3.computations.correlations.tissue_correlation_for_trait_list") - def test_compute_all_tissue_correlation(self, mock_tissue_corr): + @mock.patch("gn3.computations.correlations.process_trait_symbol_dict") + def test_compute_all_tissue_correlation(self, process_trait_symbol, mock_tissue_corr): """test for compute all tissue corelation which abstracts api calling the tissue_correlation for trait_list""" primary_tissue_dict = {"trait_id": "1419792_at", "tissue_values": [1, 2, 3, 4, 5]} - target_tissue_dict = [{"trait_id": "1418702_a_at", "tissue_values": [1, 2, 3]}, - {"trait_id": "1412_at", "tissue_values": [1, 2, 3]}] + target_tissue_dict = [{"trait_id": "1418702_a_at", + "symbol": "zf", "tissue_values": [1, 2, 3]}, + {"trait_id": "1412_at", + "symbol": "prkce", "tissue_values": [1, 2, 3]}] + + process_trait_symbol.return_value = target_tissue_dict + + target_trait_symbol = {"1418702_a_at": "Zf", "1412_at": "Prkce"} + target_symbol_tissue_vals = {"zf": [1, 2, 3], "prkce": [1, 2, 3]} + + target_tissue_data = {"trait_symbol_dict": target_trait_symbol, + "symbol_tissue_vals_dict": target_symbol_tissue_vals} mock_tissue_corr.side_effect = [{"tissue_corr": -0.5, "p_value": 0.9, "tissue_number": 3}, {"tissue_corr": 1.11, "p_value": 0.2, "tissue_number": 3}] @@ -391,9 +414,49 @@ class TestCorrelation(TestCase): results = compute_all_tissue_correlation( primary_tissue_dict=primary_tissue_dict, - target_tissues_dict_list=target_tissue_dict, + target_tissues_data=target_tissue_data, corr_method="pearson") + process_trait_symbol.assert_called_once_with( + target_trait_symbol, target_symbol_tissue_vals) self.assertEqual(mock_tissue_corr.call_count, 2) self.assertEqual(results, expected_results) + + def test_map_shared_keys_to_values(self): + """test helper function needed to integrate with genenenetwork2\ + given a a samplelist containing dataset sampelist keys\ + map that to given sample values """ + + dataset_sample_keys = ["BXD1", "BXD2", "BXD5"] + + target_dataset_data = {"HCMA:_AT": [4.1, 5.6, 3.2], + "TXD_AT": [6.2, 5.7, 3.6, ]} + + expected_results = [{"trait_id": "HCMA:_AT", + "trait_sample_data": {"BXD1": 4.1, "BXD2": 5.6, "BXD5": 3.2}}, + {"trait_id": "TXD_AT", + "trait_sample_data": {"BXD1": 6.2, "BXD2": 5.7, "BXD5": 3.6}}] + + results = map_shared_keys_to_values( + dataset_sample_keys, target_dataset_data) + + self.assertEqual(results, expected_results) + + def test_process_trait_symbol_dict(self): + """test for processing trait symbol dict\ + and fetch tissue values from tissue value dict\ + """ + trait_symbol_dict = {"1452864_at": "Igsf10"} + tissue_values_dict = {"igsf10": [8.9615, 10.6375, 9.2795, 8.6605]} + + expected_results = { + "trait_id": "1452864_at", + "symbol": "igsf10", + "tissue_values": [8.9615, 10.6375, 9.2795, 8.6605] + } + + results = process_trait_symbol_dict( + trait_symbol_dict, tissue_values_dict) + + self.assertEqual(results, [expected_results]) diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py new file mode 100644 index 0000000..0f2de9e --- /dev/null +++ b/tests/unit/test_db_utils.py @@ -0,0 +1,37 @@ +"""module contains test for db_utils""" + +from unittest import TestCase +from unittest import mock + +from types import SimpleNamespace + +from gn3.db_utils import database_connector +from gn3.db_utils import parse_db_url + + +class TestDatabase(TestCase): + """class contains testd for db connection functions""" + + @mock.patch("gn3.db_utils.mdb") + @mock.patch("gn3.db_utils.parse_db_url") + def test_database_connector(self, mock_db_parser, mock_sql): + """test for creating database connection""" + mock_db_parser.return_value = ("localhost", "guest", "4321", "users") + callable_cursor = lambda: SimpleNamespace(execute=3) + cursor_object = SimpleNamespace(cursor=callable_cursor) + mock_sql.connect.return_value = cursor_object + mock_sql.close.return_value = None + results = database_connector() + + mock_sql.connect.assert_called_with( + "localhost", "guest", "4321", "users") + self.assertIsInstance( + results, tuple, "database not created successfully") + + @mock.patch("gn3.db_utils.SQL_URI", + "mysql://username:4321@localhost/test") + def test_parse_db_url(self): + """test for parsing db_uri env variable""" + results = parse_db_url() + expected_results = ("localhost", "username", "4321", "test") + self.assertEqual(results, expected_results) -- cgit v1.2.3