about summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2021-11-18 11:59:53 +0300
committerFrederick Muriuki Muriithi2021-11-18 11:59:53 +0300
commit3dd5fbda7e08999b6470cfe1fbbd19d767adea9b (patch)
treefed4d0ae18d8d39a35184c7e9d80bd942c9f37a3 /tests
parent21fbbfd599c841f082d88ddfc5f4cb362e1eb869 (diff)
downloadgenenetwork3-3dd5fbda7e08999b6470cfe1fbbd19d767adea9b.tar.gz
Fix some linting errors
Issue:
https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/partial-correlations.gmi

* Fix some obvious linting errors and remove obsolete code
Diffstat (limited to 'tests')
-rw-r--r--tests/unit/computations/test_partial_correlations.py111
-rw-r--r--tests/unit/test_data_helpers.py15
-rw-r--r--tests/unit/test_heatmaps.py8
3 files changed, 25 insertions, 109 deletions
diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py
index 138155d..f77a066 100644
--- a/tests/unit/computations/test_partial_correlations.py
+++ b/tests/unit/computations/test_partial_correlations.py
@@ -1,14 +1,9 @@
 """Module contains tests for gn3.partial_correlations"""
 
-import csv
 from unittest import TestCase
 
 import pandas
 
-from gn3.settings import ROUND_TO
-from gn3.function_helpers import compose
-from gn3.data_helpers import partition_by
-
 from gn3.computations.partial_correlations import (
     fix_samples,
     control_samples,
@@ -99,102 +94,6 @@ dictified_control_samples = (
      "BXD1": {"sample_name": "BXD1", "value": 7.77141, "variance": None},
      "BXD2": {"sample_name": "BXD2", "value":  7.80944, "variance": None}})
 
-def parse_test_data_csv(filename):
-    """
-    Parse test data csv files for R -> Python conversion of some functions.
-    """
-    def __str__to_tuple(line, field):
-        return tuple(float(s.strip()) for s in line[field].split(","))
-
-    with open(filename, newline="\n") as csvfile:
-        reader = csv.DictReader(csvfile, delimiter=",", quotechar='"')
-        lines = tuple(row for row in reader)
-
-    methods = {"p": "pearson", "s": "spearman", "k": "kendall"}
-    return tuple({
-        **line,
-        "x": __str__to_tuple(line, "x"),
-        "y": __str__to_tuple(line, "y"),
-        "z": __str__to_tuple(line, "z"),
-        "method": methods[line["method"]],
-        "rm": line["rm"] == "TRUE",
-        "result": round(float(line["result"]), ROUND_TO)
-    } for line in lines)
-
-def parse_method(key_value):
-    """Parse the partial correlation method"""
-    key, value = key_value
-    if key == "method":
-        methods_dict = {"p": "pearson", "k": "kendall", "s": "spearman"}
-        return (key, methods_dict[value])
-    return key_value
-
-def parse_count(key_value):
-    """Parse the value of count into an integer"""
-    key, value = key_value
-    if key == "count":
-        return (key, int(value))
-    return key_value
-
-def parse_xyz(key_value):
-    """Parse the values of x, y, and z* items into sequences of floats"""
-    key, value = key_value
-    if (key in ("x", "y", "z")) or key.startswith("input.z"):
-        return (
-            key.replace("input", "").replace(".", ""),
-            tuple(float(val.strip("\n\t ")) for val in value.split(",")))
-    return key_value
-
-def parse_rm(key_value):
-    """Parse the rm value into a python True/False value."""
-    key, value = key_value
-    if key == "rm":
-        return (key, value == "TRUE")
-    return key_value
-
-def parse_result(key_value):
-    """Parse the result into a float value."""
-    key, value = key_value
-    if key == "result":
-        return (key, float(value))
-    return key_value
-
-parse_for_rec = compose(
-    parse_result,
-    parse_rm,
-    parse_xyz,
-    parse_count,
-    parse_method,
-    lambda k_v: tuple(item.strip("\n\t ") for item in k_v),
-    lambda s: s.split(":"))
-
-def parse_input_line(line, parser_function):
-    return tuple(
-        parser_function(item) for item in line if not item.startswith("------"))
-
-def merge_z(item):
-    without_z = {
-        key: val for key, val in item.items() if not key.startswith("z")}
-    return {
-        **without_z,
-        "z": item.get(
-            "z",
-            tuple(val for key, val in item.items() if key.startswith("z")))}
-
-def parse_input(lines, parser_function):
-    return tuple(
-        merge_z(dict(item))
-        for item in (parse_input_line(line, parser_function) for line in lines)
-        if len(item) != 0)
-
-def parse_test_data(filename, parser_function):
-    with open(filename, newline="\n") as fl:
-        input_lines = partition_by(
-            lambda s: s.startswith("------"),
-            (line.strip("\n\t ") for line in fl.readlines()))
-
-    return parse_input(input_lines, parser_function)
-
 class TestPartialCorrelations(TestCase):
     """Class for testing partial correlations computation functions"""
 
