diff options
Diffstat (limited to 'tests/unit')
-rw-r--r-- | tests/unit/computations/test_partial_correlations.py | 111 | ||||
-rw-r--r-- | tests/unit/test_data_helpers.py | 15 | ||||
-rw-r--r-- | tests/unit/test_heatmaps.py | 8 |
3 files changed, 25 insertions, 109 deletions
diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py index 138155d..f77a066 100644 --- a/tests/unit/computations/test_partial_correlations.py +++ b/tests/unit/computations/test_partial_correlations.py @@ -1,14 +1,9 @@ """Module contains tests for gn3.partial_correlations""" -import csv from unittest import TestCase import pandas -from gn3.settings import ROUND_TO -from gn3.function_helpers import compose -from gn3.data_helpers import partition_by - from gn3.computations.partial_correlations import ( fix_samples, control_samples, @@ -99,102 +94,6 @@ dictified_control_samples = ( "BXD1": {"sample_name": "BXD1", "value": 7.77141, "variance": None}, "BXD2": {"sample_name": "BXD2", "value": 7.80944, "variance": None}}) -def parse_test_data_csv(filename): - """ - Parse test data csv files for R -> Python conversion of some functions. - """ - def __str__to_tuple(line, field): - return tuple(float(s.strip()) for s in line[field].split(",")) - - with open(filename, newline="\n") as csvfile: - reader = csv.DictReader(csvfile, delimiter=",", quotechar='"') - lines = tuple(row for row in reader) - - methods = {"p": "pearson", "s": "spearman", "k": "kendall"} - return tuple({ - **line, - "x": __str__to_tuple(line, "x"), - "y": __str__to_tuple(line, "y"), - "z": __str__to_tuple(line, "z"), - "method": methods[line["method"]], - "rm": line["rm"] == "TRUE", - "result": round(float(line["result"]), ROUND_TO) - } for line in lines) - -def parse_method(key_value): - """Parse the partial correlation method""" - key, value = key_value - if key == "method": - methods_dict = {"p": "pearson", "k": "kendall", "s": "spearman"} - return (key, methods_dict[value]) - return key_value - -def parse_count(key_value): - """Parse the value of count into an integer""" - key, value = key_value - if key == "count": - return (key, int(value)) - return key_value - -def parse_xyz(key_value): - """Parse the values of x, y, and z* items into sequences of floats""" - key, value = key_value - if (key in ("x", "y", "z")) or key.startswith("input.z"): - return ( - key.replace("input", "").replace(".", ""), - tuple(float(val.strip("\n\t ")) for val in value.split(","))) - return key_value - -def parse_rm(key_value): - """Parse the rm value into a python True/False value.""" - key, value = key_value - if key == "rm": - return (key, value == "TRUE") - return key_value - -def parse_result(key_value): - """Parse the result into a float value.""" - key, value = key_value - if key == "result": - return (key, float(value)) - return key_value - -parse_for_rec = compose( - parse_result, - parse_rm, - parse_xyz, - parse_count, - parse_method, - lambda k_v: tuple(item.strip("\n\t ") for item in k_v), - lambda s: s.split(":")) - -def parse_input_line(line, parser_function): - return tuple( - parser_function(item) for item in line if not item.startswith("------")) - -def merge_z(item): - without_z = { - key: val for key, val in item.items() if not key.startswith("z")} - return { - **without_z, - "z": item.get( - "z", - tuple(val for key, val in item.items() if key.startswith("z")))} - -def parse_input(lines, parser_function): - return tuple( - merge_z(dict(item)) - for item in (parse_input_line(line, parser_function) for line in lines) - if len(item) != 0) - -def parse_test_data(filename, parser_function): - with open(filename, newline="\n") as fl: - input_lines = partition_by( - lambda s: s.startswith("------"), - (line.strip("\n\t ") for line in fl.readlines())) - - return parse_input(input_lines, parser_function) - class TestPartialCorrelations(TestCase): """Class for testing partial correlations computation functions""" @@ -382,16 +281,16 @@ class TestPartialCorrelations(TestCase): Check that the function builds the correct data frame. """ for xdata, ydata, zdata, expected in ( - ((0.1, 1.1, 2.1), (2.1, 3.1, 4.1), (5.1, 6.1 ,7.1), + ((0.1, 1.1, 2.1), (2.1, 3.1, 4.1), (5.1, 6.1, 7.1), pandas.DataFrame({ "x": (0.1, 1.1, 2.1), "y": (2.1, 3.1, 4.1), - "z": (5.1, 6.1 ,7.1)})), + "z": (5.1, 6.1, 7.1)})), ((0.1, 1.1, 2.1), (2.1, 3.1, 4.1), - ((5.1, 6.1 ,7.1), (5.2, 6.2, 7.2), (5.3, 6.3, 7.3)), + ((5.1, 6.1, 7.1), (5.2, 6.2, 7.2), (5.3, 6.3, 7.3)), pandas.DataFrame({ "x": (0.1, 1.1, 2.1), "y": (2.1, 3.1, 4.1), - "z0": (5.1, 6.1, 7.1), "z1": (5.2, 6.2 ,7.2), - "z2": (5.3, 6.3 ,7.3)}))): + "z0": (5.1, 6.1, 7.1), "z1": (5.2, 6.2, 7.2), + "z2": (5.3, 6.3, 7.3)}))): with self.subTest(xdata=xdata, ydata=ydata, zdata=zdata): self.assertTrue( build_data_frame(xdata, ydata, zdata).equals(expected)) diff --git a/tests/unit/test_data_helpers.py b/tests/unit/test_data_helpers.py index 3f76344..88ea469 100644 --- a/tests/unit/test_data_helpers.py +++ b/tests/unit/test_data_helpers.py @@ -61,6 +61,21 @@ class TestDataHelpers(TestCase): expected) def test_partition_by(self): + """ + Test that `partition_by` groups the data using the given predicate + + Given: + - `part_fn`: a predicate funtion that return boolean True/False + - `items`: a sequence of items + When: + - the partitioning predicate function and the sequence of items are + passed to the `partition_by` function + Then: + - the result is a tuple, with sub-tuples containing the data in the + original sequence. Each sub-tuple is a partition, ending as soon as + the next value in the sequence, when passed to `part_fn`, returns + boolean `True`. + """ for part_fn, items, expected in ( (lambda s: s.startswith("----"), ("------", "a", "b", "-----", "c", "----", "d", "e", "---", diff --git a/tests/unit/test_heatmaps.py b/tests/unit/test_heatmaps.py index e4c929d..69e1c3c 100644 --- a/tests/unit/test_heatmaps.py +++ b/tests/unit/test_heatmaps.py @@ -1,6 +1,8 @@ """Module contains tests for gn3.heatmaps.heatmaps""" from unittest import TestCase -from numpy.testing import assert_allclose + +from numpy import allclose + from gn3.heatmaps import ( cluster_traits, get_loci_names, @@ -40,7 +42,7 @@ class TestHeatmap(TestCase): (6.84118, 7.08432, 7.59844, 7.08229, 7.26774, 7.24991), (9.45215, 10.6943, 8.64719, 10.1592, 7.75044, 8.78615), (7.04737, 6.87185, 7.58586, 6.92456, 6.84243, 7.36913)] - assert_allclose( + self.assertTrue(allclose( cluster_traits(traits_data_list), ((0.0, 0.20337048635536847, 0.16381088984330505, 1.7388553629398245, 1.5025235756329178, 0.6952839500255574, 1.271661230252733, @@ -72,7 +74,7 @@ class TestHeatmap(TestCase): (0.7934461515867415, 0.4497104244247795, 0.7127042590637039, 0.9313185954797953, 1.1683723389247052, 0.23451785425383564, 1.7413442197913358, 0.33370067057028485, 1.3256191648260216, - 0.0))) + 0.0)))) def test_compute_heatmap_order(self): """Test the orders.""" |