about summary refs log tree commit diff
path: root/tests/unit/computations/test_partial_correlations.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unit/computations/test_partial_correlations.py')
-rw-r--r--tests/unit/computations/test_partial_correlations.py61
1 files changed, 60 insertions, 1 deletions
diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py
index f7217a9..c5c35d1 100644
--- a/tests/unit/computations/test_partial_correlations.py
+++ b/tests/unit/computations/test_partial_correlations.py
@@ -1,5 +1,6 @@
 """Module contains tests for gn3.partial_correlations"""
 
+import csv
 from unittest import TestCase
 from gn3.computations.partial_correlations import (
     fix_samples,
@@ -7,7 +8,9 @@ from gn3.computations.partial_correlations import (
     dictify_by_samples,
     tissue_correlation,
     find_identical_traits,
-    good_dataset_samples_indexes)
+    partial_correlation_matrix,
+    good_dataset_samples_indexes,
+    partial_correlation_recursive)
 
 sampleslist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"]
 control_traits = (
@@ -90,6 +93,28 @@ dictified_control_samples = (
      "BXD1": {"sample_name": "BXD1", "value": 7.77141, "variance": None},
      "BXD2": {"sample_name": "BXD2", "value":  7.80944, "variance": None}})
 
+def parse_test_data_csv(filename):
+    """
+    Parse test data csv files for R -> Python conversion of some functions.
+    """
+    def __str__to_tuple(line, field):
+        return tuple(float(s.strip()) for s in line[field].split(","))
+
+    with open(filename, newline="\n") as csvfile:
+        reader = csv.DictReader(csvfile, delimiter=",", quotechar='"')
+        lines = tuple(row for row in reader)
+
+    methods = {"p": "pearson", "s": "spearman", "k": "kendall"}
+    return tuple({
+        **line,
+        "x": __str__to_tuple(line, "x"),
+        "y": __str__to_tuple(line, "y"),
+        "z": __str__to_tuple(line, "z"),
+        "method": methods[line["method"]],
+        "rm": line["rm"] == "TRUE",
+        "result": float(line["result"])
+    } for line in lines)
+
 class TestPartialCorrelations(TestCase):
     """Class for testing partial correlations computation functions"""
 
@@ -271,3 +296,37 @@ class TestPartialCorrelations(TestCase):
                 ("a", "e", "i", "k"),
                 ("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l")),
             (0, 4, 8, 10))
+
+    def test_partial_correlation_matrix(self):
+        """
+        Test that `partial_correlation_matrix` computes the appropriate
+        correlation value.
+        """
+        for sample in parse_test_data_csv(
+                ("tests/unit/computations/partial_correlations_test_data/"
+                 "pcor_mat_blackbox_test.csv")):
+            with self.subTest(
+                    xdata=sample["x"], ydata=sample["y"], zdata=sample["z"],
+                    method=sample["method"], omit_nones=sample["rm"]):
+                self.assertEqual(
+                    partial_correlation_matrix(
+                        sample["x"], sample["y"], sample["z"],
+                        method=sample["method"], omit_nones=sample["rm"]),
+                    sample["result"])
+
+    def test_partial_correlation_recursive(self):
+        """
+        Test that `partial_correlation_recursive` computes the appropriate
+        correlation value.
+        """
+        for sample in parse_test_data_csv(
+                ("tests/unit/computations/partial_correlations_test_data/"
+                 "pcor_rec_blackbox_test.csv")):
+            with self.subTest(
+                    xdata=sample["x"], ydata=sample["y"], zdata=sample["z"],
+                    method=sample["method"], omit_nones=sample["rm"]):
+                self.assertEqual(
+                    partial_correlation_recursive(
+                        sample["x"], sample["y"], sample["z"],
+                        method=sample["method"], omit_nones=sample["rm"]),
+                    sample["result"])