aboutsummaryrefslogtreecommitdiff
path: root/tests/unit/computations/test_partial_correlations.py
blob: 28c95485ccc2a95c6d4cec35c0f51288b608389d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
"""Module contains tests for gn3.partial_correlations"""

import csv
from unittest import TestCase

import pandas

from gn3.settings import ROUND_TO
from gn3.function_helpers import compose
from gn3.data_helpers import partition_by
from gn3.computations.partial_correlations import (
    fix_samples,
    control_samples,
    build_data_frame,
    dictify_by_samples,
    tissue_correlation,
    find_identical_traits,
    partial_correlation_matrix,
    good_dataset_samples_indexes,
    partial_correlation_recursive)

sampleslist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"]
control_traits = (
    {
        "mysqlid": 36688172,
        "data": {
            "B6cC3-1": {
                "sample_name": "B6cC3-1", "value": 7.51879, "variance": None,
                "ndata": None},
            "BXD1": {
                "sample_name": "BXD1", "value": 7.77141, "variance": None,
                "ndata": None},
            "BXD12": {
                "sample_name": "BXD12", "value": 8.39265, "variance": None,
                "ndata": None},
            "BXD16": {
                "sample_name": "BXD16", "value": 8.17443, "variance": None,
                "ndata": None},
            "BXD19": {
                "sample_name": "BXD19", "value": 8.30401, "variance": None,
                "ndata": None},
            "BXD2": {
                "sample_name": "BXD2", "value": 7.80944, "variance": None,
                "ndata": None}}},
    {
        "mysqlid": 36688172,
        "data": {
            "B6cC3-21": {
                "sample_name": "B6cC3-1", "value": 7.51879, "variance": None,
                "ndata": None},
            "BXD21": {
                "sample_name": "BXD1", "value": 7.77141, "variance": None,
                "ndata": None},
            "BXD12": {
                "sample_name": "BXD12", "value": 8.39265, "variance": None,
                "ndata": None},
            "BXD16": {
                "sample_name": "BXD16", "value": 8.17443, "variance": None,
                "ndata": None},
            "BXD19": {
                "sample_name": "BXD19", "value": 8.30401, "variance": None,
                "ndata": None},
            "BXD2": {
                "sample_name": "BXD2", "value": 7.80944, "variance": None,
                "ndata": None}}},
    {
        "mysqlid": 36688172,
        "data": {
            "B6cC3-1": {
                "sample_name": "B6cC3-1", "value": 7.51879, "variance": None,
                "ndata": None},
            "BXD1": {
                "sample_name": "BXD1", "value": 7.77141, "variance": None,
                "ndata": None},
            "BXD12": {
                "sample_name": "BXD12", "value": None, "variance": None,
                "ndata": None},
            "BXD16": {
                "sample_name": "BXD16", "value": None, "variance": None,
                "ndata": None},
            "BXD19": {
                "sample_name": "BXD19", "value": None, "variance": None,
                "ndata": None},
            "BXD2": {
                "sample_name": "BXD2", "value": 7.80944, "variance": None,
                "ndata": None}}})

dictified_control_samples = (
    {"B6cC3-1": {"sample_name": "B6cC3-1", "value": 7.51879, "variance": None},
     "BXD1": {"sample_name": "BXD1", "value": 7.77141, "variance": None},
     "BXD12": {"sample_name": "BXD12", "value": 8.39265, "variance": None},
     "BXD16": {"sample_name": "BXD16", "value": 8.17443, "variance": None},
     "BXD19": {"sample_name": "BXD19", "value": 8.30401, "variance": None},
     "BXD2": {"sample_name": "BXD2", "value": 7.80944, "variance": None}},
    {"BXD12": {"sample_name": "BXD12", "value": 8.39265, "variance": None},
     "BXD16": {"sample_name": "BXD16", "value": 8.17443, "variance": None},
     "BXD19": {"sample_name": "BXD19", "value": 8.30401, "variance": None},
     "BXD2": {"sample_name": "BXD2", "value": 7.80944, "variance": None}},
    {"B6cC3-1": {"sample_name": "B6cC3-1", "value": 7.51879, "variance": None},
     "BXD1": {"sample_name": "BXD1", "value": 7.77141, "variance": None},
     "BXD2": {"sample_name": "BXD2", "value":  7.80944, "variance": None}})

