aboutsummaryrefslogtreecommitdiff
path: root/tests/unit/computations/test_partial_correlations.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unit/computations/test_partial_correlations.py')
-rw-r--r--tests/unit/computations/test_partial_correlations.py61
1 files changed, 60 insertions, 1 deletions
diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py
index f7217a9..c5c35d1 100644
--- a/tests/unit/computations/test_partial_correlations.py
+++ b/tests/unit/computations/test_partial_correlations.py
@@ -1,5 +1,6 @@
"""Module contains tests for gn3.partial_correlations"""
+import csv
from unittest import TestCase
from gn3.computations.partial_correlations import (
fix_samples,
@@ -7,7 +8,9 @@ from gn3.computations.partial_correlations import (
dictify_by_samples,
tissue_correlation,
find_identical_traits,
- good_dataset_samples_indexes)
+ partial_correlation_matrix,
+ good_dataset_samples_indexes,
+ partial_correlation_recursive)
sampleslist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"]
control_traits = (
@@ -90,6 +93,28 @@ dictified_control_samples = (
"BXD1": {"sample_name": "BXD1", "value": 7.77141, "variance": None},
"BXD2": {"sample_name": "BXD2", "value": 7.80944, "variance": None}})
+def parse_test_data_csv(filename):
+ """
+ Parse test data csv files for R -> Python conversion of some functions.
+ """
+ def __str__to_tuple(line, field):
+ return tuple(float(s.strip()) for s in line[field].split(","))
+
+ with open(filename, newline="\n") as csvfile:
+ reader = csv.DictReader(csvfile, delimiter=",", quotechar='"')
+ lines = tuple(row for row in reader)
+
+ methods = {"p": "pearson", "s": "spearman", "k": "kendall"}
+ return tuple({
+ **line,
+ "x": __str__to_tuple(line, "x"),
+ "y": __str__to_tuple(line, "y"),
+ "z": __str__to_tuple(line, "z"),
+ "method": methods[line["method"]],
+ "rm": line["rm"] == "TRUE",
+ "result": float(line["result"])
+ } for line in lines)
+
class TestPartialCorrelations(TestCase):
"""Class for testing partial correlations computation functions"""
@@ -271,3 +296,37 @@ class TestPartialCorrelations(TestCase):
("a", "e", "i", "k"),
("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l")),
(0, 4, 8, 10))
+
+ def test_partial_correlation_matrix(self):
+ """
+ Test that `partial_correlation_matrix` computes the appropriate
+ correlation value.
+ """
+ for sample in parse_test_data_csv(
+ ("tests/unit/computations/partial_correlations_test_data/"
+ "pcor_mat_blackbox_test.csv")):
+ with self.subTest(
+ xdata=sample["x"], ydata=sample["y"], zdata=sample["z"],
+ method=sample["method"], omit_nones=sample["rm"]):
+ self.assertEqual(
+ partial_correlation_matrix(
+ sample["x"], sample["y"], sample["z"],
+ method=sample["method"], omit_nones=sample["rm"]),
+ sample["result"])
+
+ def test_partial_correlation_recursive(self):
+ """
+ Test that `partial_correlation_recursive` computes the appropriate
+ correlation value.
+ """
+ for sample in parse_test_data_csv(
+ ("tests/unit/computations/partial_correlations_test_data/"
+ "pcor_rec_blackbox_test.csv")):
+ with self.subTest(
+ xdata=sample["x"], ydata=sample["y"], zdata=sample["z"],
+ method=sample["method"], omit_nones=sample["rm"]):
+ self.assertEqual(
+ partial_correlation_recursive(
+ sample["x"], sample["y"], sample["z"],
+ method=sample["method"], omit_nones=sample["rm"]),
+ sample["result"])