aboutsummaryrefslogtreecommitdiff
path: root/tests/unit
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unit')
-rw-r--r--tests/unit/computations/test_partial_correlations.py111
-rw-r--r--tests/unit/test_data_helpers.py15
-rw-r--r--tests/unit/test_heatmaps.py8
3 files changed, 25 insertions, 109 deletions
diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py
index 138155d..f77a066 100644
--- a/tests/unit/computations/test_partial_correlations.py
+++ b/tests/unit/computations/test_partial_correlations.py
@@ -1,14 +1,9 @@
"""Module contains tests for gn3.partial_correlations"""
-import csv
from unittest import TestCase
import pandas
-from gn3.settings import ROUND_TO
-from gn3.function_helpers import compose
-from gn3.data_helpers import partition_by
-
from gn3.computations.partial_correlations import (
fix_samples,
control_samples,
@@ -99,102 +94,6 @@ dictified_control_samples = (
"BXD1": {"sample_name": "BXD1", "value": 7.77141, "variance": None},
"BXD2": {"sample_name": "BXD2", "value": 7.80944, "variance": None}})
-def parse_test_data_csv(filename):
- """
- Parse test data csv files for R -> Python conversion of some functions.
- """
- def __str__to_tuple(line, field):
- return tuple(float(s.strip()) for s in line[field].split(","))
-
- with open(filename, newline="\n") as csvfile:
- reader = csv.DictReader(csvfile, delimiter=",", quotechar='"')
- lines = tuple(row for row in reader)
-
- methods = {"p": "pearson", "s": "spearman", "k": "kendall"}
- return tuple({
- **line,
- "x": __str__to_tuple(line, "x"),
- "y": __str__to_tuple(line, "y"),
- "z": __str__to_tuple(line, "z"),
- "method": methods[line["method"]],
- "rm": line["rm"] == "TRUE",
- "result": round(float(line["result"]), ROUND_TO)
- } for line in lines)
-
-def parse_method(key_value):
- """Parse the partial correlation method"""
- key, value = key_value
- if key == "method":
- methods_dict = {"p": "pearson", "k": "kendall", "s": "spearman"}
- return (key, methods_dict[value])
- return key_value
-
-def parse_count(key_value):
- """Parse the value of count into an integer"""
- key, value = key_value
- if key == "count":
- return (key, int(value))
- return key_value
-
-def parse_xyz(key_value):
- """Parse the values of x, y, and z* items into sequences of floats"""
- key, value = key_value
- if (key in ("x", "y", "z")) or key.startswith("input.z"):
- return (
- key.replace("input", "").replace(".", ""),
- tuple(float(val.strip("\n\t ")) for val in value.split(",")))
- return key_value
-
-def parse_rm(key_value):
- """Parse the rm value into a python True/False value."""
- key, value = key_value
- if key == "rm":
- return (key, value == "TRUE")
- return key_value
-
-def parse_result(key_value):
- """Parse the result into a float value."""
- key, value = key_value
- if key == "result":
- return (key, float(value))
- return key_value
-
-parse_for_rec = compose(
- parse_result,
- parse_rm,
- parse_xyz,
- parse_count,
- parse_method,
- lambda k_v: tuple(item.strip("\n\t ") for item in k_v),
- lambda s: s.split(":"))
-
-def parse_input_line(line, parser_function):
- return tuple(
- parser_function(item) for item in line if not item.startswith("------"))
-
-def merge_z(item):
- without_z = {
- key: val for key, val in item.items() if not key.startswith("z")}
- return {
- **without_z,
- "z": item.get(
- "z",
- tuple(val for key, val in item.items() if key.startswith("z")))}
-
-def parse_input(lines, parser_function):
- return tuple(
- merge_z(dict(item))
- for item in (parse_input_line(line, parser_function) for line in lines)
- if len(item) != 0)
-
-def parse_test_data(filename, parser_function):
- with open(filename, newline="\n") as fl:
- input_lines = partition_by(
- lambda s: s.startswith("------"),
- (line.strip("\n\t ") for line in fl.readlines()))
-
- return parse_input(input_lines, parser_function)
-
class TestPartialCorrelations(TestCase):
"""Class for testing partial correlations computation functions"""
@@ -382,16 +281,16 @@ class TestPartialCorrelations(TestCase):
Check that the function builds the correct data frame.
"""
for xdata, ydata, zdata, expected in (
- ((0.1, 1.1, 2.1), (2.1, 3.1, 4.1), (5.1, 6.1 ,7.1),
+ ((0.1, 1.1, 2.1), (2.1, 3.1, 4.1), (5.1, 6.1, 7.1),
pandas.DataFrame({
"x": (0.1, 1.1, 2.1), "y": (2.1, 3.1, 4.1),
- "z": (5.1, 6.1 ,7.1)})),
+ "z": (5.1, 6.1, 7.1)})),
((0.1, 1.1, 2.1), (2.1, 3.1, 4.1),
- ((5.1, 6.1 ,7.1), (5.2, 6.2, 7.2), (5.3, 6.3, 7.3)),
+ ((5.1, 6.1, 7.1), (5.2, 6.2, 7.2), (5.3, 6.3, 7.3)),
pandas.DataFrame({
"x": (0.1, 1.1, 2.1), "y": (2.1, 3.1, 4.1),
- "z0": (5.1, 6.1, 7.1), "z1": (5.2, 6.2 ,7.2),
- "z2": (5.3, 6.3 ,7.3)}))):
+ "z0": (5.1, 6.1, 7.1), "z1": (5.2, 6.2, 7.2),
+ "z2": (5.3, 6.3, 7.3)}))):
with self.subTest(xdata=xdata, ydata=ydata, zdata=zdata):
self.assertTrue(
build_data_frame(xdata, ydata, zdata).equals(expected))
diff --git a/tests/unit/test_data_helpers.py b/tests/unit/test_data_helpers.py
index 3f76344..88ea469 100644
--- a/tests/unit/test_data_helpers.py
+++ b/tests/unit/test_data_helpers.py
@@ -61,6 +61,21 @@ class TestDataHelpers(TestCase):
expected)
def test_partition_by(self):
+ """
+ Test that `partition_by` groups the data using the given predicate
+
+ Given:
+ - `part_fn`: a predicate funtion that return boolean True/False
+ - `items`: a sequence of items
+ When:
+ - the partitioning predicate function and the sequence of items are
+ passed to the `partition_by` function
+ Then:
+ - the result is a tuple, with sub-tuples containing the data in the
+ original sequence. Each sub-tuple is a partition, ending as soon as
+ the next value in the sequence, when passed to `part_fn`, returns
+ boolean `True`.
+ """
for part_fn, items, expected in (
(lambda s: s.startswith("----"),
("------", "a", "b", "-----", "c", "----", "d", "e", "---",
diff --git a/tests/unit/test_heatmaps.py b/tests/unit/test_heatmaps.py
index e4c929d..69e1c3c 100644
--- a/tests/unit/test_heatmaps.py
+++ b/tests/unit/test_heatmaps.py
@@ -1,6 +1,8 @@
"""Module contains tests for gn3.heatmaps.heatmaps"""
from unittest import TestCase
-from numpy.testing import assert_allclose
+
+from numpy import allclose
+
from gn3.heatmaps import (
cluster_traits,
get_loci_names,
@@ -40,7 +42,7 @@ class TestHeatmap(TestCase):
(6.84118, 7.08432, 7.59844, 7.08229, 7.26774, 7.24991),
(9.45215, 10.6943, 8.64719, 10.1592, 7.75044, 8.78615),
(7.04737, 6.87185, 7.58586, 6.92456, 6.84243, 7.36913)]
- assert_allclose(
+ self.assertTrue(allclose(
cluster_traits(traits_data_list),
((0.0, 0.20337048635536847, 0.16381088984330505, 1.7388553629398245,
1.5025235756329178, 0.6952839500255574, 1.271661230252733,
@@ -72,7 +74,7 @@ class TestHeatmap(TestCase):
(0.7934461515867415, 0.4497104244247795, 0.7127042590637039,
0.9313185954797953, 1.1683723389247052, 0.23451785425383564,
1.7413442197913358, 0.33370067057028485, 1.3256191648260216,
- 0.0)))
+ 0.0))))
def test_compute_heatmap_order(self):
"""Test the orders."""