From 89691df0d7ba096fb7a154aca3adf40f8dfaa8ae Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Tue, 9 Nov 2021 10:16:50 +0300 Subject: Implement remaining part of `partial_correlation_recursive` function Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/partial-correlations.gmi * gn3/computations/partial_correlations.py: implement remaining portion of `partial_correlation_recursive` function. * tests/unit/computations/test_partial_correlations.py: add parsing for new data format and update tests --- gn3/computations/partial_correlations.py | 27 ++++++- .../unit/computations/test_partial_correlations.py | 82 +++++++++++++++++++++- 2 files changed, 105 insertions(+), 4 deletions(-) diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py index 9d73197..07a67be 100644 --- a/gn3/computations/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -332,4 +332,29 @@ def partial_correlation_recursive( (corrs["rxy"] - corrs["rxz"] * corrs["ryz"]) / (math.sqrt(1 - corrs["rxz"]**2) * math.sqrt(1 - corrs["ryz"]**2))), ROUND_TO) - return round(0, ROUND_TO) + + remaining_cols = [ + colname for colname, series in data.items() + if colname not in ("x", "y", "z0") + ] + + new_xdata = tuple(data["x"]) + new_ydata = tuple(data["y"]) + zc = tuple( + tuple(row_series[1]) + for row_series in data[remaining_cols].iterrows()) + + rxy_zc = partial_correlation_recursive( + new_xdata, new_ydata, zc, method=method, + omit_nones=omit_nones) + rxz0_zc = partial_correlation_recursive( + new_xdata, tuple(data["z0"]), zc, method=method, + omit_nones=omit_nones) + ryz0_zc = partial_correlation_recursive( + new_ydata, tuple(data["z0"]), zc, method=method, + omit_nones=omit_nones) + + return round( + ((rxy_zc - rxz0_zc * ryz0_zc) /( + math.sqrt(1 - rxz0_zc**2) * math.sqrt(1 - ryz0_zc**2))), + ROUND_TO) diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py index 981801a..28c9548 100644 --- a/tests/unit/computations/test_partial_correlations.py +++ b/tests/unit/computations/test_partial_correlations.py @@ -6,6 +6,8 @@ from unittest import TestCase import pandas from gn3.settings import ROUND_TO +from gn3.function_helpers import compose +from gn3.data_helpers import partition_by from gn3.computations.partial_correlations import ( fix_samples, control_samples, @@ -120,6 +122,80 @@ def parse_test_data_csv(filename): "result": round(float(line["result"]), ROUND_TO) } for line in lines) +def parse_method(key_value): + """Parse the partial correlation method""" + key, value = key_value + if key == "method": + methods_dict = {"p": "pearson", "k": "kendall", "s": "spearman"} + return (key, methods_dict[value]) + return key_value + +def parse_count(key_value): + """Parse the value of count into an integer""" + key, value = key_value + if key == "count": + return (key, int(value)) + return key_value + +def parse_xyz(key_value): + """Parse the values of x, y, and z* items into sequences of floats""" + key, value = key_value + if (key in ("x", "y", "z")) or key.startswith("input.z"): + return ( + key.replace("input", "").replace(".", ""), + tuple(float(val.strip("\n\t ")) for val in value.split(","))) + return key_value + +def parse_rm(key_value): + """Parse the rm value into a python True/False value.""" + key, value = key_value + if key == "rm": + return (key, value == "TRUE") + return key_value + +def parse_result(key_value): + """Parse the result into a float value.""" + key, value = key_value + if key == "result": + return (key, float(value)) + return key_value + +parser_function = compose( + parse_result, + parse_rm, + parse_xyz, + parse_count, + parse_method, + lambda k_v: tuple(item.strip("\n\t ") for item in k_v), + lambda s: s.split(":")) + +def parse_input_line(line): + return tuple( + parser_function(item) for item in line if not item.startswith("------")) + +def merge_z(item): + without_z = { + key: val for key, val in item.items() if not key.startswith("z")} + return { + **without_z, + "z": item.get( + "z", + tuple(val for key, val in item.items() if key.startswith("z")))} + +def parse_input(lines): + return tuple( + merge_z(dict(item)) + for item in (parse_input_line(line) for line in lines) + if len(item) != 0) + +def parse_test_data(filename): + with open("pcor_rec_blackbox_attempt.txt", newline="\n") as fl: + input_lines = partition_by( + lambda s: s.startswith("------"), + (line.strip("\n\t ") for line in fl.readlines())) + + return parse_input(input_lines) + class TestPartialCorrelations(TestCase): """Class for testing partial correlations computation functions""" @@ -343,9 +419,9 @@ class TestPartialCorrelations(TestCase): Test that `partial_correlation_recursive` computes the appropriate correlation value. """ - for sample in parse_test_data_csv( + for sample in parse_test_data( ("tests/unit/computations/partial_correlations_test_data/" - "pcor_rec_blackbox_test.csv")): + "pcor_rec_blackbox_test.txt")): with self.subTest( xdata=sample["x"], ydata=sample["y"], zdata=sample["z"], method=sample["method"], omit_nones=sample["rm"]): @@ -353,4 +429,4 @@ class TestPartialCorrelations(TestCase): partial_correlation_recursive( sample["x"], sample["y"], sample["z"], method=sample["method"], omit_nones=sample["rm"]), - sample["result"]) + round(sample["result"], ROUND_TO)) -- cgit v1.2.3