about summary refs log tree commit diff
path: root/tests/unit/computations/test_partial_correlations.py
diff options
context:
space:
mode:
authorMuriithi Frederick Muriuki2021-11-20 18:06:32 +0300
committerGitHub2021-11-20 18:06:32 +0300
commit92d5766f5514181cd6aa82fc0d0f225666e892cb (patch)
tree74840e8e2118e24e1f49eb780ca0bbf24704e510 /tests/unit/computations/test_partial_correlations.py
parentabc0d36f39c691652fee81bce808d625fc368e72 (diff)
parent08c81b8892060353bb7fb15555875f03bbdcb46e (diff)
downloadgenenetwork3-92d5766f5514181cd6aa82fc0d0f225666e892cb.tar.gz
Merge pull request #56 from genenetwork/partial-correlations
Partial correlations
Diffstat (limited to 'tests/unit/computations/test_partial_correlations.py')
-rw-r--r--tests/unit/computations/test_partial_correlations.py92
1 files changed, 27 insertions, 65 deletions
diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py
index 83cb9d9..3e1b6e1 100644
--- a/tests/unit/computations/test_partial_correlations.py
+++ b/tests/unit/computations/test_partial_correlations.py
@@ -1,16 +1,18 @@
 """Module contains tests for gn3.partial_correlations"""
 
-import csv
-from unittest import TestCase, skip
+from unittest import TestCase
+
+import pandas
+from numpy.testing import assert_allclose
+
 from gn3.computations.partial_correlations import (
     fix_samples,
     control_samples,
+    build_data_frame,
     dictify_by_samples,
     tissue_correlation,
     find_identical_traits,
-    partial_correlation_matrix,
-    good_dataset_samples_indexes,
-    partial_correlation_recursive)
+    good_dataset_samples_indexes)
 
 sampleslist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"]
 control_traits = (
@@ -93,29 +95,6 @@ dictified_control_samples = (
      "BXD1": {"sample_name": "BXD1", "value": 7.77141, "variance": None},
      "BXD2": {"sample_name": "BXD2", "value":  7.80944, "variance": None}})
 
-def parse_test_data_csv(filename):
-    """
-    Parse test data csv files for R -> Python conversion of some functions.
-    """
-    def __str__to_tuple(line, field):
-        return tuple(float(s.strip()) for s in line[field].split(","))
-
-    with open(filename, newline="\n") as csvfile:
-        reader = csv.DictReader(csvfile, delimiter=",", quotechar='"')
-        lines = tuple(row for row in reader)
-
-    methods = {"p": "pearson", "s": "spearman", "k": "kendall"}
-    return tuple({
-        **line,
-        "x": __str__to_tuple(line, "x"),
-        "y": __str__to_tuple(line, "y"),
-        "z": __str__to_tuple(line, "z"),
-        "method": methods[line["method"]],
-        "rm": line["rm"] == "TRUE",
-        "result": float(line["result"])
-    } for line in lines)
-
-
 class TestPartialCorrelations(TestCase):
     """Class for testing partial correlations computation functions"""
 
@@ -272,7 +251,7 @@ class TestPartialCorrelations(TestCase):
                 with self.assertRaises(error, msg=error_msg):
                     tissue_correlation(primary, target, method)
 
-    def test_tissue_correlation(self):
+    def test_tissue_correlation(self): # pylint: disable=R0201
         """
         Test that the correct correlation values are computed for the given:
         - primary trait
@@ -281,11 +260,11 @@ class TestPartialCorrelations(TestCase):
         """
         for primary, target, method, expected in (
                 ((12.34, 18.36, 42.51), (37.25, 46.25, 46.56), "pearson",
-                 (0.6761779253, 0.5272701134)),
+                 (0.6761779252651052, 0.5272701133657985)),
                 ((1, 2, 3, 4, 5), (5, 6, 7, 8, 7), "spearman",
-                 (0.8207826817, 0.0885870053))):
+                 (0.8207826816681233, 0.08858700531354381))):
             with self.subTest(primary=primary, target=target, method=method):
-                self.assertEqual(
+                assert_allclose(
                     tissue_correlation(primary, target, method), expected)
 
     def test_good_dataset_samples_indexes(self):
@@ -298,38 +277,21 @@ class TestPartialCorrelations(TestCase):
                 ("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l")),
             (0, 4, 8, 10))
 
-    @skip
-    def test_partial_correlation_matrix(self):
+    def test_build_data_frame(self):
         """
-        Test that `partial_correlation_matrix` computes the appropriate
-        correlation value.
+        Check that the function builds the correct data frame.
         """
-        for sample in parse_test_data_csv(
-                ("tests/unit/computations/partial_correlations_test_data/"
-                 "pcor_mat_blackbox_test.csv")):
-            with self.subTest(
-                    xdata=sample["x"], ydata=sample["y"], zdata=sample["z"],
-                    method=sample["method"], omit_nones=sample["rm"]):
-                self.assertEqual(
-                    partial_correlation_matrix(
-                        sample["x"], sample["y"], sample["z"],
-                        method=sample["method"], omit_nones=sample["rm"]),
-                    sample["result"])
-
-    @skip
-    def test_partial_correlation_recursive(self):
-        """
-        Test that `partial_correlation_recursive` computes the appropriate
-        correlation value.
-        """
-        for sample in parse_test_data_csv(
-                ("tests/unit/computations/partial_correlations_test_data/"
-                 "pcor_rec_blackbox_test.csv")):
-            with self.subTest(
-                    xdata=sample["x"], ydata=sample["y"], zdata=sample["z"],
-                    method=sample["method"], omit_nones=sample["rm"]):
-                self.assertEqual(
-                    partial_correlation_recursive(
-                        sample["x"], sample["y"], sample["z"],
-                        method=sample["method"], omit_nones=sample["rm"]),
-                    sample["result"])
+        for xdata, ydata, zdata, expected in (
+                ((0.1, 1.1, 2.1), (2.1, 3.1, 4.1), (5.1, 6.1, 7.1),
+                 pandas.DataFrame({
+                     "x": (0.1, 1.1, 2.1), "y": (2.1, 3.1, 4.1),
+                     "z": (5.1, 6.1, 7.1)})),
+                ((0.1, 1.1, 2.1), (2.1, 3.1, 4.1),
+                 ((5.1, 6.1, 7.1), (5.2, 6.2, 7.2), (5.3, 6.3, 7.3)),
+                 pandas.DataFrame({
+                     "x": (0.1, 1.1, 2.1), "y": (2.1, 3.1, 4.1),
+                     "z0": (5.1, 6.1, 7.1), "z1": (5.2, 6.2, 7.2),
+                     "z2": (5.3, 6.3, 7.3)}))):
+            with self.subTest(xdata=xdata, ydata=ydata, zdata=zdata):
+                self.assertTrue(
+                    build_data_frame(xdata, ydata, zdata).equals(expected))