"""Module contains the tests for correlation"""
from unittest import TestCase
from unittest import mock
from collections import namedtuple
import pytest
from numpy.testing import assert_almost_equal
from gn3.computations.correlations import normalize_values
from gn3.computations.correlations import compute_sample_r_correlation
from gn3.computations.correlations import compute_one_sample_correlation
from gn3.computations.correlations import filter_shared_sample_keys
from gn3.computations.correlations import tissue_correlation_for_trait
from gn3.computations.correlations import lit_correlation_for_trait
from gn3.computations.correlations import fetch_lit_correlation_data
from gn3.computations.correlations import query_formatter
from gn3.computations.correlations import map_to_mouse_gene_id
from gn3.computations.correlations import compute_all_lit_correlation
from gn3.computations.correlations import compute_tissue_correlation
from gn3.computations.correlations import map_shared_keys_to_values
from gn3.computations.correlations import process_trait_symbol_dict
from gn3.computations.correlations2 import compute_correlation
class QueryableMixin:
"""base class for db call"""
def execute(self, query_options):
"""base method for execute"""
raise NotImplementedError()
def cursor(self):
"""method for creating db cursor"""
raise NotImplementedError()
def fetchone(self):
"""base method for fetching one iten"""
raise NotImplementedError()
def fetchall(self):
"""base method for fetch all items"""
raise NotImplementedError()
class IllegalOperationError(Exception):
"""custom error to raise illegal operation in db"""
def __init__(self):
super().__init__("Operation not permitted!")
class DataBase(QueryableMixin):
"""Class for creating db object"""
def __init__(self, expected_results=None, password="1234", db_name=None):
"""expects the expectede results value to be an array"""
self.password = password
self.db_name = db_name
self.__query_options = None # pylint: disable=[W0238]
self.results_generator(expected_results)
def execute(self, query_options):
"""method to execute an sql query"""
self.__query_options = query_options # pylint: disable=[W0238]
return 1
def cursor(self):
"""method for creating db cursor"""
return self
def fetchone(self):
"""method to fetch single item from the db query"""
if self.__results is None:
return None
return self.__results[0]
def fetchall(self):
"""method for fetching all items from db query"""
if self.__results is None:
return None
return self.__results
def results_generator(self, expected_results):
"""private method for generating mock results"""
self.__results = expected_results
class TestCorrelation(TestCase):
"""Class for testing correlation functions"""
@pytest.mark.unit_test
def test_normalize_values(self):
"""Function to test normalizing values """
test_data = [
[[2.3, None, None, 3.2, 4.1, 5], [3.4, 7.2, 1.3, None, 6.2, 4.1],
[(2.3, 4.1, 5), (3.4, 6.2, 4.1)]],
[[2.3, None, 1.3, None], [None, None, None, 1.2], []],
[[], [], []]
]
for a_values, b_values, expected_result in test_data:
with self.subTest(a_values=a_values, b_values=b_values):
results = normalize_values(a_values, b_values)
self.assertEqual(list(zip(*list(results))), expected_result)
@pytest.mark.unit_test
@mock.patch("gn3.computations.correlations.compute_corr_coeff_p_value")
@mock.patch("gn3.computations.correlations.normalize_values")
def test_compute_sample_r_correlation(self, norm_vals, compute_corr):
"""Test for doing sample correlation gets the cor\
and p value and rho value using pearson correlation
"""
primary_values = [2.3, 4.1, 5, 4.2, None, None, 4, 1.2, 1.1]
target_values = [3.4, 6.2, 4, 1.1, 1.2, None, 8, 1.1, 2.1]
norm_vals.return_value = iter(
[(2.3, 3.4), (4.1, 6.2), (5, 4), (4.2, 1.1), (4, 8), (1.2, 1.1), (1.1, 2.1)])
compute_corr.return_value = (0.8, 0.21)
bicor_results = compute_sample_r_correlation(trait_name="1412_at",
corr_method="bicor",
trait_vals=primary_values,
target_samples_vals=target_values)
self.assertEqual(bicor_results, ("1412_at", 0.8, 0.21, 7))
@pytest.mark.unit_test
def test_filter_shared_sample_keys(self):
"""Function to tests shared key between two dicts"""
this_samplelist = {
"C57BL/6J": "6.638",
"DBA/2J": "6.266",
"B6D2F1": "6.494",
"D2B6F1": "6.565",
"BXD2": "6.456"
}
target_samplelist = {
"DBA/2J": "1.23",
"D2B6F1": "6.565",
"BXD2": "6.456"
}
filtered_target_samplelist = ("1.23", "6.565", "6.456")
filtered_this_samplelist = ("6.266", "6.565", "6.456")
results = filter_shared_sample_keys(
this_samplelist=this_samplelist, target_samplelist=target_samplelist)
self.assertEqual(list(zip(*list(results))), [filtered_this_samplelist,
filtered_target_samplelist])
@pytest.mark.unit_test
@mock.patch("gn3.computations.correlations.compute_sample_r_correlation")
@mock.patch("gn3.computations.correlations.filter_shared_sample_keys")
def test_compute_one_sample(self, filter_shared_samples, sample_r_corr):
"""Given target dataset compute all sample r correlation"""
filter_shared_samples.return_value = [iter(val) for val in [(
"1.23", "6.266"), ("6.565", "6.565"), ("6.456", "6.456")]]
sample_r_corr.return_value = (["1419792_at", -1.0, 0.9, 6])
this_trait_data = {
"trait_id": "1455376_at",
"trait_sample_data": {
"C57BL/6J": "6.638",
"DBA/2J": "6.266",
"B6D2F1": "6.494",
"D2B6F1": "6.565",
"BXD2": "6.456"
}}
traits_dataset = [
{
"trait_id": "1419792_at",
"trait_sample_data": {
"DBA/2J": "1.23",
"D2B6F1": "6.565",
"BXD2": "6.456"
}
}
]
sample_all_results = [{"1419792_at": {"corr_coefficient": -1.0,
"p_value": 0.9,
"num_overlap": 6}}]
self.assertEqual(
compute_one_sample_correlation(
this_trait_data["trait_sample_data"],
traits_dataset[0], "pearson"),
sample_all_results[0])
sample_r_corr.assert_called_once_with(
trait_name='1419792_at',
corr_method="pearson", trait_vals=('1.23', '6.565', '6.456'),
target_samples_vals=('6.266', '6.565', '6.456'))
@pytest.mark.unit_test
@mock.patch("gn3.computations.correlations.compute_corr_coeff_p_value")
def test_tissue_correlation_for_trait(self, mock_compute_corr_coeff):
"""Test given a primary tissue values for a trait and and a list of\
target tissues for traits do the tissue correlation for them
"""
primary_tissue_values = [1.1, 1.5, 2.3]
target_tissues_values = [1, 2, 3]
mock_compute_corr_coeff.side_effect = [(0.4, 0.9), (-0.2, 0.91)]
expected_tissue_results = {"1456_at": {"tissue_corr": 0.4,
"tissue_p_val": 0.9, "tissue_number": 3}}
tissue_results = tissue_correlation_for_trait(
primary_tissue_values, target_tissues_values,
corr_method="pearson", trait_id="1456_at",
compute_corr_p_value=mock_compute_corr_coeff)
self.assertEqual(tissue_results, expected_tissue_results)
@pytest.mark.unit_test
@mock.patch("gn3.computations.correlations.fetch_lit_correlation_data")
@mock.patch("gn3.computations.correlations.map_to_mouse_gene_id")
def test_lit_correlation_for_trait(self, mock_mouse_gene_id, fetch_lit_data):
"""Fetch results from db call for lit correlation given a trait list\
after doing correlation
"""
target_trait_lists = [("1426679_at", 15),
("1426702_at", 17),
("1426682_at", 11)]
mock_mouse_gene_id.side_effect = [12, 11, 18, 16, 20]
conn = DataBase()
fetch_lit_data.side_effect = [(15, 9), (17, 8), (11, 12)]
lit_results = lit_correlation_for_trait(
conn=conn, target_trait_lists=target_trait_lists,
species="rat", trait_gene_id="12")
expected_results = [{"1426679_at": {"gene_id": 15, "lit_corr": 9}},
{"1426702_at": {
"gene_id": 17, "lit_corr": 8}},
{"1426682_at": {"gene_id": 11, "lit_corr": 12}}]
self.assertEqual(lit_results, expected_results)
@pytest.mark.unit_test
def test_fetch_lit_correlation_data(self):
"""Test for fetching lit correlation data from\
the database where the input and mouse geneid are none
"""
conn = DataBase()
results = fetch_lit_correlation_data(conn=conn,
gene_id="1",
input_mouse_gene_id=None,
mouse_gene_id=None)
self.assertEqual(results, ("1", None))
@pytest.mark.unit_test
def test_fetch_lit_correlation_data_db_query(self):
"""Test for fetching lit corr coefficent givent the input\
input trait mouse gene id and mouse gene id
"""
expected_db_results = [[x*0.1]
for x in range(1, 4)]
conn = DataBase(expected_results=expected_db_results)
expected_results = ("1", 0.1)
lit_results = fetch_lit_correlation_data(conn=conn,
gene_id="1",
input_mouse_gene_id="20",
mouse_gene_id="15")
self.assertEqual(expected_results, lit_results)
@pytest.mark.unit_test
def test_query_lit_correlation_for_db_empty(self):
"""Test that corr coeffient returned is None given the\
db value if corr coefficient is empty
"""
conn = mock.Mock()
conn.cursor.return_value = DataBase()
conn.execute.return_value.fetchone.return_value = ""
self.assertEqual(fetch_lit_correlation_data(conn=conn,
input_mouse_gene_id="12",
gene_id="16",
mouse_gene_id="12"), ("16", None))
@pytest.mark.unit_test
def test_query_formatter(self):
"""Test for formatting a query given the query string and also the\
values
"""
query = """
SELECT VALUE
FROM LCorr
WHERE GeneId1='%s' and
GeneId2='%s'
"""
expected_formatted_query = """
SELECT VALUE
FROM LCorr
WHERE GeneId1='20' and
GeneId2='15'
"""
mouse_gene_id = "20"
input_mouse_gene_id = "15"
query_values = (mouse_gene_id, input_mouse_gene_id)
formatted_query = query_formatter(query, *query_values)
self.assertEqual(formatted_query, expected_formatted_query)
@pytest.mark.unit_test
def test_query_formatter_no_query_values(self):
"""Test for formatting a query where there are no\
string placeholder
"""
query = """SELECT * FROM USERS"""
formatted_query = query_formatter(query)
self.assertEqual(formatted_query, query)
@pytest.mark.unit_test
def test_map_to_mouse_gene_id(self):
"""Test for converting a gene id to mouse geneid\
given a species which is not mouse
"""
conn = mock.Mock()
test_data = [("Human", 14), (None, 9), ("Mouse", 15), ("Rat", 14)]
database_results = [namedtuple("mouse_id", "mouse")(val)
for val in range(12, 20)]
results = []
cursor = mock.Mock()
cursor.execute.return_value = 1
cursor.fetchone.side_effect = database_results
conn.cursor.return_value = cursor
expected_results = [12, None, 13, 14]
for (species, gene_id) in test_data:
mouse_gene_id_results = map_to_mouse_gene_id(
conn=conn, species=species, gene_id=gene_id)
results.append(mouse_gene_id_results)
self.assertEqual(results, expected_results)
@pytest.mark.unit_test
@mock.patch("gn3.computations.correlations.lit_correlation_for_trait")
def test_compute_all_lit_correlation(self, mock_lit_corr):
"""Test for compute all lit correlation which acts\
as an abstraction for lit_correlation_for_trait
and is used in the api/correlation/lit
"""
conn = mock.Mock()
expected_mocked_lit_results = [{"1412_at": {"gene_id": 11, "lit_corr": 0.9}}, {"1412_a": {
"gene_id": 17, "lit_corr": 0.48}}]
mock_lit_corr.return_value = expected_mocked_lit_results
lit_correlation_results = compute_all_lit_correlation(
conn=conn, trait_lists=[("1412_at", 11), ("1412_a", 121)],
species="rat", gene_id=12)
self.assertEqual(lit_correlation_results, expected_mocked_lit_results)
@pytest.mark.unit_test
@mock.patch("gn3.computations.correlations.tissue_correlation_for_trait")
@mock.patch("gn3.computations.correlations.process_trait_symbol_dict")
def test_compute_all_tissue_correlation(self, process_trait_symbol, mock_tissue_corr):
"""Test for compute all tissue corelation which abstracts
api calling the tissue_correlation for trait_list"""
primary_tissue_dict = {"trait_id": "1419792_at",
"tissue_values": [1, 2, 3, 4, 5]}
target_tissue_dict = [{"trait_id": "1418702_a_at",
"symbol": "zf", "tissue_values": [1, 2, 3]},
{"trait_id": "1412_at",
"symbol": "prkce", "tissue_values": [1, 2, 3]}]
process_trait_symbol.return_value = target_tissue_dict
target_trait_symbol = {"1418702_a_at": "Zf", "1412_at": "Prkce"}
target_symbol_tissue_vals = {"zf": [1, 2, 3], "prkce": [1, 2, 3]}
target_tissue_data = {"trait_symbol_dict": target_trait_symbol,
"symbol_tissue_vals_dict": target_symbol_tissue_vals}
mock_tissue_corr.side_effect = [{"1418702_a_at": {"tissue_corr": -0.5, "tissue_p_val": 0.9,
"tissue_number": 3}},
{"1412_at": {"tissue_corr": 1.11, "tissue_p_val": 0.2,
"tissue_number": 3}}]
expected_results = [{"1412_at":
{"tissue_corr": 1.11, "tissue_p_val": 0.2, "tissue_number": 3}},
{"1418702_a_at":
{"tissue_corr": -0.5, "tissue_p_val": 0.9, "tissue_number": 3}}]
results = compute_tissue_correlation(
primary_tissue_dict=primary_tissue_dict,
target_tissues_data=target_tissue_data,
corr_method="pearson")
process_trait_symbol.assert_called_once_with(
target_trait_symbol, target_symbol_tissue_vals)
self.assertEqual(mock_tissue_corr.call_count, 2)
self.assertEqual(results, expected_results)
@pytest.mark.unit_test
def test_map_shared_keys_to_values(self):
"""test helper function needed to integrate with genenenetwork2\
given a a samplelist containing dataset sampelist keys\
map that to given sample values """
dataset_sample_keys = ["BXD1", "BXD2", "BXD5"]
target_dataset_data = {"HCMA:_AT": [4.1, 5.6, 3.2],
"TXD_AT": [6.2, 5.7, 3.6, ]}
expected_results = [{"trait_id": "HCMA:_AT",
"trait_sample_data": {"BXD1": 4.1, "BXD2": 5.6, "BXD5": 3.2}},
{"trait_id": "TXD_AT",
"trait_sample_data": {"BXD1": 6.2, "BXD2": 5.7, "BXD5": 3.6}}]
results = map_shared_keys_to_values(
dataset_sample_keys, target_dataset_data)
self.assertEqual(results, expected_results)
@pytest.mark.unit_test
def test_process_trait_symbol_dict(self):
"""test for processing trait symbol dict\
and fetch tissue values from tissue value dict\
"""
trait_symbol_dict = {"1452864_at": "Igsf10"}
tissue_values_dict = {"igsf10": [8.9615, 10.6375, 9.2795, 8.6605]}
expected_results = {
"trait_id": "1452864_at",
"symbol": "igsf10",
"tissue_values": [8.9615, 10.6375, 9.2795, 8.6605]
}
results = process_trait_symbol_dict(
trait_symbol_dict, tissue_values_dict)
self.assertEqual(results, [expected_results])
@pytest.mark.unit_test
def test_compute_correlation(self):
"""Test that the new correlation function works the same as the original
from genenetwork1."""
for dbdata, userdata, expected in [
[[None, None, None, None, None, None, None, None, None, None],
[None, None, None, None, None, None, None, None, None, None],
(0.0, 0)],
[[None, None, None, None, None, None, None, None, None, 0],
[None, None, None, None, None, None, None, None, None, None],
(0.0, 0)],
[[None, None, None, None, None, None, None, None, None, 0],
[None, None, None, None, None, None, None, None, None, 0],
(0.0, 1)],
[[9.3, 2.2, 5.4, 7.2, 6.4, 7.6, 3.8, 1.8, 8.4, 0.2],
[0.6, 3.97, 5.82, 8.21, 1.65, 4.55, 6.72, 9.5, 7.33, 2.34],
(-0.12720361919462056, 10)],
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[None, None, None, None, 2, None, None, 3, None, None],
(0.0, 2)]]:
with self.subTest(dbdata=dbdata, userdata=userdata):
actual = compute_correlation(dbdata, userdata)
with self.subTest("correlation coefficient"):
assert_almost_equal(actual[0], expected[0])
with self.subTest("overlap"):
self.assertEqual(actual[1], expected[1])