aboutsummaryrefslogtreecommitdiff
"""Module contains the tests for correlation"""
from unittest import TestCase
from unittest import mock
from collections import namedtuple

import pytest
from numpy.testing import assert_almost_equal

from gn3.computations.correlations import normalize_values
from gn3.computations.correlations import compute_sample_r_correlation
from gn3.computations.correlations import compute_one_sample_correlation
from gn3.computations.correlations import filter_shared_sample_keys

from gn3.computations.correlations import tissue_correlation_for_trait
from gn3.computations.correlations import lit_correlation_for_trait
from gn3.computations.correlations import fetch_lit_correlation_data
from gn3.computations.correlations import query_formatter
from gn3.computations.correlations import map_to_mouse_gene_id
from gn3.computations.correlations import compute_all_lit_correlation
from gn3.computations.correlations import compute_tissue_correlation
from gn3.computations.correlations import map_shared_keys_to_values
from gn3.computations.correlations import process_trait_symbol_dict
from gn3.computations.correlations2 import compute_correlation


class QueryableMixin:
    """base class for db call"""

    def execute(self, query_options):
        """base method for execute"""
        raise NotImplementedError()

    def cursor(self):
        """method for creating db cursor"""
        raise NotImplementedError()

    def fetchone(self):
        """base method for fetching one iten"""
        raise NotImplementedError()

    def fetchall(self):
        """base method for fetch all items"""
        raise NotImplementedError()


class IllegalOperationError(Exception):
    """custom error to raise illegal operation in db"""

    def __init__(self):
        super().__init__("Operation not permitted!")


class DataBase(QueryableMixin):
    """Class for creating db object"""

    def __init__(self, expected_results=None, password="1234", db_name=None):
        """expects the expectede results value to be an array"""
        self.password = password
        self.db_name = db_name
        self.__query_options = None  # pylint: disable=[W0238]
        self.results_generator(expected_results)

    def execute(self, query_options):
        """method to execute an sql query"""
        self.__query_options = query_options  # pylint: disable=[W0238]
        return 1

    def cursor(self):
        """method for creating db cursor"""
        return self

    def fetchone(self):
        """method to fetch single item from the db query"""
        if self.__results is None:
            return None

        return self.__results[0]

    def fetchall(self):
        """method for fetching all items from db query"""
        if self.__results is None:
            return None
        return self.__results

    def results_generator(self, expected_results):
        """private method  for generating mock results"""

        self.__results = expected_results


