about summary refs log tree commit diff
path: root/tests/unit/computations
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2021-11-09 10:16:50 +0300
committerFrederick Muriuki Muriithi2021-11-09 10:16:50 +0300
commit89691df0d7ba096fb7a154aca3adf40f8dfaa8ae (patch)
tree8a6541abc927d554b6b57c1634763687d136db42 /tests/unit/computations
parentb38be1a60ea04e3087d36ff68e7b047922295d4e (diff)
downloadgenenetwork3-89691df0d7ba096fb7a154aca3adf40f8dfaa8ae.tar.gz
Implement remaining part of `partial_correlation_recursive` function
Issue:
https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/partial-correlations.gmi

* gn3/computations/partial_correlations.py: implement remaining portion of
  `partial_correlation_recursive` function.
* tests/unit/computations/test_partial_correlations.py: add parsing for new
  data format and update tests
Diffstat (limited to 'tests/unit/computations')
-rw-r--r--tests/unit/computations/test_partial_correlations.py82
1 files changed, 79 insertions, 3 deletions
diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py
index 981801a..28c9548 100644
--- a/tests/unit/computations/test_partial_correlations.py
+++ b/tests/unit/computations/test_partial_correlations.py
@@ -6,6 +6,8 @@ from unittest import TestCase
 import pandas
 
 from gn3.settings import ROUND_TO
+from gn3.function_helpers import compose
+from gn3.data_helpers import partition_by
 from gn3.computations.partial_correlations import (
     fix_samples,
     control_samples,
@@ -120,6 +122,80 @@ def parse_test_data_csv(filename):
         "result": round(float(line["result"]), ROUND_TO)
     } for line in lines)
 
+def parse_method(key_value):
+    """Parse the partial correlation method"""
+    key, value = key_value
+    if key == "method":
+        methods_dict = {"p": "pearson", "k": "kendall", "s": "spearman"}
+        return (key, methods_dict[value])
+    return key_value
+
+def parse_count(key_value):
+    """Parse the value of count into an integer"""
+    key, value = key_value
+    if key == "count":
+        return (key, int(value))
+    return key_value
+
+def parse_xyz(key_value):
+    """Parse the values of x, y, and z* items into sequences of floats"""
+    key, value = key_value
+    if (key in ("x", "y", "z")) or key.startswith("input.z"):
+        return (
+            key.replace("input", "").replace(".", ""),
+            tuple(float(val.strip("\n\t ")) for val in value.split(",")))
+    return key_value
+
+def parse_rm(key_value):
+    """Parse the rm value into a python True/False value."""
+    key, value = key_value
+    if key == "rm":
+        return (key, value == "TRUE")
+    return key_value
+
+def parse_result(key_value):
+    """Parse the result into a float value."""
+    key, value = key_value
+    if key == "result":
+        return (key, float(value))
+    return key_value
+
+parser_function = compose(
+    parse_result,
+    parse_rm,
+    parse_xyz,
+    parse_count,
+    parse_method,
+    lambda k_v: tuple(item.strip("\n\t ") for item in k_v),
+    lambda s: s.split(":"))
+
+def parse_input_line(line):
+    return tuple(
+        parser_function(item) for item in line if not item.startswith("------"))
+
+def merge_z(item):
+    without_z = {
+        key: val for key, val in item.items() if not key.startswith("z")}
+    return {
+        **without_z,
+        "z": item.get(
+            "z",
+            tuple(val for key, val in item.items() if key.startswith("z")))}
+
+def parse_input(lines):
+    return tuple(
+        merge_z(dict(item))
+        for item in (parse_input_line(line) for line in lines)
+        if len(item) != 0)
+
+def parse_test_data(filename):
+    with open("pcor_rec_blackbox_attempt.txt", newline="\n") as fl:
+        input_lines = partition_by(
+            lambda s: s.startswith("------"),
+            (line.strip("\n\t ") for line in fl.readlines()))
+
+    return parse_input(input_lines)
+
 class TestPartialCorrelations(TestCase):
     """Class for testing partial correlations computation functions"""
 
@@ -343,9 +419,9 @@ class TestPartialCorrelations(TestCase):
         Test that `partial_correlation_recursive` computes the appropriate
         correlation value.
         """
-        for sample in parse_test_data_csv(
+        for sample in parse_test_data(
                 ("tests/unit/computations/partial_correlations_test_data/"
-                 "pcor_rec_blackbox_test.csv")):
+                 "pcor_rec_blackbox_test.txt")):
             with self.subTest(
                     xdata=sample["x"], ydata=sample["y"], zdata=sample["z"],
                     method=sample["method"], omit_nones=sample["rm"]):
@@ -353,4 +429,4 @@ class TestPartialCorrelations(TestCase):
                     partial_correlation_recursive(
                         sample["x"], sample["y"], sample["z"],
                         method=sample["method"], omit_nones=sample["rm"]),
-                    sample["result"])
+                    round(sample["result"], ROUND_TO))