aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/api/correlation.py39
-rw-r--r--gn3/computations/correlations.py104
-rw-r--r--gn3/db_utils.py24
-rw-r--r--gn3/settings.py2
-rw-r--r--guix.scm2
-rw-r--r--requirements.txt6
-rw-r--r--tests/integration/test_correlation.py27
-rw-r--r--tests/unit/computations/test_correlation.py131
-rw-r--r--tests/unit/test_db_utils.py37
9 files changed, 290 insertions, 82 deletions
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py
index 53ea6a7..2339088 100644
--- a/gn3/api/correlation.py
+++ b/gn3/api/correlation.py
@@ -1,6 +1,4 @@
"""Endpoints for running correlations"""
-from unittest import mock
-
from flask import jsonify
from flask import Blueprint
from flask import request
@@ -8,11 +6,31 @@ from flask import request
from gn3.computations.correlations import compute_all_sample_correlation
from gn3.computations.correlations import compute_all_lit_correlation
from gn3.computations.correlations import compute_all_tissue_correlation
-
+from gn3.computations.correlations import map_shared_keys_to_values
+from gn3.db_utils import database_connector
correlation = Blueprint("correlation", __name__)
+@correlation.route("/sample_x/<string:corr_method>", methods=["POST"])
+def compute_sample_integration(corr_method="pearson"):
+ """temporary api to help integrate genenetwork2 to genenetwork3 """
+
+ correlation_input = request.get_json()
+
+ target_samplelist = correlation_input.get("target_samplelist")
+ target_data_values = correlation_input.get("target_dataset")
+ this_trait_data = correlation_input.get("trait_data")
+
+ results = map_shared_keys_to_values(target_samplelist, target_data_values)
+
+ correlation_results = compute_all_sample_correlation(corr_method=corr_method,
+ this_trait=this_trait_data,
+ target_dataset=results)
+
+ return jsonify(correlation_results)
+
+
@correlation.route("/sample_r/<string:corr_method>", methods=["POST"])
def compute_sample_r(corr_method="pearson"):
"""correlation endpoint for computing sample r correlations\
@@ -22,11 +40,11 @@ def compute_sample_r(corr_method="pearson"):
# xtodo move code below to compute_all_sampl correlation
this_trait_data = correlation_input.get("this_trait")
- target_datasets = correlation_input.get("target_dataset")
+ target_dataset_data = correlation_input.get("target_dataset")
correlation_results = compute_all_sample_correlation(corr_method=corr_method,
this_trait=this_trait_data,
- target_dataset=target_datasets)
+ target_dataset=target_dataset_data)
return jsonify({
"corr_results": correlation_results
@@ -39,13 +57,16 @@ def compute_lit_corr(species=None, gene_id=None):
are fetched from the database this is the only case where the db\
might be needed for actual computing of the correlation results"""
- database_instance = mock.Mock()
+ conn, _cursor_object = database_connector()
target_traits_gene_ids = request.get_json()
+ target_trait_gene_list = list(target_traits_gene_ids.items())
lit_corr_results = compute_all_lit_correlation(
- database_instance=database_instance, trait_lists=target_traits_gene_ids,
+ conn=conn, trait_lists=target_trait_gene_list,
species=species, gene_id=gene_id)
+ conn.close()
+
return jsonify(lit_corr_results)
@@ -54,10 +75,10 @@ def compute_tissue_corr(corr_method="pearson"):
"""api endpoint fr doing tissue correlation"""
tissue_input_data = request.get_json()
primary_tissue_dict = tissue_input_data["primary_tissue"]
- target_tissues_dict_list = tissue_input_data["target_tissues"]
+ target_tissues_dict = tissue_input_data["target_tissues_dict"]
results = compute_all_tissue_correlation(primary_tissue_dict=primary_tissue_dict,
- target_tissues_dict_list=target_tissues_dict_list,
+ target_tissues_data=target_tissues_dict,
corr_method=corr_method)
return jsonify(results)
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py
index dc2f8d3..26b7294 100644
--- a/gn3/computations/correlations.py
+++ b/gn3/computations/correlations.py
@@ -12,10 +12,30 @@ def compute_sum(rhs: int, lhs: int) -> int:
return rhs + lhs
+def map_shared_keys_to_values(target_sample_keys: List, target_sample_vals: dict)-> List:
+ """Function to construct target dataset data items given commoned shared\
+ keys and trait samplelist values for example given keys >>>>>>>>>>\
+ ["BXD1", "BXD2", "BXD5", "BXD6", "BXD8", "BXD9"] and value object as\
+ "HCMA:_AT": [4.1, 5.6, 3.2, 1.1, 4.4, 2.2],TXD_AT": [6.2, 5.7, 3.6, 1.5, 4.2, 2.3]}\
+ return results should be a list of dicts mapping the shared keys to the trait values"""
+ target_dataset_data = []
+
+ for trait_id, sample_values in target_sample_vals.items():
+ target_trait_dict = dict(zip(target_sample_keys, sample_values))
+
+ target_trait = {
+ "trait_id": trait_id,
+ "trait_sample_data": target_trait_dict
+ }
+
+ target_dataset_data.append(target_trait)
+
+ return target_dataset_data
+
+
def normalize_values(a_values: List,
b_values: List) -> Tuple[List[float], List[float], int]:
"""Trim two lists of values to contain only the values they both share
-
Given two lists of sample values, trim each list so that it contains only
the samples that contain a value in both lists. Also returns the number of
such samples.
@@ -175,7 +195,7 @@ def tissue_correlation_for_trait_list(
"""
- # ax :todo assertion that lenggth one one target tissue ==primary_tissue
+ # ax :todo assertion that length one one target tissue ==primary_tissue
(tissue_corr_coeffient,
p_value) = compute_corr_p_value(primary_values=primary_tissue_vals,
@@ -192,11 +212,11 @@ def tissue_correlation_for_trait_list(
def fetch_lit_correlation_data(
- database,
+ conn,
input_mouse_gene_id: Optional[str],
gene_id: str,
mouse_gene_id: Optional[str] = None) -> Tuple[str, float]:
- """given input trait mouse gene id and mouse gene id fetch the lit\
+ """Given input trait mouse gene id and mouse gene id fetch the lit\
corr_data"""
if mouse_gene_id is not None and ";" not in mouse_gene_id:
query = """
@@ -208,15 +228,19 @@ def fetch_lit_correlation_data(
query_values = (str(mouse_gene_id), str(input_mouse_gene_id))
- results = database.execute(query_formatter(query,
- *query_values)).fetchone()
+ cursor = conn.cursor()
+
+ cursor.execute(query_formatter(query,
+ *query_values))
+ results = cursor.fetchone()
lit_corr_results = None
if results is not None:
lit_corr_results = results
else:
- lit_corr_results = database.execute(
- query_formatter(query,
- *tuple(reversed(query_values)))).fetchone()
+ cursor = conn.cursor()
+ cursor.execute(query_formatter(query,
+ *tuple(reversed(query_values))))
+ lit_corr_results = cursor.fetchone()
lit_results = (gene_id, lit_corr_results.val)\
if lit_corr_results else (gene_id, 0)
return lit_results
@@ -225,7 +249,7 @@ def fetch_lit_correlation_data(
def lit_correlation_for_trait_list(
- database,
+ conn,
target_trait_lists: List,
species: Optional[str] = None,
trait_gene_id: Optional[str] = None) -> List:
@@ -233,43 +257,45 @@ def lit_correlation_for_trait_list(
output is float for lit corr results """
fetched_lit_corr_results = []
- this_trait_mouse_gene_id = map_to_mouse_gene_id(database=database,
+ this_trait_mouse_gene_id = map_to_mouse_gene_id(conn=conn,
species=species,
gene_id=trait_gene_id)
- for trait in target_trait_lists:
- target_trait_gene_id = trait.get("gene_id")
+ for (trait_name, target_trait_gene_id) in target_trait_lists:
+ corr_results = {}
if target_trait_gene_id:
target_mouse_gene_id = map_to_mouse_gene_id(
- database=database,
+ conn=conn,
species=species,
gene_id=target_trait_gene_id)
fetched_corr_data = fetch_lit_correlation_data(
- database=database,
+ conn=conn,
input_mouse_gene_id=this_trait_mouse_gene_id,
gene_id=target_trait_gene_id,
mouse_gene_id=target_mouse_gene_id)
dict_results = dict(zip(("gene_id", "lit_corr"),
fetched_corr_data))
- fetched_lit_corr_results.append(dict_results)
+ corr_results[trait_name] = dict_results
+ fetched_lit_corr_results.append(corr_results)
return fetched_lit_corr_results
def query_formatter(query_string: str, *query_values):
- """formatter query string given the unformatted query string\
+ """Formatter query string given the unformatted query string\
and the respectibe values.Assumes number of placeholders is
equal to the number of query values """
+ # xtodo escape sql queries
results = query_string % (query_values)
return results
-def map_to_mouse_gene_id(database, species: Optional[str],
+def map_to_mouse_gene_id(conn, species: Optional[str],
gene_id: Optional[str]) -> Optional[str]:
- """given a species which is not mouse map the gene_id\
+ """Given a species which is not mouse map the gene_id\
to respective mouse gene id"""
# AK:xtodo move the code for checking nullity out of thing functions bug
# while method for string
@@ -278,28 +304,29 @@ def map_to_mouse_gene_id(database, species: Optional[str],
if species == "mouse":
return gene_id
+ cursor = conn.cursor()
query = """SELECT mouse
FROM GeneIDXRef
WHERE '%s' = '%s'"""
query_values = (species, gene_id)
-
- results = database.execute(query_formatter(query,
- *query_values)).fetchone()
+ cursor.execute(query_formatter(query,
+ *query_values))
+ results = cursor.fetchone()
mouse_gene_id = results.mouse if results is not None else None
return mouse_gene_id
-def compute_all_lit_correlation(database_instance, trait_lists: List,
+def compute_all_lit_correlation(conn, trait_lists: List,
species: str, gene_id):
"""Function that acts as an abstraction for
lit_correlation_for_trait_list"""
# xtodo to be refactored
lit_results = lit_correlation_for_trait_list(
- database=database_instance,
+ conn=conn,
target_trait_lists=trait_lists,
species=species,
trait_gene_id=gene_id)
@@ -308,18 +335,22 @@ def compute_all_lit_correlation(database_instance, trait_lists: List,
def compute_all_tissue_correlation(primary_tissue_dict: dict,
- target_tissues_dict_list: List,
+ target_tissues_data: dict,
corr_method: str):
"""Function acts as an abstraction for tissue_correlation_for_trait_list\
- required input are target tissue object and primary tissue trait
+ required input are target tissue object and primary tissue trait\
+ target tissues data contains the trait_symbol_dict and symbol_tissue_vals
"""
tissues_results = {}
primary_tissue_vals = primary_tissue_dict["tissue_values"]
+ traits_symbol_dict = target_tissues_data["trait_symbol_dict"]
+ symbol_tissue_vals_dict = target_tissues_data["symbol_tissue_vals_dict"]
- target_tissues_list = target_tissues_dict_list
+ target_tissues_list = process_trait_symbol_dict(
+ traits_symbol_dict, symbol_tissue_vals_dict)
for target_tissue_obj in target_tissues_list:
trait_id = target_tissue_obj.get("trait_id")
@@ -334,3 +365,22 @@ def compute_all_tissue_correlation(primary_tissue_dict: dict,
tissues_results[trait_id] = tissue_result
return tissues_results
+
+
+def process_trait_symbol_dict(trait_symbol_dict, symbol_tissue_vals_dict) -> List:
+ """Method for processing trait symbol\
+ dict given the symbol tissue values """
+ traits_tissue_vals = []
+
+ for (trait, symbol) in trait_symbol_dict.items():
+ if symbol is not None:
+ target_symbol = symbol.lower()
+ if target_symbol in symbol_tissue_vals_dict:
+ trait_tissue_val = symbol_tissue_vals_dict[target_symbol]
+ target_tissue_dict = {"trait_id": trait,
+ "symbol": target_symbol,
+ "tissue_values": trait_tissue_val}
+
+ traits_tissue_vals.append(target_tissue_dict)
+
+ return traits_tissue_vals
diff --git a/gn3/db_utils.py b/gn3/db_utils.py
new file mode 100644
index 0000000..34c5bf0
--- /dev/null
+++ b/gn3/db_utils.py
@@ -0,0 +1,24 @@
+"""module contains all db related stuff"""
+from typing import Tuple
+from urllib.parse import urlparse
+import MySQLdb as mdb # type: ignore
+from gn3.settings import SQL_URI
+
+
+def parse_db_url() -> Tuple:
+ """function to parse SQL_URI env variable note:there\
+ is a default value for SQL_URI so a tuple result is\
+ always expected"""
+ parsed_db = urlparse(SQL_URI)
+ return (parsed_db.hostname, parsed_db.username,
+ parsed_db.password, parsed_db.path[1:])
+
+
+def database_connector()->Tuple:
+ """function to create db connector"""
+ host, user, passwd, db_name = parse_db_url()
+ conn = mdb.connect(host, user, passwd, db_name)
+ cursor = conn.cursor()
+
+ return (conn, cursor)
+ \ No newline at end of file
diff --git a/gn3/settings.py b/gn3/settings.py
index 2836581..e77a977 100644
--- a/gn3/settings.py
+++ b/gn3/settings.py
@@ -12,6 +12,6 @@ REDIS_JOB_QUEUE = "GN3::job-queue"
TMPDIR = os.environ.get("TMPDIR", tempfile.gettempdir())
# SQL confs
-SQLALCHEMY_DATABASE_URI = "mysql://kabui:1234@localhost/test"
+SQL_URI = os.environ.get("SQL_URI", "mysql://kabui:1234@localhost/db_webqtl")
SECRET_KEY = "password"
SQLALCHEMY_TRACK_MODIFICATIONS = False
diff --git a/guix.scm b/guix.scm
index 4f4e3b2..a810a12 100644
--- a/guix.scm
+++ b/guix.scm
@@ -78,7 +78,7 @@
("python-pylint" python-pylint)
("python-redis" ,python-redis)
("python-scipy" ,python-scipy)
- ("python-scipy" ,python-scipy)
+ ("python-mysqlclient" ,python-mysqlclient)
("python-sqlalchemy-stubs" ,python-sqlalchemy-stubs)))
(build-system python-build-system)
(home-page "https://github.com/genenetwork/genenetwork3")
diff --git a/requirements.txt b/requirements.txt
index 4586f96..e581297 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,9 +1,12 @@
astroid==2.5
base58==2.1.0
bcrypt==3.1.7
+certifi==2020.12.5
cffi==1.14.5
+chardet==4.0.0
click==7.1.2
Flask==1.1.2
+idna==2.10
ipfshttpclient==0.7.0
isort==4.3.21
itsdangerous==1.1.0
@@ -14,16 +17,19 @@ mccabe==0.6.1
multiaddr==0.0.9
mypy==0.790
mypy-extensions==0.4.3
+mysqlclient==2.0.3
netaddr==0.8.0
numpy==1.20.1
pycparser==2.20
pylint==2.5.3
redis==3.5.3
+requests==2.25.1
scipy==1.6.0
six==1.15.0
toml==0.10.2
typed-ast==1.4.2
typing-extensions==3.7.4.3
+urllib3==1.26.4
varint==1.0.2
Werkzeug==1.0.1
wrapt==1.12.1
diff --git a/tests/integration/test_correlation.py b/tests/integration/test_correlation.py
index 488a8a4..e67f58d 100644
--- a/tests/integration/test_correlation.py
+++ b/tests/integration/test_correlation.py
@@ -10,10 +10,6 @@ class CorrelationIntegrationTest(TestCase):
def setUp(self):
self.app = create_app().test_client()
- def test_fail(self):
- """initial method for class that fails"""
- self.assertEqual(2, 2)
-
@mock.patch("gn3.api.correlation.compute_all_sample_correlation")
def test_sample_r_correlation(self, mock_compute_samples):
"""Test /api/correlation/sample_r/{method}"""
@@ -66,13 +62,17 @@ class CorrelationIntegrationTest(TestCase):
self.assertEqual(response.get_json(), api_response)
@mock.patch("gn3.api.correlation.compute_all_lit_correlation")
- def test_lit_correlation(self, mock_compute_corr):
+ @mock.patch("gn3.api.correlation.database_connector")
+ def test_lit_correlation(self, database_connector, mock_compute_corr):
"""Test api/correlation/lit_corr/{species}/{gene_id}"""
mock_compute_corr.return_value = []
- post_data = [{"gene_id": 8, "lit_corr": 1}, {
- "gene_id": 12, "lit_corr": 0.3}]
+ database_connector.return_value = (mock.Mock(), mock.Mock())
+
+ post_data = {"1426678_at": "68031",
+ "1426679_at": "68036",
+ "1426680_at": "74777"}
response = self.app.post(
"/api/correlation/lit_corr/mouse/16", json=post_data, follow_redirects=True)
@@ -85,13 +85,20 @@ class CorrelationIntegrationTest(TestCase):
"""Test api/correlation/tissue_corr/{corr_method}"""
mock_tissue_corr.return_value = {}
+ target_trait_symbol_dict = {
+ "1418702_a_at": "Bxdc1", "1412_at": "Bxdc2"}
+ symbol_tissue_dict = {
+ "bxdc1": [12, 21.1, 11.4, 16.7], "bxdc2": [12, 20.1, 12.4, 1.1]}
+
primary_dict = {"trait_id": "1449593_at", "tissue_values": [1, 2, 3]}
- target_tissue_dict_list = [
- {"trait_id": "1449593_at", "tissue_values": [1, 2, 3]}]
+ target_tissue_data = {
+ "trait_symbol_dict": target_trait_symbol_dict,
+ "symbol_tissue_vals_dict": symbol_tissue_dict
+ }
tissue_corr_input_data = {"primary_tissue": primary_dict,
- "target_tissues": target_tissue_dict_list}
+ "target_tissues_dict": target_tissue_data}
response = self.app.post("/api/correlation/tissue_corr/spearman",
json=tissue_corr_input_data, follow_redirects=True)
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py
index 84b9330..52d1f60 100644
--- a/tests/unit/computations/test_correlation.py
+++ b/tests/unit/computations/test_correlation.py
@@ -18,6 +18,8 @@ from gn3.computations.correlations import query_formatter
from gn3.computations.correlations import map_to_mouse_gene_id
from gn3.computations.correlations import compute_all_lit_correlation
from gn3.computations.correlations import compute_all_tissue_correlation
+from gn3.computations.correlations import map_shared_keys_to_values
+from gn3.computations.correlations import process_trait_symbol_dict
class QueryableMixin:
@@ -27,6 +29,10 @@ class QueryableMixin:
"""base method for execute"""
raise NotImplementedError()
+ def cursor(self):
+ """method for creating db cursor"""
+ raise NotImplementedError()
+
def fetchone(self):
"""base method for fetching one iten"""
raise NotImplementedError()
@@ -46,37 +52,39 @@ class IllegalOperationError(Exception):
class DataBase(QueryableMixin):
"""Class for creating db object"""
- def __init__(self):
+ def __init__(self, expected_results=None, password="1234", db_name=None):
+ """expects the expectede results value to be an array"""
+ self.password = password
+ self.db_name = db_name
self.__query_options = None
- self.__results = None
+ self.results_generator(expected_results)
def execute(self, query_options):
"""method to execute an sql query"""
self.__query_options = query_options
- self.results_generator()
+ return 1
+
+ def cursor(self):
+ """method for creating db cursor"""
return self
def fetchone(self):
"""method to fetch single item from the db query"""
if self.__results is None:
- raise IllegalOperationError()
+ return None
return self.__results[0]
def fetchall(self):
"""method for fetching all items from db query"""
if self.__results is None:
- raise IllegalOperationError()
+ return None
return self.__results
- def results_generator(self, expected_results=None):
+ def results_generator(self, expected_results):
"""private method for generating mock results"""
- if expected_results is None:
- self.__results = [namedtuple("lit_coeff", "val")(x*0.1)
- for x in range(1, 4)]
- else:
- self.__results = expected_results
+ self.__results = expected_results
class TestCorrelation(TestCase):
@@ -236,21 +244,23 @@ class TestCorrelation(TestCase):
"""fetch results from db call for lit correlation given a trait list\
after doing correlation"""
- target_trait_lists = [{"gene_id": 15},
- {"gene_id": 17},
- {"gene_id": 11}]
+ target_trait_lists = [("1426679_at", 15),
+ ("1426702_at", 17),
+ ("1426682_at", 11)]
mock_mouse_gene_id.side_effect = [12, 11, 18, 16, 20]
- database_instance = namedtuple("database", "execute")("fetchone")
+ conn = DataBase()
fetch_lit_data.side_effect = [(15, 9), (17, 8), (11, 12)]
lit_results = lit_correlation_for_trait_list(
- database=database_instance, target_trait_lists=target_trait_lists,
+ conn=conn, target_trait_lists=target_trait_lists,
species="rat", trait_gene_id="12")
- expected_results = [{"gene_id": 15, "lit_corr": 9}, {
- "gene_id": 17, "lit_corr": 8}, {"gene_id": 11, "lit_corr": 12}]
+ expected_results = [{"1426679_at": {"gene_id": 15, "lit_corr": 9}},
+ {"1426702_at": {
+ "gene_id": 17, "lit_corr": 8}},
+ {"1426682_at": {"gene_id": 11, "lit_corr": 12}}]
self.assertEqual(lit_results, expected_results)
@@ -258,8 +268,8 @@ class TestCorrelation(TestCase):
"""test for fetching lit correlation data from\
the database where the input and mouse geneid are none"""
- database_instance = DataBase()
- results = fetch_lit_correlation_data(database=database_instance,
+ conn = DataBase()
+ results = fetch_lit_correlation_data(conn=conn,
gene_id="1",
input_mouse_gene_id=None,
mouse_gene_id=None)
@@ -270,10 +280,12 @@ class TestCorrelation(TestCase):
"""test for fetching lit corr coefficent givent the input\
input trait mouse gene id and mouse gene id"""
- database_instance = DataBase()
+ expected_db_results = [namedtuple("lit_coeff", "val")(x*0.1)
+ for x in range(1, 4)]
+ database_instance = DataBase(expected_results=expected_db_results)
expected_results = ("1", 0.1)
- lit_results = fetch_lit_correlation_data(database=database_instance,
+ lit_results = fetch_lit_correlation_data(conn=database_instance,
gene_id="1",
input_mouse_gene_id="20",
mouse_gene_id="15")
@@ -283,10 +295,8 @@ class TestCorrelation(TestCase):
def test_query_lit_correlation_for_db_empty(self):
"""test that corr coeffient returned is 0 given the\
db value if corr coefficient is empty"""
- database_instance = mock.Mock()
- database_instance.execute.return_value.fetchone.return_value = None
-
- lit_results = fetch_lit_correlation_data(database=database_instance,
+ database_instance = DataBase()
+ lit_results = fetch_lit_correlation_data(conn=database_instance,
input_mouse_gene_id="12",
gene_id="16",
mouse_gene_id="12")
@@ -336,13 +346,15 @@ class TestCorrelation(TestCase):
database_results = [namedtuple("mouse_id", "mouse")(val)
for val in range(12, 20)]
results = []
-
- database_instance.execute.return_value.fetchone.side_effect = database_results
+ cursor = mock.Mock()
+ cursor.execute.return_value = 1
+ cursor.fetchone.side_effect = database_results
+ database_instance.cursor.return_value = cursor
expected_results = [12, None, 13, 14]
for (species, gene_id) in test_data:
mouse_gene_id_results = map_to_mouse_gene_id(
- database=database_instance, species=species, gene_id=gene_id)
+ conn=database_instance, species=species, gene_id=gene_id)
results.append(mouse_gene_id_results)
self.assertEqual(results, expected_results)
@@ -361,7 +373,7 @@ class TestCorrelation(TestCase):
mock_lit_corr.side_effect = expected_mocked_lit_results
lit_correlation_results = compute_all_lit_correlation(
- database_instance=database, trait_lists=[{"gene_id": 11}],
+ conn=database, trait_lists=[{"gene_id": 11}],
species="rat", gene_id=12)
expected_results = {
@@ -371,15 +383,26 @@ class TestCorrelation(TestCase):
self.assertEqual(lit_correlation_results, expected_results)
@mock.patch("gn3.computations.correlations.tissue_correlation_for_trait_list")
- def test_compute_all_tissue_correlation(self, mock_tissue_corr):
+ @mock.patch("gn3.computations.correlations.process_trait_symbol_dict")
+ def test_compute_all_tissue_correlation(self, process_trait_symbol, mock_tissue_corr):
"""test for compute all tissue corelation which abstracts
api calling the tissue_correlation for trait_list"""
primary_tissue_dict = {"trait_id": "1419792_at",
"tissue_values": [1, 2, 3, 4, 5]}
- target_tissue_dict = [{"trait_id": "1418702_a_at", "tissue_values": [1, 2, 3]},
- {"trait_id": "1412_at", "tissue_values": [1, 2, 3]}]
+ target_tissue_dict = [{"trait_id": "1418702_a_at",
+ "symbol": "zf", "tissue_values": [1, 2, 3]},
+ {"trait_id": "1412_at",
+ "symbol": "prkce", "tissue_values": [1, 2, 3]}]
+
+ process_trait_symbol.return_value = target_tissue_dict
+
+ target_trait_symbol = {"1418702_a_at": "Zf", "1412_at": "Prkce"}
+ target_symbol_tissue_vals = {"zf": [1, 2, 3], "prkce": [1, 2, 3]}
+
+ target_tissue_data = {"trait_symbol_dict": target_trait_symbol,
+ "symbol_tissue_vals_dict": target_symbol_tissue_vals}
mock_tissue_corr.side_effect = [{"tissue_corr": -0.5, "p_value": 0.9, "tissue_number": 3},
{"tissue_corr": 1.11, "p_value": 0.2, "tissue_number": 3}]
@@ -391,9 +414,49 @@ class TestCorrelation(TestCase):
results = compute_all_tissue_correlation(
primary_tissue_dict=primary_tissue_dict,
- target_tissues_dict_list=target_tissue_dict,
+ target_tissues_data=target_tissue_data,
corr_method="pearson")
+ process_trait_symbol.assert_called_once_with(
+ target_trait_symbol, target_symbol_tissue_vals)
self.assertEqual(mock_tissue_corr.call_count, 2)
self.assertEqual(results, expected_results)
+
+ def test_map_shared_keys_to_values(self):
+ """test helper function needed to integrate with genenenetwork2\
+ given a a samplelist containing dataset sampelist keys\
+ map that to given sample values """
+
+ dataset_sample_keys = ["BXD1", "BXD2", "BXD5"]
+
+ target_dataset_data = {"HCMA:_AT": [4.1, 5.6, 3.2],
+ "TXD_AT": [6.2, 5.7, 3.6, ]}
+
+ expected_results = [{"trait_id": "HCMA:_AT",
+ "trait_sample_data": {"BXD1": 4.1, "BXD2": 5.6, "BXD5": 3.2}},
+ {"trait_id": "TXD_AT",
+ "trait_sample_data": {"BXD1": 6.2, "BXD2": 5.7, "BXD5": 3.6}}]
+
+ results = map_shared_keys_to_values(
+ dataset_sample_keys, target_dataset_data)
+
+ self.assertEqual(results, expected_results)
+
+ def test_process_trait_symbol_dict(self):
+ """test for processing trait symbol dict\
+ and fetch tissue values from tissue value dict\
+ """
+ trait_symbol_dict = {"1452864_at": "Igsf10"}
+ tissue_values_dict = {"igsf10": [8.9615, 10.6375, 9.2795, 8.6605]}
+
+ expected_results = {
+ "trait_id": "1452864_at",
+ "symbol": "igsf10",
+ "tissue_values": [8.9615, 10.6375, 9.2795, 8.6605]
+ }
+
+ results = process_trait_symbol_dict(
+ trait_symbol_dict, tissue_values_dict)
+
+ self.assertEqual(results, [expected_results])
diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py
new file mode 100644
index 0000000..0f2de9e
--- /dev/null
+++ b/tests/unit/test_db_utils.py
@@ -0,0 +1,37 @@
+"""module contains test for db_utils"""
+
+from unittest import TestCase
+from unittest import mock
+
+from types import SimpleNamespace
+
+from gn3.db_utils import database_connector
+from gn3.db_utils import parse_db_url
+
+
+class TestDatabase(TestCase):
+ """class contains testd for db connection functions"""
+
+ @mock.patch("gn3.db_utils.mdb")
+ @mock.patch("gn3.db_utils.parse_db_url")
+ def test_database_connector(self, mock_db_parser, mock_sql):
+ """test for creating database connection"""
+ mock_db_parser.return_value = ("localhost", "guest", "4321", "users")
+ callable_cursor = lambda: SimpleNamespace(execute=3)
+ cursor_object = SimpleNamespace(cursor=callable_cursor)
+ mock_sql.connect.return_value = cursor_object
+ mock_sql.close.return_value = None
+ results = database_connector()
+
+ mock_sql.connect.assert_called_with(
+ "localhost", "guest", "4321", "users")
+ self.assertIsInstance(
+ results, tuple, "database not created successfully")
+
+ @mock.patch("gn3.db_utils.SQL_URI",
+ "mysql://username:4321@localhost/test")
+ def test_parse_db_url(self):
+ """test for parsing db_uri env variable"""
+ results = parse_db_url()
+ expected_results = ("localhost", "username", "4321", "test")
+ self.assertEqual(results, expected_results)