def parse_test_data_csv(filename):
    """
    Parse test data csv files for R -> Python conversion of some functions.
    """
    def __str__to_tuple(line, field):
        return tuple(float(s.strip()) for s in line[field].split(","))

    with open(filename, newline="\n") as csvfile:
        reader = csv.DictReader(csvfile, delimiter=",", quotechar='"')
        lines = tuple(row for row in reader)

    methods = {"p": "pearson", "s": "spearman", "k": "kendall"}
    return tuple({
        **line,
        "x": __str__to_tuple(line, "x"),
        "y": __str__to_tuple(line, "y"),
        "z": __str__to_tuple(line, "z"),
        "method": methods[line["method"]],
        "rm": line["rm"] == "TRUE",
        "result": round(float(line["result"]), ROUND_TO)
    } for line in lines)

def parse_method(key_value):
    """Parse the partial correlation method"""
    key, value = key_value
    if key == "method":
        methods_dict = {"p": "pearson", "k": "kendall", "s": "spearman"}
        return (key, methods_dict[value])
    return key_value

def parse_count(key_value):
    """Parse the value of count into an integer"""
    key, value = key_value
    if key == "count":
        return (key, int(value))
    return key_value

def parse_xyz(key_value):
    """Parse the values of x, y, and z* items into sequences of floats"""
    key, value = key_value
    if (key in ("x", "y", "z")) or key.startswith("input.z"):
        return (
            key.replace("input", "").replace(".", ""),
            tuple(float(val.strip("\n\t ")) for val in value.split(",")))
    return key_value

def parse_rm(key_value):
    """Parse the rm value into a python True/False value."""
    key, value = key_value
    if key == "rm":
        return (key, value == "TRUE")
    return key_value

def parse_result(key_value):
    """Parse the result into a float value."""
    key, value = key_value
    if key == "result":
        return (key, float(value))
    return key_value

parser_function = compose(
    parse_result,
    parse_rm,
    parse_xyz,
    parse_count,
    parse_method,
    lambda k_v: tuple(item.strip("\n\t ") for item in k_v),
    lambda s: s.split(":"))

def parse_input_line(line):
    return tuple(
        parser_function(item) for item in line if not item.startswith("------"))

def merge_z(item):
    without_z = {
        key: val for key, val in item.items() if not key.startswith("z")}
    return {
        **without_z,
        "z": item.get(
            "z",
            tuple(val for key, val in item.items() if key.startswith("z")))}

def parse_input(lines):
    return tuple(
        merge_z(dict(item))
        for item in (parse_input_line(line) for line in lines)
        if len(item) != 0)

def parse_test_data(filename):
    with open("pcor_rec_blackbox_attempt.txt", newline="\n") as fl:
        input_lines = partition_by(
            lambda s: s.startswith("------"),
            (line.strip("\n\t ") for line in fl.readlines()))

    return parse_input(input_lines)