@@ -382,16 +281,16 @@ class TestPartialCorrelations(TestCase):
         Check that the function builds the correct data frame.
         """
         for xdata, ydata, zdata, expected in (
-                ((0.1, 1.1, 2.1), (2.1, 3.1, 4.1), (5.1, 6.1 ,7.1),
+                ((0.1, 1.1, 2.1), (2.1, 3.1, 4.1), (5.1, 6.1, 7.1),
                  pandas.DataFrame({
                      "x": (0.1, 1.1, 2.1), "y": (2.1, 3.1, 4.1),
-                     "z": (5.1, 6.1 ,7.1)})),
+                     "z": (5.1, 6.1, 7.1)})),
                 ((0.1, 1.1, 2.1), (2.1, 3.1, 4.1),
-                 ((5.1, 6.1 ,7.1), (5.2, 6.2, 7.2), (5.3, 6.3, 7.3)),
+                 ((5.1, 6.1, 7.1), (5.2, 6.2, 7.2), (5.3, 6.3, 7.3)),
                  pandas.DataFrame({
                      "x": (0.1, 1.1, 2.1), "y": (2.1, 3.1, 4.1),
-                     "z0": (5.1, 6.1, 7.1), "z1": (5.2, 6.2 ,7.2),
-                     "z2": (5.3, 6.3 ,7.3)}))):
+                     "z0": (5.1, 6.1, 7.1), "z1": (5.2, 6.2, 7.2),
+                     "z2": (5.3, 6.3, 7.3)}))):
             with self.subTest(xdata=xdata, ydata=ydata, zdata=zdata):
                 self.assertTrue(
                     build_data_frame(xdata, ydata, zdata).equals(expected))
diff --git a/tests/unit/test_data_helpers.py b/tests/unit/test_data_helpers.py
index 3f76344..88ea469 100644
--- a/tests/unit/test_data_helpers.py
+++ b/tests/unit/test_data_helpers.py
@@ -61,6 +61,21 @@ class TestDataHelpers(TestCase):
                     expected)
 
     def test_partition_by(self):
+        """
+        Test that `partition_by` groups the data using the given predicate
+
+        Given:
+          - `part_fn`: a predicate funtion that return boolean True/False
+          - `items`: a sequence of items
+        When:
+          - the partitioning predicate function and the sequence of items are
+            passed to the `partition_by` function
+        Then:
+          - the result is a tuple, with sub-tuples containing the data in the
+            original sequence. Each sub-tuple is a partition, ending as soon as
+            the next value in the sequence, when passed to `part_fn`, returns
+            boolean `True`.
+        """
         for part_fn, items, expected in (
                 (lambda s: s.startswith("----"),
                  ("------", "a", "b", "-----", "c", "----", "d", "e", "---",
diff --git a/tests/unit/test_heatmaps.py b/tests/unit/test_heatmaps.py
index e4c929d..69e1c3c 100644
--- a/tests/unit/test_heatmaps.py
+++ b/tests/unit/test_heatmaps.py
@@ -1,6 +1,8 @@
 """Module contains tests for gn3.heatmaps.heatmaps"""
 from unittest import TestCase
-from numpy.testing import assert_allclose
+
+from numpy import allclose
+
 from gn3.heatmaps import (
     cluster_traits,
     get_loci_names,
@@ -40,7 +42,7 @@ class TestHeatmap(TestCase):
             (6.84118, 7.08432, 7.59844, 7.08229, 7.26774, 7.24991),
             (9.45215, 10.6943, 8.64719, 10.1592, 7.75044, 8.78615),
             (7.04737, 6.87185, 7.58586, 6.92456, 6.84243, 7.36913)]
-        assert_allclose(
+        self.assertTrue(allclose(
             cluster_traits(traits_data_list),
             ((0.0, 0.20337048635536847, 0.16381088984330505, 1.7388553629398245,
               1.5025235756329178, 0.6952839500255574, 1.271661230252733,
@@ -72,7 +74,7 @@ class TestHeatmap(TestCase):
              (0.7934461515867415, 0.4497104244247795, 0.7127042590637039,
               0.9313185954797953, 1.1683723389247052, 0.23451785425383564,
               1.7413442197913358, 0.33370067057028485, 1.3256191648260216,
-              0.0)))
+              0.0))))
 
     def test_compute_heatmap_order(self):
         """Test the orders."""