aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2021-11-09 10:16:50 +0300
committerFrederick Muriuki Muriithi2021-11-09 10:16:50 +0300
commit89691df0d7ba096fb7a154aca3adf40f8dfaa8ae (patch)
tree8a6541abc927d554b6b57c1634763687d136db42
parentb38be1a60ea04e3087d36ff68e7b047922295d4e (diff)
downloadgenenetwork3-89691df0d7ba096fb7a154aca3adf40f8dfaa8ae.tar.gz
Implement remaining part of `partial_correlation_recursive` function
Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/partial-correlations.gmi * gn3/computations/partial_correlations.py: implement remaining portion of `partial_correlation_recursive` function. * tests/unit/computations/test_partial_correlations.py: add parsing for new data format and update tests
-rw-r--r--gn3/computations/partial_correlations.py27
-rw-r--r--tests/unit/computations/test_partial_correlations.py82
2 files changed, 105 insertions, 4 deletions
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 9d73197..07a67be 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -332,4 +332,29 @@ def partial_correlation_recursive(
(corrs["rxy"] - corrs["rxz"] * corrs["ryz"]) /
(math.sqrt(1 - corrs["rxz"]**2) *
math.sqrt(1 - corrs["ryz"]**2))), ROUND_TO)
- return round(0, ROUND_TO)
+
+ remaining_cols = [
+ colname for colname, series in data.items()
+ if colname not in ("x", "y", "z0")
+ ]
+
+ new_xdata = tuple(data["x"])
+ new_ydata = tuple(data["y"])
+ zc = tuple(
+ tuple(row_series[1])
+ for row_series in data[remaining_cols].iterrows())
+
+ rxy_zc = partial_correlation_recursive(
+ new_xdata, new_ydata, zc, method=method,
+ omit_nones=omit_nones)
+ rxz0_zc = partial_correlation_recursive(
+ new_xdata, tuple(data["z0"]), zc, method=method,
+ omit_nones=omit_nones)
+ ryz0_zc = partial_correlation_recursive(
+ new_ydata, tuple(data["z0"]), zc, method=method,
+ omit_nones=omit_nones)
+
+ return round(
+ ((rxy_zc - rxz0_zc * ryz0_zc) /(
+ math.sqrt(1 - rxz0_zc**2) * math.sqrt(1 - ryz0_zc**2))),
+ ROUND_TO)
diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py
index 981801a..28c9548 100644
--- a/tests/unit/computations/test_partial_correlations.py
+++ b/tests/unit/computations/test_partial_correlations.py
@@ -6,6 +6,8 @@ from unittest import TestCase
import pandas
from gn3.settings import ROUND_TO
+from gn3.function_helpers import compose
+from gn3.data_helpers import partition_by
from gn3.computations.partial_correlations import (
fix_samples,
control_samples,
@@ -120,6 +122,80 @@ def parse_test_data_csv(filename):
"result": round(float(line["result"]), ROUND_TO)
} for line in lines)
+def parse_method(key_value):
+ """Parse the partial correlation method"""
+ key, value = key_value
+ if key == "method":
+ methods_dict = {"p": "pearson", "k": "kendall", "s": "spearman"}
+ return (key, methods_dict[value])
+ return key_value
+
+def parse_count(key_value):
+ """Parse the value of count into an integer"""
+ key, value = key_value
+ if key == "count":
+ return (key, int(value))
+ return key_value
+
+def parse_xyz(key_value):
+ """Parse the values of x, y, and z* items into sequences of floats"""
+ key, value = key_value
+ if (key in ("x", "y", "z")) or key.startswith("input.z"):
+ return (
+ key.replace("input", "").replace(".", ""),
+ tuple(float(val.strip("\n\t ")) for val in value.split(",")))
+ return key_value
+
+def parse_rm(key_value):
+ """Parse the rm value into a python True/False value."""
+ key, value = key_value
+ if key == "rm":
+ return (key, value == "TRUE")
+ return key_value
+
+def parse_result(key_value):
+ """Parse the result into a float value."""
+ key, value = key_value
+ if key == "result":
+ return (key, float(value))
+ return key_value
+
+parser_function = compose(
+ parse_result,
+ parse_rm,
+ parse_xyz,
+ parse_count,
+ parse_method,
+ lambda k_v: tuple(item.strip("\n\t ") for item in k_v),
+ lambda s: s.split(":"))
+
+def parse_input_line(line):
+ return tuple(
+ parser_function(item) for item in line if not item.startswith("------"))
+
+def merge_z(item):
+ without_z = {
+ key: val for key, val in item.items() if not key.startswith("z")}
+ return {
+ **without_z,
+ "z": item.get(
+ "z",
+ tuple(val for key, val in item.items() if key.startswith("z")))}
+
+def parse_input(lines):
+ return tuple(
+ merge_z(dict(item))
+ for item in (parse_input_line(line) for line in lines)
+ if len(item) != 0)
+
+def parse_test_data(filename):
+ with open("pcor_rec_blackbox_attempt.txt", newline="\n") as fl:
+ input_lines = partition_by(
+ lambda s: s.startswith("------"),
+ (line.strip("\n\t ") for line in fl.readlines()))
+
+ return parse_input(input_lines)
+
class TestPartialCorrelations(TestCase):
"""Class for testing partial correlations computation functions"""
@@ -343,9 +419,9 @@ class TestPartialCorrelations(TestCase):
Test that `partial_correlation_recursive` computes the appropriate
correlation value.
"""
- for sample in parse_test_data_csv(
+ for sample in parse_test_data(
("tests/unit/computations/partial_correlations_test_data/"
- "pcor_rec_blackbox_test.csv")):
+ "pcor_rec_blackbox_test.txt")):
with self.subTest(
xdata=sample["x"], ydata=sample["y"], zdata=sample["z"],
method=sample["method"], omit_nones=sample["rm"]):
@@ -353,4 +429,4 @@ class TestPartialCorrelations(TestCase):
partial_correlation_recursive(
sample["x"], sample["y"], sample["z"],
method=sample["method"], omit_nones=sample["rm"]),
- sample["result"])
+ round(sample["result"], ROUND_TO))