aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/unit/computations/test_partial_correlations.py82
1 files changed, 79 insertions, 3 deletions
diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py
index 981801a..28c9548 100644
--- a/tests/unit/computations/test_partial_correlations.py
+++ b/tests/unit/computations/test_partial_correlations.py
@@ -6,6 +6,8 @@ from unittest import TestCase
import pandas
from gn3.settings import ROUND_TO
+from gn3.function_helpers import compose
+from gn3.data_helpers import partition_by
from gn3.computations.partial_correlations import (
fix_samples,
control_samples,
@@ -120,6 +122,80 @@ def parse_test_data_csv(filename):
"result": round(float(line["result"]), ROUND_TO)
} for line in lines)
+def parse_method(key_value):
+ """Parse the partial correlation method"""
+ key, value = key_value
+ if key == "method":
+ methods_dict = {"p": "pearson", "k": "kendall", "s": "spearman"}
+ return (key, methods_dict[value])
+ return key_value
+
+def parse_count(key_value):
+ """Parse the value of count into an integer"""
+ key, value = key_value
+ if key == "count":
+ return (key, int(value))
+ return key_value
+
+def parse_xyz(key_value):
+ """Parse the values of x, y, and z* items into sequences of floats"""
+ key, value = key_value
+ if (key in ("x", "y", "z")) or key.startswith("input.z"):
+ return (
+ key.replace("input", "").replace(".", ""),
+ tuple(float(val.strip("\n\t ")) for val in value.split(",")))
+ return key_value
+
+def parse_rm(key_value):
+ """Parse the rm value into a python True/False value."""
+ key, value = key_value
+ if key == "rm":
+ return (key, value == "TRUE")
+ return key_value
+
+def parse_result(key_value):
+ """Parse the result into a float value."""
+ key, value = key_value
+ if key == "result":
+ return (key, float(value))
+ return key_value
+
+parser_function = compose(
+ parse_result,
+ parse_rm,
+ parse_xyz,
+ parse_count,
+ parse_method,
+ lambda k_v: tuple(item.strip("\n\t ") for item in k_v),
+ lambda s: s.split(":"))
+
+def parse_input_line(line):
+ return tuple(
+ parser_function(item) for item in line if not item.startswith("------"))
+
+def merge_z(item):
+ without_z = {
+ key: val for key, val in item.items() if not key.startswith("z")}
+ return {
+ **without_z,
+ "z": item.get(
+ "z",
+ tuple(val for key, val in item.items() if key.startswith("z")))}
+
+def parse_input(lines):
+ return tuple(
+ merge_z(dict(item))
+ for item in (parse_input_line(line) for line in lines)
+ if len(item) != 0)
+
+def parse_test_data(filename):
+ with open("pcor_rec_blackbox_attempt.txt", newline="\n") as fl:
+ input_lines = partition_by(
+ lambda s: s.startswith("------"),
+ (line.strip("\n\t ") for line in fl.readlines()))
+
+ return parse_input(input_lines)
+
class TestPartialCorrelations(TestCase):
"""Class for testing partial correlations computation functions"""
@@ -343,9 +419,9 @@ class TestPartialCorrelations(TestCase):
Test that `partial_correlation_recursive` computes the appropriate
correlation value.
"""
- for sample in parse_test_data_csv(
+ for sample in parse_test_data(
("tests/unit/computations/partial_correlations_test_data/"
- "pcor_rec_blackbox_test.csv")):
+ "pcor_rec_blackbox_test.txt")):
with self.subTest(
xdata=sample["x"], ydata=sample["y"], zdata=sample["z"],
method=sample["method"], omit_nones=sample["rm"]):
@@ -353,4 +429,4 @@ class TestPartialCorrelations(TestCase):
partial_correlation_recursive(
sample["x"], sample["y"], sample["z"],
method=sample["method"], omit_nones=sample["rm"]),
- sample["result"])
+ round(sample["result"], ROUND_TO))