class TestCorrelation(TestCase):
    """Class for testing correlation functions"""

    @pytest.mark.unit_test
    def test_normalize_values(self):
        """Function to test normalizing values """

        test_data = [
            [[2.3, None, None, 3.2, 4.1, 5], [3.4, 7.2, 1.3, None, 6.2, 4.1],
             [(2.3, 4.1, 5), (3.4, 6.2, 4.1)]],
            [[2.3, None, 1.3, None], [None, None, None, 1.2], []],
            [[], [], []]
        ]

        for a_values, b_values, expected_result in test_data:
            with self.subTest(a_values=a_values, b_values=b_values):
                results = normalize_values(a_values, b_values)
                self.assertEqual(list(zip(*list(results))), expected_result)

    @pytest.mark.unit_test
    @mock.patch("gn3.computations.correlations.compute_corr_coeff_p_value")
    @mock.patch("gn3.computations.correlations.normalize_values")
    def test_compute_sample_r_correlation(self, norm_vals, compute_corr):
        """Test for doing sample correlation gets the cor\
        and p value and rho value using pearson correlation
        """
        primary_values = [2.3, 4.1, 5, 4.2, None, None, 4, 1.2, 1.1]
        target_values = [3.4, 6.2, 4, 1.1, 1.2, None, 8, 1.1, 2.1]

        norm_vals.return_value = iter(
            [(2.3, 3.4), (4.1, 6.2), (5, 4), (4.2, 1.1), (4, 8), (1.2, 1.1), (1.1, 2.1)])

        compute_corr.return_value = (0.8, 0.21)

        bicor_results = compute_sample_r_correlation(trait_name="1412_at",
                                                     corr_method="bicor",
                                                     trait_vals=primary_values,
                                                     target_samples_vals=target_values)

        self.assertEqual(bicor_results, ("1412_at", 0.8, 0.21, 7))

    @pytest.mark.unit_test
    def test_filter_shared_sample_keys(self):
        """Function to  tests shared key between two dicts"""

        this_samplelist = {
            "C57BL/6J": "6.638",
            "DBA/2J": "6.266",
            "B6D2F1": "6.494",
            "D2B6F1": "6.565",
            "BXD2": "6.456"
        }

        target_samplelist = {
            "DBA/2J": "1.23",
            "D2B6F1": "6.565",
            "BXD2": "6.456"

        }

        filtered_target_samplelist = ("1.23", "6.565", "6.456")
        filtered_this_samplelist = ("6.266", "6.565", "6.456")

        results = filter_shared_sample_keys(
            this_samplelist=this_samplelist, target_samplelist=target_samplelist)

        self.assertEqual(list(zip(*list(results))), [filtered_this_samplelist,
                                                     filtered_target_samplelist])

    @pytest.mark.unit_test
    @mock.patch("gn3.computations.correlations.compute_sample_r_correlation")
    @mock.patch("gn3.computations.correlations.filter_shared_sample_keys")
    def test_compute_one_sample(self, filter_shared_samples, sample_r_corr):
        """Given target dataset compute all sample r correlation"""

        filter_shared_samples.return_value = [iter(val) for val in [(
            "1.23", "6.266"), ("6.565", "6.565"), ("6.456", "6.456")]]

        sample_r_corr.return_value = (["1419792_at", -1.0, 0.9, 6])

        this_trait_data = {
            "trait_id": "1455376_at",
            "trait_sample_data": {
                "C57BL/6J": "6.638",
                "DBA/2J": "6.266",
                "B6D2F1": "6.494",
                "D2B6F1": "6.565",
                "BXD2": "6.456"
            }}

        traits_dataset = [
            {
                "trait_id": "1419792_at",
                "trait_sample_data": {
                    "DBA/2J": "1.23",
                    "D2B6F1": "6.565",
                    "BXD2": "6.456"
                }
            }
        ]

        sample_all_results = [{"1419792_at": {"corr_coefficient": -1.0,
                                              "p_value": 0.9,
                                              "num_overlap": 6}}]

        self.assertEqual(
            compute_one_sample_correlation(
                this_trait_data["trait_sample_data"],
                traits_dataset[0], "pearson"),
            sample_all_results[0])
        sample_r_corr.assert_called_once_with(
            trait_name='1419792_at',
            corr_method="pearson", trait_vals=('1.23', '6.565', '6.456'),
            target_samples_vals=('6.266', '6.565', '6.456'))

    @pytest.mark.unit_test
    @mock.patch("gn3.computations.correlations.compute_corr_coeff_p_value")
    def test_tissue_correlation_for_trait(self, mock_compute_corr_coeff):
        """Test given a primary tissue values for a trait  and and a list of\
        target tissues for traits  do the tissue correlation for them
        """

        primary_tissue_values = [1.1, 1.5, 2.3]
        target_tissues_values = [1, 2, 3]
        mock_compute_corr_coeff.side_effect = [(0.4, 0.9), (-0.2, 0.91)]
        expected_tissue_results = {"1456_at": {"tissue_corr": 0.4,
                                               "tissue_p_val": 0.9, "tissue_number": 3}}
        tissue_results = tissue_correlation_for_trait(
            primary_tissue_values, target_tissues_values,
            corr_method="pearson", trait_id="1456_at",
            compute_corr_p_value=mock_compute_corr_coeff)

        self.assertEqual(tissue_results, expected_tissue_results)

    @pytest.mark.unit_test
    @mock.patch("gn3.computations.correlations.fetch_lit_correlation_data")
    @mock.patch("gn3.computations.correlations.map_to_mouse_gene_id")
    def test_lit_correlation_for_trait(self, mock_mouse_gene_id, fetch_lit_data):
        """Fetch results from  db call for lit correlation given a trait list\
        after doing correlation
        """

        target_trait_lists = [("1426679_at", 15),
                              ("1426702_at", 17),
                              ("1426682_at", 11)]
        mock_mouse_gene_id.side_effect = [12, 11, 18, 16, 20]

        conn = DataBase()

        fetch_lit_data.side_effect = [(15, 9), (17, 8), (11, 12)]

        lit_results = lit_correlation_for_trait(
            conn=conn, target_trait_lists=target_trait_lists,
            species="rat", trait_gene_id="12")

        expected_results = [{"1426679_at": {"gene_id": 15, "lit_corr": 9}},
                            {"1426702_at": {
                                "gene_id": 17, "lit_corr": 8}},
                            {"1426682_at": {"gene_id": 11, "lit_corr": 12}}]

        self.assertEqual(lit_results, expected_results)

    @pytest.mark.unit_test
    def test_fetch_lit_correlation_data(self):
        """Test for fetching lit correlation data from\
        the database where the input and mouse geneid are none
        """

        conn = DataBase()
        results = fetch_lit_correlation_data(conn=conn,
                                             gene_id="1",
                                             input_mouse_gene_id=None,
                                             mouse_gene_id=None)

        self.assertEqual(results, ("1", None))

    @pytest.mark.unit_test
    def test_fetch_lit_correlation_data_db_query(self):
        """Test for fetching lit corr coefficent givent the input\
         input trait mouse gene id and mouse gene id
        """

        expected_db_results = [[x*0.1]
                               for x in range(1, 4)]
        conn = DataBase(expected_results=expected_db_results)
        expected_results = ("1", 0.1)

        lit_results = fetch_lit_correlation_data(conn=conn,
                                                 gene_id="1",
                                                 input_mouse_gene_id="20",
                                                 mouse_gene_id="15")

        self.assertEqual(expected_results, lit_results)

    @pytest.mark.unit_test
    def test_query_lit_correlation_for_db_empty(self):
        """Test that corr coeffient returned is None given the\
        db value if corr coefficient is empty
        """
        conn = mock.Mock()
        conn.cursor.return_value = DataBase()
        conn.execute.return_value.fetchone.return_value = ""

        self.assertEqual(fetch_lit_correlation_data(conn=conn,
                                                    input_mouse_gene_id="12",
                                                    gene_id="16",
                                                    mouse_gene_id="12"), ("16", None))

    @pytest.mark.unit_test
    def test_query_formatter(self):
        """Test for formatting a query given the query string and also the\
        values
        """
        query = """
        SELECT VALUE
        FROM  LCorr
        WHERE GeneId1='%s' and
        GeneId2='%s'
        """

        expected_formatted_query = """
        SELECT VALUE
        FROM  LCorr
        WHERE GeneId1='20' and
        GeneId2='15'
        """

        mouse_gene_id = "20"
        input_mouse_gene_id = "15"

        query_values = (mouse_gene_id, input_mouse_gene_id)

        formatted_query = query_formatter(query, *query_values)

        self.assertEqual(formatted_query, expected_formatted_query)

    @pytest.mark.unit_test
    def test_query_formatter_no_query_values(self):
        """Test for formatting a query where there are no\
        string placeholder
        """
        query = """SELECT * FROM  USERS"""
        formatted_query = query_formatter(query)

        self.assertEqual(formatted_query, query)

    @pytest.mark.unit_test
    def test_map_to_mouse_gene_id(self):
        """Test for converting a gene id to mouse geneid\
        given a species which is not mouse
        """
        conn = mock.Mock()
        test_data = [("Human", 14), (None, 9), ("Mouse", 15), ("Rat", 14)]

        database_results = [namedtuple("mouse_id", "mouse")(val)
                            for val in range(12, 20)]
        results = []
        cursor = mock.Mock()
        cursor.execute.return_value = 1
        cursor.fetchone.side_effect = database_results
        conn.cursor.return_value = cursor
        expected_results = [12, None, 13, 14]
        for (species, gene_id) in test_data:

            mouse_gene_id_results = map_to_mouse_gene_id(
                conn=conn, species=species, gene_id=gene_id)
            results.append(mouse_gene_id_results)

        self.assertEqual(results, expected_results)

    @pytest.mark.unit_test
    @mock.patch("gn3.computations.correlations.lit_correlation_for_trait")
    def test_compute_all_lit_correlation(self, mock_lit_corr):
        """Test for compute all lit correlation which acts\
        as an abstraction for lit_correlation_for_trait
        and is used in the api/correlation/lit
        """

        conn = mock.Mock()

        expected_mocked_lit_results = [{"1412_at": {"gene_id": 11, "lit_corr": 0.9}}, {"1412_a": {
            "gene_id": 17, "lit_corr": 0.48}}]

        mock_lit_corr.return_value = expected_mocked_lit_results

        lit_correlation_results = compute_all_lit_correlation(
            conn=conn, trait_lists=[("1412_at", 11), ("1412_a", 121)],
            species="rat", gene_id=12)

        self.assertEqual(lit_correlation_results, expected_mocked_lit_results)

    @pytest.mark.unit_test
    @mock.patch("gn3.computations.correlations.tissue_correlation_for_trait")
    @mock.patch("gn3.computations.correlations.process_trait_symbol_dict")
    def test_compute_all_tissue_correlation(self, process_trait_symbol, mock_tissue_corr):
        """Test for compute all tissue corelation which abstracts
        api calling the tissue_correlation for trait_list"""

        primary_tissue_dict = {"trait_id": "1419792_at",
                               "tissue_values": [1, 2, 3, 4, 5]}

        target_tissue_dict = [{"trait_id": "1418702_a_at",
                               "symbol": "zf", "tissue_values": [1, 2, 3]},
                              {"trait_id": "1412_at",
                               "symbol": "prkce", "tissue_values": [1, 2, 3]}]

        process_trait_symbol.return_value = target_tissue_dict

        target_trait_symbol = {"1418702_a_at": "Zf", "1412_at": "Prkce"}
        target_symbol_tissue_vals = {"zf": [1, 2, 3], "prkce": [1, 2, 3]}

        target_tissue_data = {"trait_symbol_dict": target_trait_symbol,
                              "symbol_tissue_vals_dict": target_symbol_tissue_vals}

        mock_tissue_corr.side_effect = [{"1418702_a_at": {"tissue_corr": -0.5, "tissue_p_val": 0.9,
                                                          "tissue_number": 3}},
                                        {"1412_at": {"tissue_corr": 1.11, "tissue_p_val": 0.2,
                                                     "tissue_number": 3}}]

        expected_results = [{"1412_at":
                             {"tissue_corr": 1.11, "tissue_p_val": 0.2, "tissue_number": 3}},
                            {"1418702_a_at":
                             {"tissue_corr": -0.5, "tissue_p_val": 0.9, "tissue_number": 3}}]

        results = compute_tissue_correlation(
            primary_tissue_dict=primary_tissue_dict,
            target_tissues_data=target_tissue_data,
            corr_method="pearson")
        process_trait_symbol.assert_called_once_with(
            target_trait_symbol, target_symbol_tissue_vals)

        self.assertEqual(mock_tissue_corr.call_count, 2)

        self.assertEqual(results, expected_results)

    @pytest.mark.unit_test
    def test_map_shared_keys_to_values(self):
        """test helper function needed to integrate with genenenetwork2\
        given a a samplelist containing dataset sampelist keys\
        map that to given sample values """

        dataset_sample_keys = ["BXD1", "BXD2", "BXD5"]

        target_dataset_data = {"HCMA:_AT": [4.1, 5.6, 3.2],
                               "TXD_AT": [6.2, 5.7, 3.6, ]}

        expected_results = [{"trait_id": "HCMA:_AT",
                             "trait_sample_data": {"BXD1": 4.1, "BXD2": 5.6, "BXD5": 3.2}},
                            {"trait_id": "TXD_AT",
                             "trait_sample_data": {"BXD1": 6.2, "BXD2": 5.7, "BXD5": 3.6}}]

        results = map_shared_keys_to_values(
            dataset_sample_keys, target_dataset_data)

        self.assertEqual(results, expected_results)

    @pytest.mark.unit_test
    def test_process_trait_symbol_dict(self):
        """test for processing trait symbol dict\
        and fetch tissue values from tissue value dict\
        """
        trait_symbol_dict = {"1452864_at": "Igsf10"}
        tissue_values_dict = {"igsf10": [8.9615, 10.6375, 9.2795, 8.6605]}

        expected_results = {
            "trait_id": "1452864_at",
            "symbol": "igsf10",
            "tissue_values": [8.9615, 10.6375, 9.2795, 8.6605]
        }

        results = process_trait_symbol_dict(
            trait_symbol_dict, tissue_values_dict)

        self.assertEqual(results, [expected_results])

    @pytest.mark.unit_test
    def test_compute_correlation(self):
        """Test that the new correlation function works the same as the original
        from genenetwork1."""
        for dbdata, userdata, expected in [
                [[None, None, None, None, None, None, None, None, None, None],
                 [None, None, None, None, None, None, None, None, None, None],
                 (0.0, 0)],
                [[None, None, None, None, None, None, None, None, None, 0],
                 [None, None, None, None, None, None, None, None, None, None],
                 (0.0, 0)],
                [[None, None, None, None, None, None, None, None, None, 0],
                 [None, None, None, None, None, None, None, None, None, 0],
                 (0.0, 1)],
                [[9.3, 2.2, 5.4, 7.2, 6.4, 7.6, 3.8, 1.8, 8.4, 0.2],
                 [0.6, 3.97, 5.82, 8.21, 1.65, 4.55, 6.72, 9.5, 7.33, 2.34],
                 (-0.12720361919462056, 10)],
                [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                 [None, None, None, None, 2, None, None, 3, None, None],
                 (0.0, 2)]]:
            with self.subTest(dbdata=dbdata, userdata=userdata):
                actual = compute_correlation(dbdata, userdata)
                with self.subTest("correlation coefficient"):
                    assert_almost_equal(actual[0], expected[0])
                with self.subTest("overlap"):
                    self.assertEqual(actual[1], expected[1])