class TestPartialCorrelations(TestCase):
    """Class for testing partial correlations computation functions"""

    def test_control_samples(self):
        """Test that the control_samples works as expected."""
        self.assertEqual(
            control_samples(control_traits, sampleslist),
            ((("B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"),
              ("BXD12", "BXD16", "BXD19", "BXD2"),
              ("B6cC3-1", "BXD1", "BXD2")),
             ((7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944),
              (8.39265, 8.17443, 8.30401, 7.80944),
              (7.51879, 7.77141, 7.80944)),
             ((None, None, None, None, None, None), (None, None, None, None),
              (None, None, None)),
             (6, 4, 3)))

    def test_dictify_by_samples(self):
        """
        Test that `dictify_by_samples` generates the appropriate dict

        Given:
            a sequence of sequences with sample names, values and variances, as
            in the output of `gn3.partial_correlations.control_samples` or
            the output of `gn3.db.traits.export_informative`
        When:
            the sequence is passed as an argument into the
            `gn3.partial_correlations.dictify_by_sample`
        Then:
            return a sequence of dicts with keys being the values of the sample
            names, and each of who's values being sub-dicts with the keys
            'sample_name', 'value' and 'variance' whose values correspond to the
            values passed in.
        """
        self.assertEqual(
            dictify_by_samples(
                ((("B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"),
                  ("BXD12", "BXD16", "BXD19", "BXD2"),
                  ("B6cC3-1", "BXD1", "BXD2")),
                 ((7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944),
                  (8.39265, 8.17443, 8.30401, 7.80944),
                  (7.51879, 7.77141, 7.80944)),
                 ((None, None, None, None, None, None), (None, None, None, None),
                  (None, None, None)),
                 (6, 4, 3))),
            dictified_control_samples)

    def test_fix_samples(self):
        """
        Test that `fix_samples` returns only the common samples

        Given:
            - A primary trait
            - A sequence of control samples
        When:
            - The two arguments are passed to `fix_samples`
        Then:
            - Only the names of the samples present in the primary trait that
              are also present in ALL the control traits are present in the
              return value
            - Only the values of the samples present in the primary trait that
              are also present in ALL the control traits are present in the
              return value
            - ALL the values for ALL the control traits are present in the
              return value
            - Only the variances of the samples present in the primary trait
              that are also present in ALL the control traits are present in the
              return value
            - ALL the variances for ALL the control traits are present in the
              return value
            - The return value is a tuple of the above items, in the following
              order:
                ((sample_names, ...), (primary_trait_values, ...),
                 (control_traits_values, ...), (primary_trait_variances, ...)
                 (control_traits_variances, ...))
        """
        self.assertEqual(
            fix_samples(
                {"B6cC3-1": {"sample_name": "B6cC3-1", "value": 7.51879,
                             "variance": None},
                 "BXD1": {"sample_name": "BXD1", "value": 7.77141,
                          "variance": None},
                 "BXD2": {"sample_name": "BXD2", "value":  7.80944,
                          "variance": None}},
                dictified_control_samples),
            (("BXD2",), (7.80944,),
             (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944, 8.39265,
              8.17443, 8.30401, 7.80944, 7.51879, 7.77141, 7.80944),
             (None,),
             (None, None, None, None, None, None, None, None, None, None, None,
              None, None)))

    def test_find_identical_traits(self):
        """
        Test `gn3.partial_correlations.find_identical_traits`.

        Given:
            - the name of a primary trait
            - the value of a primary trait
            - a sequence of names of control traits
            - a sequence of values of control traits
        When:
            - the arguments above are passed to the `find_identical_traits`
              function
        Then:
            - Return ALL trait names that have the same value when up to three
              decimal places are considered
        """
        for primn, primv, contn, contv, expected in (
                ("pt", 12.98395, ("ct0", "ct1", "ct2"),
                 (0.1234, 2.3456, 3.4567), tuple()),
                ("pt", 12.98395, ("ct0", "ct1", "ct2"),
                 (12.98354, 2.3456, 3.4567), ("pt", "ct0")),
                ("pt", 12.98395, ("ct0", "ct1", "ct2", "ct3"),
                 (0.1234, 2.3456, 0.1233, 4.5678), ("ct0", "ct2"))
        ):
            with self.subTest(
                    primary_name=primn, primary_value=primv,
                    control_names=contn, control_values=contv):
                self.assertEqual(
                    find_identical_traits(primn, primv, contn, contv), expected)

    def test_tissue_correlation_error(self):
        """
        Test that `tissue_correlation` raises specific exceptions for particular
        error conditions.
        """
        for primary, target, method, error, error_msg in (
                ((1, 2, 3), (4, 5, 6, 7), "pearson",
                 AssertionError,
                 (
                     "The lengths of the `primary_trait_values` and "
                     "`target_trait_values` must be equal")),
                ((1, 2, 3), (4, 5, 6, 7), "spearman",
                 AssertionError,
                 (
                     "The lengths of the `primary_trait_values` and "
                     "`target_trait_values` must be equal")),
                ((1, 2, 3, 4), (5, 6, 7), "pearson",
                 AssertionError,
                 (
                     "The lengths of the `primary_trait_values` and "
                     "`target_trait_values` must be equal")),
                ((1, 2, 3, 4), (5, 6, 7), "spearman",
                 AssertionError,
                 (
                     "The lengths of the `primary_trait_values` and "
                     "`target_trait_values` must be equal")),
                ((1, 2, 3), (4, 5, 6), "nonexistentmethod",
                 AssertionError,
                 (
                     "Method must be one of: pearson, spearman"))):
            with self.subTest(primary=primary, target=target, method=method):
                with self.assertRaises(error, msg=error_msg):
                    tissue_correlation(primary, target, method)

    def test_tissue_correlation(self):
        """
        Test that the correct correlation values are computed for the given:
        - primary trait
        - target trait
        - method
        """
        for primary, target, method, expected in (
                ((12.34, 18.36, 42.51), (37.25, 46.25, 46.56), "pearson",
                 (0.6761779253, 0.5272701134)),
                ((1, 2, 3, 4, 5), (5, 6, 7, 8, 7), "spearman",
                 (0.8207826817, 0.0885870053))):
            with self.subTest(primary=primary, target=target, method=method):
                self.assertEqual(
                    tissue_correlation(primary, target, method), expected)

    def test_good_dataset_samples_indexes(self):
        """
        Test that `good_dataset_samples_indexes` returns correct indices.
        """
        self.assertEqual(
            good_dataset_samples_indexes(
                ("a", "e", "i", "k"),
                ("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l")),
            (0, 4, 8, 10))

    def test_build_data_frame(self):
        """
        Check that the function builds the correct data frame.
        """
        for xdata, ydata, zdata, expected in (
                ((0.1, 1.1, 2.1), (2.1, 3.1, 4.1), (5.1, 6.1 ,7.1),
                 pandas.DataFrame({
                     "x": (0.1, 1.1, 2.1), "y": (2.1, 3.1, 4.1),
                     "z": (5.1, 6.1 ,7.1)})),
                ((0.1, 1.1, 2.1), (2.1, 3.1, 4.1),
                 ((5.1, 6.1 ,7.1), (5.2, 6.2, 7.2), (5.3, 6.3, 7.3)),
                 pandas.DataFrame({
                     "x": (0.1, 1.1, 2.1), "y": (2.1, 3.1, 4.1),
                     "z0": (5.1, 5.2 ,5.3), "z1": (6.1, 6.2 ,6.3),
                     "z2": (7.1, 7.2 ,7.3)}))):
            with self.subTest(xdata=xdata, ydata=ydata, zdata=zdata):
                self.assertTrue(
                    build_data_frame(xdata, ydata, zdata).equals(expected))

    def test_partial_correlation_matrix(self):
        """
        Test that `partial_correlation_matrix` computes the appropriate
        correlation value.
        """
        for sample in parse_test_data_csv(
                ("tests/unit/computations/partial_correlations_test_data/"
                 "pcor_mat_blackbox_test.csv")):
            with self.subTest(
                    xdata=sample["x"], ydata=sample["y"], zdata=sample["z"],
                    method=sample["method"], omit_nones=sample["rm"]):
                self.assertEqual(
                    partial_correlation_matrix(
                        sample["x"], sample["y"], sample["z"],
                        method=sample["method"], omit_nones=sample["rm"]),
                    sample["result"])

    def test_partial_correlation_recursive(self):
        """
        Test that `partial_correlation_recursive` computes the appropriate
        correlation value.
        """
        for sample in parse_test_data(
                ("tests/unit/computations/partial_correlations_test_data/"
                 "pcor_rec_blackbox_test.txt")):
            with self.subTest(
                    xdata=sample["x"], ydata=sample["y"], zdata=sample["z"],
                    method=sample["method"], omit_nones=sample["rm"]):
                self.assertEqual(
                    partial_correlation_recursive(
                        sample["x"], sample["y"], sample["z"],
                        method=sample["method"], omit_nones=sample["rm"]),
                    round(sample["result"], ROUND_TO))