From 1a9d28e6db2140cc7b3491c6dbcf4fc8cd8c09b6 Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Tue, 17 Aug 2021 08:47:11 +0300 Subject: Add tests and fix errors caught with tests Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * gn3/computations/heatmap.py: fix errors * tests/unit/computations/test_heatmap.py: new tests Add new tests with the expected source data format, and expected results. Fix all errors that were caught by running the tests --- gn3/computations/heatmap.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index a0e778a..8a86fe8 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -34,11 +34,11 @@ def export_trait_data( """ def __export_all_types(tdata, strain): sample_data = [] - if tdata[strain]["val"]: - sample_data.append(tdata[strain]["val"]) + if tdata[strain]["value"]: + sample_data.append(tdata[strain]["value"]) if var_exists: - if tdata[strain].var: - sample_data.append(tdata[strain]["var"]) + if tdata[strain]["variance"]: + sample_data.append(tdata[strain]["variance"]) else: sample_data.append(None) if n_exists: @@ -58,15 +58,15 @@ def export_trait_data( def __exporter(accumulator, strain): # pylint: disable=[R0911] - if trait_data.has_key(strain): + if strain in trait_data["data"]: if dtype == "val": - return accumulator + (trait_data[strain]["val"], ) + return accumulator + (trait_data["data"][strain]["value"], ) if dtype == "var": - return accumulator + (trait_data[strain]["var"], ) + return accumulator + (trait_data["data"][strain]["variance"], ) if dtype == "N": - return trait_data[strain]["ndata"] + return accumulator + (trait_data["data"][strain]["ndata"], ) if dtype == "all": - return accumulator + __export_all_types(trait_data, strain) + return accumulator + __export_all_types(trait_data["data"], strain) raise KeyError("Type `%s` is incorrect" % dtype) if var_exists and n_exists: return accumulator + (None, None, None) -- cgit v1.2.3 From 41fc5136914548710529cbed7ef370dfb5b4a5c8 Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Tue, 17 Aug 2021 11:43:32 +0300 Subject: Test the clustering Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * gn3/computations/heatmap.py: Fix clustering bugs * tests/unit/computations/test_heatmap.py: Add new tests. Fix linting issues. Test and fix the clustering function. --- gn3/computations/heatmap.py | 14 ++-- tests/unit/computations/test_heatmap.py | 109 +++++++++++++++++++++++++++++--- 2 files changed, 106 insertions(+), 17 deletions(-) (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index 8a86fe8..3c35029 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -110,13 +110,13 @@ def cluster_traits(traits_data_list: Sequence[Dict]): https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L162 """ def __compute_corr(tdata_i, tdata_j): - if tdata_j[0] < tdata_i[0]: - corr_vals = compute_correlation(tdata_i, tdata_j) - corr = corr_vals[0] - if (1 - corr) < 0: - return 0.0 - return 1 - corr - return 0.0 + if tdata_i[0] == tdata_j[0]: + return 0.0 + corr_vals = compute_correlation(tdata_i[1], tdata_j[1]) + corr = corr_vals[0] + if (1 - corr) < 0: + return 0.0 + return 1 - corr def __cluster(tdata_i): return tuple( diff --git a/tests/unit/computations/test_heatmap.py b/tests/unit/computations/test_heatmap.py index 78303ae..650cb45 100644 --- a/tests/unit/computations/test_heatmap.py +++ b/tests/unit/computations/test_heatmap.py @@ -1,9 +1,38 @@ """Module contains tests for gn3.computations.heatmap""" from unittest import TestCase -from gn3.computations.heatmap import export_trait_data +from gn3.computations.heatmap import cluster_traits, export_trait_data strainlist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"] -trait_data = {"mysqlid": 36688172, "data": {"B6cC3-1": {"strain_name": "B6cC3-1", "value": 7.51879, "variance": None, "ndata": None}, "BXD1": {"strain_name": "BXD1", "value": 7.77141, "variance": None, "ndata": None}, "BXD12": {"strain_name": "BXD12", "value": 8.39265, "variance": None, "ndata": None}, "BXD16": {"strain_name": "BXD16", "value": 8.17443, "variance": None, "ndata": None}, "BXD19": {"strain_name": "BXD19", "value": 8.30401, "variance": None, "ndata": None}, "BXD2": {"strain_name": "BXD2", "value": 7.80944, "variance": None, "ndata": None}, "BXD21": {"strain_name": "BXD21", "value": 8.93809, "variance": None, "ndata": None}, "BXD24": {"strain_name": "BXD24", "value": 7.99415, "variance": None, "ndata": None}, "BXD27": {"strain_name": "BXD27", "value": 8.12177, "variance": None, "ndata": None}, "BXD28": {"strain_name": "BXD28", "value": 7.67688, "variance": None, "ndata": None}, "BXD32": {"strain_name": "BXD32", "value": 7.79062, "variance": None, "ndata": None}, "BXD39": {"strain_name": "BXD39", "value": 8.27641, "variance": None, "ndata": None}, "BXD40": {"strain_name": "BXD40", "value": 8.18012, "variance": None, "ndata": None}, "BXD42": {"strain_name": "BXD42", "value": 7.82433, "variance": None, "ndata": None}, "BXD6": {"strain_name": "BXD6", "value": 8.09718, "variance": None, "ndata": None}, "BXH14": {"strain_name": "BXH14", "value": 7.97475, "variance": None, "ndata": None}, "BXH19": {"strain_name": "BXH19", "value": 7.67223, "variance": None, "ndata": None}, "BXH2": {"strain_name": "BXH2", "value": 7.93622, "variance": None, "ndata": None}, "BXH22": {"strain_name": "BXH22", "value": 7.43692, "variance": None, "ndata": None}, "BXH4": {"strain_name": "BXH4", "value": 7.96336, "variance": None, "ndata": None}, "BXH6": {"strain_name": "BXH6", "value": 7.75132, "variance": None, "ndata": None}, "BXH7": {"strain_name": "BXH7", "value": 8.12927, "variance": None, "ndata": None}, "BXH8": {"strain_name": "BXH8", "value": 6.77338, "variance": None, "ndata": None}, "BXH9": {"strain_name": "BXH9", "value": 8.03836, "variance": None, "ndata": None}, "C3H/HeJ": {"strain_name": "C3H/HeJ", "value": 7.42795, "variance": None, "ndata": None}, "C57BL/6J": {"strain_name": "C57BL/6J", "value": 7.50606, "variance": None, "ndata": None}, "DBA/2J": {"strain_name": "DBA/2J", "value": 7.72588, "variance": None, "ndata": None}}} +trait_data = { + "mysqlid": 36688172, + "data": { + "B6cC3-1": {"strain_name": "B6cC3-1", "value": 7.51879, "variance": None, "ndata": None}, + "BXD1": {"strain_name": "BXD1", "value": 7.77141, "variance": None, "ndata": None}, + "BXD12": {"strain_name": "BXD12", "value": 8.39265, "variance": None, "ndata": None}, + "BXD16": {"strain_name": "BXD16", "value": 8.17443, "variance": None, "ndata": None}, + "BXD19": {"strain_name": "BXD19", "value": 8.30401, "variance": None, "ndata": None}, + "BXD2": {"strain_name": "BXD2", "value": 7.80944, "variance": None, "ndata": None}, + "BXD21": {"strain_name": "BXD21", "value": 8.93809, "variance": None, "ndata": None}, + "BXD24": {"strain_name": "BXD24", "value": 7.99415, "variance": None, "ndata": None}, + "BXD27": {"strain_name": "BXD27", "value": 8.12177, "variance": None, "ndata": None}, + "BXD28": {"strain_name": "BXD28", "value": 7.67688, "variance": None, "ndata": None}, + "BXD32": {"strain_name": "BXD32", "value": 7.79062, "variance": None, "ndata": None}, + "BXD39": {"strain_name": "BXD39", "value": 8.27641, "variance": None, "ndata": None}, + "BXD40": {"strain_name": "BXD40", "value": 8.18012, "variance": None, "ndata": None}, + "BXD42": {"strain_name": "BXD42", "value": 7.82433, "variance": None, "ndata": None}, + "BXD6": {"strain_name": "BXD6", "value": 8.09718, "variance": None, "ndata": None}, + "BXH14": {"strain_name": "BXH14", "value": 7.97475, "variance": None, "ndata": None}, + "BXH19": {"strain_name": "BXH19", "value": 7.67223, "variance": None, "ndata": None}, + "BXH2": {"strain_name": "BXH2", "value": 7.93622, "variance": None, "ndata": None}, + "BXH22": {"strain_name": "BXH22", "value": 7.43692, "variance": None, "ndata": None}, + "BXH4": {"strain_name": "BXH4", "value": 7.96336, "variance": None, "ndata": None}, + "BXH6": {"strain_name": "BXH6", "value": 7.75132, "variance": None, "ndata": None}, + "BXH7": {"strain_name": "BXH7", "value": 8.12927, "variance": None, "ndata": None}, + "BXH8": {"strain_name": "BXH8", "value": 6.77338, "variance": None, "ndata": None}, + "BXH9": {"strain_name": "BXH9", "value": 8.03836, "variance": None, "ndata": None}, + "C3H/HeJ": {"strain_name": "C3H/HeJ", "value": 7.42795, "variance": None, "ndata": None}, + "C57BL/6J": {"strain_name": "C57BL/6J", "value": 7.50606, "variance": None, "ndata": None}, + "DBA/2J": {"strain_name": "DBA/2J", "value": 7.72588, "variance": None, "ndata": None}}} class TestHeatmap(TestCase): """Class for testing heatmap computation functions""" @@ -29,10 +58,14 @@ class TestHeatmap(TestCase): argument and the different flags set up """ for dtype, vflag, nflag, expected in [ - ["val", False, False, (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], - ["val", False, True, (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], - ["val", True, False, (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], - ["val", True, True, (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["val", False, False, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["val", False, True, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["val", True, False, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["val", True, True, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], ["var", False, False, (None, None, None, None, None, None)], ["var", False, True, (None, None, None, None, None, None)], ["var", True, False, (None, None, None, None, None, None)], @@ -41,10 +74,17 @@ class TestHeatmap(TestCase): ["N", False, True, (None, None, None, None, None, None)], ["N", True, False, (None, None, None, None, None, None)], ["N", True, True, (None, None, None, None, None, None)], - ["all", False, False, (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], - ["all", False, True, (7.51879, None, 7.77141, None, 8.39265, None, 8.17443, None, 8.30401, None, 7.80944, None)], - ["all", True, False, (7.51879, None, 7.77141, None, 8.39265, None, 8.17443, None, 8.30401, None, 7.80944, None)], - ["all", True, True, (7.51879, None, None, 7.77141, None, None, 8.39265, None, None, 8.17443, None, None, 8.30401, None, None, 7.80944, None, None)] + ["all", False, False, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["all", False, True, + (7.51879, None, 7.77141, None, 8.39265, None, 8.17443, None, + 8.30401, None, 7.80944, None)], + ["all", True, False, + (7.51879, None, 7.77141, None, 8.39265, None, 8.17443, None, + 8.30401, None, 7.80944, None)], + ["all", True, True, + (7.51879, None, None, 7.77141, None, None, 8.39265, None, None, + 8.17443, None, None, 8.30401, None, None, 7.80944, None, None)] ]: with self.subTest(dtype=dtype, vflag=vflag, nflag=nflag): self.assertEqual( @@ -52,3 +92,52 @@ class TestHeatmap(TestCase): trait_data, strainlist, dtype=dtype, var_exists=vflag, n_exists=nflag), expected) + + def test_cluster_traits(self): + """ + Test that the clustering is working as expected. + """ + traits_data_list = [ + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944), + (6.1427, 6.50588, 7.73705, 6.68328, 7.49293, 7.27398), + (8.4211, 8.30581, 9.24076, 8.51173, 9.18455, 8.36077), + (10.0904, 10.6509, 9.36716, 9.91202, 8.57444, 10.5731), + (10.188, 9.76652, 9.54813, 9.05074, 9.52319, 9.10505), + (6.74676, 7.01029, 7.54169, 6.48574, 7.01427, 7.26815), + (6.39359, 6.85321, 5.78337, 7.11141, 6.22101, 6.16544), + (6.84118, 7.08432, 7.59844, 7.08229, 7.26774, 7.24991), + (9.45215, 10.6943, 8.64719, 10.1592, 7.75044, 8.78615), + (7.04737, 6.87185, 7.58586, 6.92456, 6.84243, 7.36913)] + self.assertEqual( + cluster_traits(traits_data_list), + ((0.0, 0.20337048635536847, 0.16381088984330505, 1.7388553629398245, + 1.5025235756329178, 0.6952839500255574, 1.271661230252733, + 0.2100487290977544, 1.4699690641062024, 0.7934461515867415), + (0.20337048635536847, 0.0, 0.2198321044997198, 1.5753041735592204, + 1.4815755944537086, 0.26087293140686374, 1.6939790104301427, + 0.06024619831474998, 1.7430082449189215, 0.4497104244247795), + (0.16381088984330505, 0.2198321044997198, 0.0, 1.9073926868549234, + 1.0396738891139845, 0.5278328671176757, 1.6275069061182947, + 0.2636503792482082, 1.739617877037615, 0.7127042590637039), + (1.7388553629398245, 1.5753041735592204, 1.9073926868549234, 0.0, + 0.9936846292920328, 1.1169999189889366, 0.6007483980555253, + 1.430209221053372, 0.25879514152086425, 0.9313185954797953), + (1.5025235756329178, 1.4815755944537086, 1.0396738891139845, + 0.9936846292920328, 0.0, 1.027827186339337, 1.1441743109173244, + 1.4122477962364253, 0.8968250491499363, 1.1683723389247052), + (0.6952839500255574, 0.26087293140686374, 0.5278328671176757, + 1.1169999189889366, 1.027827186339337, 0.0, 1.8420471110023269, + 0.19179284676938602, 1.4875072385631605, 0.23451785425383564), + (1.271661230252733, 1.6939790104301427, 1.6275069061182947, + 0.6007483980555253, 1.1441743109173244, 1.8420471110023269, 0.0, + 1.6540234785929928, 0.2140799896286565, 1.7413442197913358), + (0.2100487290977544, 0.06024619831474998, 0.2636503792482082, + 1.430209221053372, 1.4122477962364253, 0.19179284676938602, + 1.6540234785929928, 0.0, 1.5225640692832796, 0.33370067057028485), + (1.4699690641062024, 1.7430082449189215, 1.739617877037615, + 0.25879514152086425, 0.8968250491499363, 1.4875072385631605, + 0.2140799896286565, 1.5225640692832796, 0.0, 1.3256191648260216), + (0.7934461515867415, 0.4497104244247795, 0.7127042590637039, + 0.9313185954797953, 1.1683723389247052, 0.23451785425383564, + 1.7413442197913358, 0.33370067057028485, 1.3256191648260216, + 0.0))) -- cgit v1.2.3 From ded960e3d32e4d7ebe590deda27fc47175be73d9 Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Fri, 20 Aug 2021 13:21:31 +0300 Subject: Add tests for ordering and implement function Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * gn3/computations/heatmap.py: implement new ordering function * tests/unit/computations/test_heatmap.py: add new tests Implement the ordering function to migrate the setup of the `neworder` variable from GN1 to GN3. This migration is incomplete, since there is dependence on the return from the `web.webqtl.heatmap.Heatmap.draw` function in form of the `d_1` variable in some of the paths. The thing is, this `d_1` variable, and the `xoffset` variable seem to be used for laying out things on the drawn heatmap, and might actually end up not being needed for the new system using plotly, which has other ways of laying out things on the drawing. For now though, this commit "shims" the presence of these values until when the use of these variables is confirmed as present or absent in the new GN3 system. --- gn3/computations/heatmap.py | 28 ++++++++++++++++++++++++++++ tests/unit/computations/test_heatmap.py | 25 ++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index 3c35029..1c86261 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -175,3 +175,31 @@ def heatmap_data(formd, search_result, conn: Any): "traits_list": traits_list, "traits_data_list": traits_data_list } + +def compute_heatmap_order( + slink_data, xoffset: int = 40, neworder: tuple = tuple()): + """ + Compute the data used for drawing the heatmap proper from `slink_data`. + + This function tries to reproduce the creation and update of the `neworder` + variable in + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L120 + and in the `web.webqtl.heatmap.Heatmap.draw` function in GN1 + """ + d_1 = (0, 0, 0) # returned from self.draw in lines 391 and 399. This is just a placeholder + + def __order_maker(norder, slnk_dt): + print("norder:{}, slnk_dt:{}".format(norder, slnk_dt)) + if isinstance(slnk_dt[0], int) and isinstance(slnk_dt[1], int): + return norder + ( + (xoffset+20, slnk_dt[0]), (xoffset + 40, slnk_dt[1])) + + if isinstance(slnk_dt[0], int): + return norder + ((xoffset + 20, slnk_dt[0]), ) + + if isinstance(slnk_dt[1], int): + return norder + ((xoffset + d_1[0] + 20, slnk_dt[1]), ) + + return __order_maker(__order_maker(norder, slnk_dt[0]), slnk_dt[1]) + + return __order_maker(neworder, slink_data) diff --git a/tests/unit/computations/test_heatmap.py b/tests/unit/computations/test_heatmap.py index 650cb45..14807bb 100644 --- a/tests/unit/computations/test_heatmap.py +++ b/tests/unit/computations/test_heatmap.py @@ -1,6 +1,9 @@ """Module contains tests for gn3.computations.heatmap""" from unittest import TestCase -from gn3.computations.heatmap import cluster_traits, export_trait_data +from gn3.computations.heatmap import ( + cluster_traits, + export_trait_data, + compute_heatmap_order) strainlist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"] trait_data = { @@ -34,6 +37,16 @@ trait_data = { "C57BL/6J": {"strain_name": "C57BL/6J", "value": 7.50606, "variance": None, "ndata": None}, "DBA/2J": {"strain_name": "DBA/2J", "value": 7.72588, "variance": None, "ndata": None}}} +slinked = ( + (((0, 2, 0.16381088984330505), + ((1, 7, 0.06024619831474998), 5, 0.19179284676938602), + 0.20337048635536847), + 9, + 0.23451785425383564), + ((3, (6, 8, 0.2140799896286565), 0.25879514152086425), + 4, 0.8968250491499363), + 0.9313185954797953) + class TestHeatmap(TestCase): """Class for testing heatmap computation functions""" @@ -141,3 +154,13 @@ class TestHeatmap(TestCase): 0.9313185954797953, 1.1683723389247052, 0.23451785425383564, 1.7413442197913358, 0.33370067057028485, 1.3256191648260216, 0.0))) + + def test_compute_heatmap_order(self): + """Test the orders.""" + for xoff, expected in [ + (40, ((60, 9), (60, 4))), + (30, ((50, 9), (50, 4))), + (20, ((40, 9), (40, 4)))]: + with self.subTest(xoffset=xoff): + self.assertEqual( + compute_heatmap_order(slinked, xoffset=xoff), expected) -- cgit v1.2.3 From 8b2c776771d2a70613a1e31d6e6671b612cfbafc Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Fri, 20 Aug 2021 14:10:45 +0300 Subject: Retrieve the strains with valid values Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * gn3/computations/heatmap.py: add function to get strains with values * tests/unit/computations/test_heatmap.py: new tests Add function to get the strains whose values are not `None` from the `trait_data` object passed in. This migrates https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L215-221 into a separate function that can handle that and be tested independently of any other code. --- gn3/computations/heatmap.py | 19 +++++++++++++++++++ tests/unit/computations/test_heatmap.py | 14 +++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index 1c86261..5a3c619 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -203,3 +203,22 @@ def compute_heatmap_order( return __order_maker(__order_maker(norder, slnk_dt[0]), slnk_dt[1]) return __order_maker(neworder, slink_data) + +def retrieve_strains_and_values(strainlist, trait_data): + """ + Get the strains and their corresponding values from `strainlist` and + `trait_data`. + + This migrates the code in + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L215-221 + """ + def __strains_and_values(acc, i): + if trait_data[i] is None: + return acc + if len(acc) == 0: + return ((strainlist[i], ), (trait_data[i], )) + _strains = acc[0] + _vals = acc[1] + return (_strains + (strainlist[i], ), _vals + (trait_data[i], )) + return reduce( + __strains_and_values, range(len(strainlist)), (tuple(), tuple())) diff --git a/tests/unit/computations/test_heatmap.py b/tests/unit/computations/test_heatmap.py index 14807bb..686288d 100644 --- a/tests/unit/computations/test_heatmap.py +++ b/tests/unit/computations/test_heatmap.py @@ -3,7 +3,8 @@ from unittest import TestCase from gn3.computations.heatmap import ( cluster_traits, export_trait_data, - compute_heatmap_order) + compute_heatmap_order, + retrieve_strains_and_values) strainlist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"] trait_data = { @@ -164,3 +165,14 @@ class TestHeatmap(TestCase): with self.subTest(xoffset=xoff): self.assertEqual( compute_heatmap_order(slinked, xoffset=xoff), expected) + + def test_retrieve_strains_and_values(self): + """Test retrieval of strains and values.""" + for slist, tdata, expected in [ + [["s1", "s2", "s3", "s4"], [9, None, 5, 4], + (("s1", "s3", "s4"), (9, 5, 4))], + [["s1", "s2", "s3", "s4", "s5"], [6, None, None, 4, None], + (("s1", "s4"), (6, 4))]]: + with self.subTest(strainlist=slist, traitdata=tdata): + self.assertEqual( + retrieve_strains_and_values(slist, tdata), expected) -- cgit v1.2.3 From 96af4e9e32ed167a8d70cf7761b709b1a37bb344 Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Fri, 20 Aug 2021 14:14:12 +0300 Subject: Fix typing issue(s) caught by mypy Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * gn3/computations/heatmap.py: Use `Sequence` type not `Iterator` type --- gn3/computations/heatmap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index 5a3c619..c9c2b8a 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -156,8 +156,8 @@ def heatmap_data(formd, search_result, conn: Any): traits_details = [ __retrieve_traitlist_and_datalist(threshold, fullname) for fullname in search_result] - traits_list = map(lambda x: x[0], traits_details) - traits_data_list = map(lambda x: x[1], traits_details) + traits_list = tuple(x[0] for x in traits_details) + traits_data_list = tuple(x[1] for x in traits_details) return { "target_description_checked": formd.formdata.getvalue( -- cgit v1.2.3 From 557e482c88ba3d44ae7d278b7222f37fa043b4d0 Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Fri, 27 Aug 2021 15:47:52 +0300 Subject: Rework strains and trait values retrieval Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Rework the strains and values retrieval function to more closely correspond to the working of the original code in GN1 --- gn3/computations/heatmap.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index c9c2b8a..da13ceb 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -204,21 +204,28 @@ def compute_heatmap_order( return __order_maker(neworder, slink_data) -def retrieve_strains_and_values(strainlist, trait_data): +def retrieve_strains_and_values(orders, strainlist, traits_data_list): """ Get the strains and their corresponding values from `strainlist` and - `trait_data`. + `traits_data_list`. This migrates the code in https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L215-221 """ - def __strains_and_values(acc, i): - if trait_data[i] is None: - return acc - if len(acc) == 0: - return ((strainlist[i], ), (trait_data[i], )) - _strains = acc[0] - _vals = acc[1] - return (_strains + (strainlist[i], ), _vals + (trait_data[i], )) - return reduce( - __strains_and_values, range(len(strainlist)), (tuple(), tuple())) + # This feels nasty! There's a lot of mutation of values here, that might + # indicate something untoward in the design of this function and its + # dependents ==> Review + strains = [] + values = [] + rets = [] + for order in orders: + temp_val = traits_data_list[order[1]] + for i in range(len(strainlist)): + if temp_val[i] != None: + strains.append(strainlist[i]) + values.append(temp_val[i]) + rets.append([order, strains[:], values[:]]) + strains = [] + values = [] + + return rets -- cgit v1.2.3 From 1a3901b174d00af8fa7f5ae78b810de66024b5ab Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Fri, 27 Aug 2021 15:49:53 +0300 Subject: Export trait data to file Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Provide a function to export the given strains and traits data into a traits file for use with `rust-qtlreaper`. --- gn3/computations/heatmap.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index da13ceb..2f92048 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -229,3 +229,11 @@ def retrieve_strains_and_values(orders, strainlist, traits_data_list): values = [] return rets + +def generate_traits_file(strains, trait_values, traits_filename): + header = "Traits\t{}\n".format("\t".join(strains)) + data = [header] + [ + "T{}\t{}\n".format(i+1, "\t".join([str(i) for i in t])) + for i,t in enumerate(trait_values)] + with open(traits_filename, "w") as outfile: + outfile.writelines(data) -- cgit v1.2.3 From 28fde00ee2835d404157652548a4265be3accede Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Fri, 27 Aug 2021 15:51:27 +0300 Subject: Provide intermediate data in final results Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Seeing as not every requirement/feature has been migrated over at this time, this commit just provides all the intermediate data representations in the final return of the function for later use down the line. --- gn3/computations/heatmap.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index 2f92048..3e96ed2 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -149,22 +149,22 @@ def heatmap_data(formd, search_result, conn: Any): def __retrieve_traitlist_and_datalist(threshold, fullname): trait = retrieve_trait_info(threshold, fullname, conn) - return ( - trait, - export_trait_data(retrieve_trait_data(trait, conn), strainlist)) + return (trait, retrieve_trait_data(trait, conn)) traits_details = [ __retrieve_traitlist_and_datalist(threshold, fullname) for fullname in search_result] traits_list = tuple(x[0] for x in traits_details) - traits_data_list = tuple(x[1] for x in traits_details) + traits_data_list = [x[1] for x in traits_details] + exported_traits_data_list = tuple( + export_trait_data(td, strainlist) for x in traits_data_list) return { "target_description_checked": formd.formdata.getvalue( "targetDescriptionCheck", ""), "cluster_checked": cluster_checked, "slink_data": ( - slink(cluster_traits(traits_data_list)) + slink(cluster_traits(exported_traits_data_list)) if cluster_checked else False), "sessionfile": formd.formdata.getvalue("session"), "genotype": genotype, @@ -173,7 +173,8 @@ def heatmap_data(formd, search_result, conn: Any): "ppolar": formd.ppolar, "mpolar":formd.mpolar, "traits_list": traits_list, - "traits_data_list": traits_data_list + "traits_data_list": traits_data_list, + "exported_traits_data_list": exported_traits_data_list } def compute_heatmap_order( -- cgit v1.2.3 From 983acfdfc523677b4d7501287a000b7fd52a2c39 Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Mon, 30 Aug 2021 07:00:38 +0300 Subject: Implement module for interfacing with rust-qtlreaper Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * gn3/computations/heatmap.py: move `generate_traits_file` function to new module * gn3/computations/qtlreaper.py: new module to interface with the `rust-qtlreaper` utility. * gn3/settings.py: Provide setting for the path to the `rust-qtlreaper` utility * qtlfilesexport.py: Move `random_string` function to new module. Update to use functions in new module. Provide a module with functions to be used to interface with `rust-qtlreaper`. This module essentially contains all the functions that are needed to build the files needed for, and to run the qtlreaper utility. --- gn3/computations/heatmap.py | 8 ---- gn3/computations/qtlreaper.py | 88 +++++++++++++++++++++++++++++++++++++++++++ gn3/settings.py | 3 ++ qtlfilesexport.py | 10 +---- 4 files changed, 92 insertions(+), 17 deletions(-) create mode 100644 gn3/computations/qtlreaper.py (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index 3e96ed2..dcd64b1 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -230,11 +230,3 @@ def retrieve_strains_and_values(orders, strainlist, traits_data_list): values = [] return rets - -def generate_traits_file(strains, trait_values, traits_filename): - header = "Traits\t{}\n".format("\t".join(strains)) - data = [header] + [ - "T{}\t{}\n".format(i+1, "\t".join([str(i) for i in t])) - for i,t in enumerate(trait_values)] - with open(traits_filename, "w") as outfile: - outfile.writelines(data) diff --git a/gn3/computations/qtlreaper.py b/gn3/computations/qtlreaper.py new file mode 100644 index 0000000..49d363b --- /dev/null +++ b/gn3/computations/qtlreaper.py @@ -0,0 +1,88 @@ +""" +This module contains functions to interact with the `qtlreaper` utility for +computation of QTLs. +""" +import os +import random +import string +import subprocess +from gn3.settings import TMPDIR, REAPER_COMMAND + +def random_string(length): + """Generate a random string of length `length`.""" + return "".join( + random.choices( + string.ascii_letters + string.digits, k=length)) + +def generate_traits_file(strains, trait_values, traits_filename): + """ + Generate a traits file for use with `qtlreaper`. + + PARAMETERS: + strains: A list of strains to use as the headers for the various columns. + trait_values: A list of lists of values for each trait and strain. + traits_filename: The tab-separated value to put the values in for + computation of QTLs. + """ + header = "Traits\t{}\n".format("\t".join(strains)) + data = [header] + [ + "T{}\t{}\n".format(i+1, "\t".join([str(i) for i in t])) + for i, t in enumerate(trait_values)] + with open(traits_filename, "w") as outfile: + outfile.writelines(data) + +def create_output_directory(path: str): + """Create the output directory at `path` if it does not exist.""" + try: + os.mkdir(path) + except OSError: + pass + +def run_reaper( + genotype_filename: str, traits_filename: str, + other_options: tuple = ("--n_permutations", 1000), + separate_nperm_output: bool = False, + output_dir: str = TMPDIR): + """ + Run the QTLReaper command to compute the QTLs. + + PARAMETERS: + genotype_filename: The complete path to a genotype file to use in the QTL + computation. + traits_filename: A path to a file previously generated with the + `generate_traits_file` function in this module, to be used in the QTL + computation. + other_options: Other options to pass to the `qtlreaper` command to modify + the QTL computations. + separate_nperm_output: A flag indicating whether or not to provide a + separate output for the permutations computation. The default is False, + which means by default, no separate output file is created. + output_dir: A path to the directory where the outputs are put + + RETURNS: + The function returns a tuple of the main output file, and the output file + for the permutation computations. If the `separate_nperm_output` is `False`, + the second value in the tuple returned is `None`. + + RAISES: + The function will raise a `subprocess.CalledProcessError` exception in case + of any errors running the `qtlreaper` command. + """ + create_output_directory(output_dir) + output_filename = "{}/qtlreaper/main_output_{}.txt".format( + output_dir, random_string(10)) + output_list = ["--main_output", output_filename] + if separate_nperm_output: + permu_output_filename = "{}/qtlreaper/permu_output_{}.txt".format( + output_dir, random_string(10)) + output_list = output_list + ["--permu_output", permu_output_filename] + else: + permu_output_filename = None + + command_list = [ + REAPER_COMMAND, "--geno", genotype_filename, + *other_options, # this splices the `other_options` list here + "--traits", traits_filename, "--main_output", output_filename] + + subprocess.run(command_list, check=True) + return (output_filename, permu_output_filename) diff --git a/gn3/settings.py b/gn3/settings.py index f4866d5..d137370 100644 --- a/gn3/settings.py +++ b/gn3/settings.py @@ -24,3 +24,6 @@ GN2_BASE_URL = "http://www.genenetwork.org/" # biweight script BIWEIGHT_RSCRIPT = "~/genenetwork3/scripts/calculate_biweight.R" + +# qtlreaper command +REAPER_COMMAND = "{}/bin/qtlreaper".format(os.environ.get("GUIX_ENVIRONMENT")) diff --git a/qtlfilesexport.py b/qtlfilesexport.py index 2e7c9c2..0543dc9 100644 --- a/qtlfilesexport.py +++ b/qtlfilesexport.py @@ -7,16 +7,14 @@ Run with: replacing the variables in the angled brackets with the appropriate values """ -import random -import string from gn3.computations.slink import slink from gn3.db_utils import database_connector from gn3.computations.heatmap import export_trait_data from gn3.db.traits import retrieve_trait_data, retrieve_trait_info +from gn3.computations.qtlreaper import random_string, generate_traits_file from gn3.computations.heatmap import ( cluster_traits, compute_heatmap_order, - generate_traits_file, retrieve_strains_and_values) TMPDIR = "tmp/qtltests" @@ -35,11 +33,6 @@ def trait_fullnames(): "UCLA_BXDBXH_CARTILAGE_V2::ILM4200064", "UCLA_BXDBXH_CARTILAGE_V2::ILM3140463"] -def random_string(length): - return "".join( - random.choices( - string.ascii_letters + string.digits, k=length)) - def main(): """entrypoint function""" conn = database_connector()[0] @@ -56,7 +49,6 @@ def main(): strains_and_values = retrieve_strains_and_values( orders, strains, exported_traits_data_list) strains_values = strains_and_values[0][1] - strains_values2 = strains_and_values[1][1] trait_values = [t[2] for t in strains_and_values] traits_filename = "{}/traits_test_file_{}.txt".format( TMPDIR, random_string(10)) -- cgit v1.2.3 From b95ad3bd2ce8bc22d1dcadefdf76c43f28309984 Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Mon, 30 Aug 2021 07:05:49 +0300 Subject: Fix some linting errors and minor bugs. --- gn3/computations/heatmap.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index dcd64b1..e0ff05b 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -157,7 +157,7 @@ def heatmap_data(formd, search_result, conn: Any): traits_list = tuple(x[0] for x in traits_details) traits_data_list = [x[1] for x in traits_details] exported_traits_data_list = tuple( - export_trait_data(td, strainlist) for x in traits_data_list) + export_trait_data(td, strainlist) for td in traits_data_list) return { "target_description_checked": formd.formdata.getvalue( @@ -190,7 +190,6 @@ def compute_heatmap_order( d_1 = (0, 0, 0) # returned from self.draw in lines 391 and 399. This is just a placeholder def __order_maker(norder, slnk_dt): - print("norder:{}, slnk_dt:{}".format(norder, slnk_dt)) if isinstance(slnk_dt[0], int) and isinstance(slnk_dt[1], int): return norder + ( (xoffset+20, slnk_dt[0]), (xoffset + 40, slnk_dt[1])) @@ -221,9 +220,9 @@ def retrieve_strains_and_values(orders, strainlist, traits_data_list): rets = [] for order in orders: temp_val = traits_data_list[order[1]] - for i in range(len(strainlist)): - if temp_val[i] != None: - strains.append(strainlist[i]) + for i, strain in enumerate(strainlist): + if temp_val[i] is not None: + strains.append(strain) values.append(temp_val[i]) rets.append([order, strains[:], values[:]]) strains = [] -- cgit v1.2.3 From e441509a59c20a051fd5ab94710513f1968a5e02 Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Tue, 31 Aug 2021 10:50:56 +0300 Subject: Update `heatmap_data` function: remove extraneous data Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * gn3/computations/heatmap.py: update function * gn3/db/traits.py: new function Remove extraneous data and arguments from the function. - Load the genotype file - Generate traits file - Provide both raw traits data, and exported traits data in return --- gn3/computations/heatmap.py | 42 ++++++++++++++++++++++-------------------- gn3/db/traits.py | 5 +++++ 2 files changed, 27 insertions(+), 20 deletions(-) (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index e0ff05b..92014cf 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -6,8 +6,12 @@ generate various kinds of heatmaps. from functools import reduce from typing import Any, Dict, Sequence from gn3.computations.slink import slink -from gn3.db.traits import retrieve_trait_data, retrieve_trait_info from gn3.computations.correlations2 import compute_correlation +from gn3.db.genotypes import build_genotype_file, load_genotype_samples +from gn3.db.traits import ( + retrieve_trait_data, + retrieve_trait_info, + generate_traits_filename) def export_trait_data( trait_data: dict, strainlist: Sequence[str], dtype: str = "val", @@ -125,7 +129,7 @@ def cluster_traits(traits_data_list: Sequence[Dict]): return tuple(__cluster(tdata_i) for tdata_i in enumerate(traits_data_list)) -def heatmap_data(formd, search_result, conn: Any): +def heatmap_data(traits_names, conn: Any): """ heatmap function @@ -142,39 +146,37 @@ def heatmap_data(formd, search_result, conn: Any): TODO: Elaborate on the parameters here... """ threshold = 0 # webqtlConfig.PUBLICTHRESH - cluster_checked = formd.formdata.getvalue("clusterCheck", "") - strainlist = [ - strain for strain in formd.strainlist if strain not in formd.parlist] - genotype = formd.genotype - def __retrieve_traitlist_and_datalist(threshold, fullname): trait = retrieve_trait_info(threshold, fullname, conn) return (trait, retrieve_trait_data(trait, conn)) traits_details = [ __retrieve_traitlist_and_datalist(threshold, fullname) - for fullname in search_result] + for fullname in traits_names] traits_list = tuple(x[0] for x in traits_details) traits_data_list = [x[1] for x in traits_details] exported_traits_data_list = tuple( export_trait_data(td, strainlist) for td in traits_data_list) + genotype_filename = build_genotype_file(traits_list[0]["riset"]) + strainlist = load_genotype_samples(genotype_filename) + slink_data = slink(cluster_traits(exported_traits_data_list)) + ordering_data = compute_heatmap_order(slink_data) + strains_and_values = retrieve_strains_and_values( + orders, strainlist, exported_traits_data_list) + strains_values = strains_and_values[0][1] + trait_values = [t[2] for t in strains_and_values] + traits_filename = generate_traits_filename() + generate_traits_file(strains_values, trait_values, traits_filename) return { - "target_description_checked": formd.formdata.getvalue( - "targetDescriptionCheck", ""), - "cluster_checked": cluster_checked, - "slink_data": ( - slink(cluster_traits(exported_traits_data_list)) - if cluster_checked else False), - "sessionfile": formd.formdata.getvalue("session"), - "genotype": genotype, - "nLoci": sum(map(len, genotype)), + "slink_data": slink_data, + "ordering_data": ordering_data, "strainlist": strainlist, - "ppolar": formd.ppolar, - "mpolar":formd.mpolar, + "genotype_filename": genotype_filename, "traits_list": traits_list, "traits_data_list": traits_data_list, - "exported_traits_data_list": exported_traits_data_list + "exported_traits_data_list": exported_traits_data_list, + "traits_filename": traits_filename } def compute_heatmap_order( diff --git a/gn3/db/traits.py b/gn3/db/traits.py index 1031e44..ccb101a 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -1,4 +1,5 @@ """This class contains functions relating to trait data manipulation""" +from gn3.settings import TMPDIR from typing import Any, Dict, Union, Sequence from gn3.function_helpers import compose from gn3.db.datasets import retrieve_trait_dataset @@ -666,3 +667,7 @@ def retrieve_trait_data(trait: dict, conn: Any, strainlist: Sequence[str] = tupl {k:v for k, v in x.items() if x != "strain_name"}), data))} return {} + +def generate_traits_filename(base_path: str = TMPDIR): + return "{}/traits_test_file_{}.txt".format( + os.path.abspath(base_path), random_string(10)) -- cgit v1.2.3 From b5e1d1176f1bf4f7c0b68b27beb15e99418f1650 Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Tue, 31 Aug 2021 11:16:29 +0300 Subject: Fix linting errors, minor bugs and reorganise code * Fix some linting errors and some minor bugs caught by the linter. Move the `random_string` function to separate module for use in multiple places in the code. --- gn3/computations/heatmap.py | 7 ++++--- gn3/computations/qtlreaper.py | 27 ++++++++++++++------------- gn3/db/traits.py | 5 ++++- gn3/heatmaps/heatmaps.py | 25 +++++++++++++++++++------ gn3/random.py | 11 +++++++++++ tests/unit/computations/test_qtlreaper.py | 5 +++-- 6 files changed, 55 insertions(+), 25 deletions(-) create mode 100644 gn3/random.py (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index 92014cf..1143450 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -6,6 +6,7 @@ generate various kinds of heatmaps. from functools import reduce from typing import Any, Dict, Sequence from gn3.computations.slink import slink +from gn3.computations.qtlreaper import generate_traits_file from gn3.computations.correlations2 import compute_correlation from gn3.db.genotypes import build_genotype_file, load_genotype_samples from gn3.db.traits import ( @@ -155,14 +156,14 @@ def heatmap_data(traits_names, conn: Any): for fullname in traits_names] traits_list = tuple(x[0] for x in traits_details) traits_data_list = [x[1] for x in traits_details] - exported_traits_data_list = tuple( - export_trait_data(td, strainlist) for td in traits_data_list) genotype_filename = build_genotype_file(traits_list[0]["riset"]) strainlist = load_genotype_samples(genotype_filename) + exported_traits_data_list = tuple( + export_trait_data(td, strainlist) for td in traits_data_list) slink_data = slink(cluster_traits(exported_traits_data_list)) ordering_data = compute_heatmap_order(slink_data) strains_and_values = retrieve_strains_and_values( - orders, strainlist, exported_traits_data_list) + ordering_data, strainlist, exported_traits_data_list) strains_values = strains_and_values[0][1] trait_values = [t[2] for t in strains_and_values] traits_filename = generate_traits_filename() diff --git a/gn3/computations/qtlreaper.py b/gn3/computations/qtlreaper.py index 3b8e4db..30c7051 100644 --- a/gn3/computations/qtlreaper.py +++ b/gn3/computations/qtlreaper.py @@ -3,17 +3,10 @@ This module contains functions to interact with the `qtlreaper` utility for computation of QTLs. """ import os -import random -import string import subprocess +from gn3.random import random_string from gn3.settings import TMPDIR, REAPER_COMMAND -def random_string(length): - """Generate a random string of length `length`.""" - return "".join( - random.choices( - string.ascii_letters + string.digits, k=length)) - def generate_traits_file(strains, trait_values, traits_filename): """ Generate a traits file for use with `qtlreaper`. @@ -25,11 +18,13 @@ def generate_traits_file(strains, trait_values, traits_filename): computation of QTLs. """ header = "Trait\t{}\n".format("\t".join(strains)) - data = [header] + [ - "T{}\t{}\n".format(i+1, "\t".join([str(i) for i in t])) - for i, t in enumerate(trait_values[:-1])] + [ - "T{}\t{}".format(len(trait_values), "\t".join([str(i) for i in t])) - for t in trait_values[-1:]] + data = ( + [header] + + ["T{}\t{}\n".format(i+1, "\t".join([str(i) for i in t])) + for i, t in enumerate(trait_values[:-1])] + + ["T{}\t{}".format( + len(trait_values), "\t".join([str(i) for i in t])) + for t in trait_values[-1:]]) with open(traits_filename, "w") as outfile: outfile.writelines(data) @@ -93,6 +88,9 @@ def run_reaper( def parse_reaper_main_results(results_file): + """ + Parse the results file of running QTLReaper into a list of dicts. + """ with open(results_file, "r") as infile: lines = infile.readlines() @@ -104,6 +102,9 @@ def parse_reaper_main_results(results_file): return [dict(zip(header, __parse_line(line))) for line in lines[1:]] def parse_reaper_permutation_results(results_file): + """ + Parse the results QTLReaper permutations into a list of values. + """ with open(results_file, "r") as infile: lines = infile.readlines() diff --git a/gn3/db/traits.py b/gn3/db/traits.py index ccb101a..bfe887e 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -1,6 +1,8 @@ """This class contains functions relating to trait data manipulation""" -from gn3.settings import TMPDIR +import os from typing import Any, Dict, Union, Sequence +from gn3.settings import TMPDIR +from gn3.random import random_string from gn3.function_helpers import compose from gn3.db.datasets import retrieve_trait_dataset @@ -669,5 +671,6 @@ def retrieve_trait_data(trait: dict, conn: Any, strainlist: Sequence[str] = tupl return {} def generate_traits_filename(base_path: str = TMPDIR): + """Generate a unique filename for use with generated traits files.""" return "{}/traits_test_file_{}.txt".format( os.path.abspath(base_path), random_string(10)) diff --git a/gn3/heatmaps/heatmaps.py b/gn3/heatmaps/heatmaps.py index 3bf7917..88f546d 100644 --- a/gn3/heatmaps/heatmaps.py +++ b/gn3/heatmaps/heatmaps.py @@ -14,6 +14,19 @@ def generate_random_data(data_stop: float = 2, width: int = 10, height: int = 30 return [[random.uniform(0,data_stop) for i in range(0, width)] for j in range(0, height)] +def generate_random_data2(data_stop: float = 2, width: int = 10, height: int = 30): + """ + This is mostly a utility function to be used to generate random data, useful + for development of the heatmap generation code, without access to the actual + database data. + """ + return [ + [{ + "value": item, + "category": random.choice(["C57BL/6J +", "DBA/2J +"])} + for item in axis] + for axis in generate_random_data(data_stop, width, height)] + def heatmap_x_axis_names(): return [ "UCLA_BXDBXH_CARTILAGE_V2::ILM103710672", @@ -30,13 +43,14 @@ def heatmap_x_axis_names(): # Grey + Blue + Red def generate_heatmap(): - rows = 20 - data = generate_random_data(height=rows) - y = (["%s"%x for x in range(1, rows+1)][:-1] + ["X"]) #replace last item with x for now + cols = 20 + y_axis = (["%s"%x for x in range(1, cols+1)][:-1] + ["X"]) #replace last item with x for now + x_axis = heatmap_x_axis_names() + data = generate_random_data(height=cols, width=len(x_axis)) fig = px.imshow( data, - x=heatmap_x_axis_names(), - y=y, + x=x_axis, + y=y_axis, width=500) fig.update_traces(xtype="array") fig.update_traces(ytype="array") @@ -49,6 +63,5 @@ def generate_heatmap(): coloraxis_colorscale=[ [0.0, '#3B3B3B'], [0.4999999999999999, '#ABABAB'], [0.5, '#F5DE11'], [1.0, '#FF0D00']]) - fig.write_html("%s/%s"%(heatmap_dir, "test_image.html")) return fig diff --git a/gn3/random.py b/gn3/random.py new file mode 100644 index 0000000..f0ba574 --- /dev/null +++ b/gn3/random.py @@ -0,0 +1,11 @@ +""" +Functions to generate complex random data. +""" +import random +import string + +def random_string(length): + """Generate a random string of length `length`.""" + return "".join( + random.choices( + string.ascii_letters + string.digits, k=length)) diff --git a/tests/unit/computations/test_qtlreaper.py b/tests/unit/computations/test_qtlreaper.py index ec23664..6c3b64d 100644 --- a/tests/unit/computations/test_qtlreaper.py +++ b/tests/unit/computations/test_qtlreaper.py @@ -1,5 +1,4 @@ """Module contains tests for gn3.computations.qtlreaper""" -import os from unittest import TestCase from gn3.computations.qtlreaper import ( parse_reaper_main_results, parse_reaper_permutation_results) @@ -8,6 +7,7 @@ class TestQTLReaper(TestCase): """Class for testing qtlreaper interface functions.""" def test_parse_reaper_main_results(self): + """Test that the main results file is parsed correctly.""" self.assertEqual( parse_reaper_main_results( "tests/unit/computations/data/qtlreaper/main_output_sample.txt"), @@ -65,9 +65,10 @@ class TestQTLReaper(TestCase): ]) def test_parse_reaper_permutation_results(self): + """Test that the permutations results file is parsed correctly.""" self.assertEqual( parse_reaper_permutation_results( - "tests/unit/computations/data/qtlreaper/permu_output_sample.txt"), + "tests/unit/computations/data/qtlreaper/permu_output_sample.txt"), [4.44174, 5.03825, 5.08167, 5.18119, 5.18578, 5.24563, 5.24619, 5.24619, 5.27961, 5.28228, 5.43903, 5.50188, 5.51694, 5.56830, 5.63874, 5.71346, 5.71936, 5.74275, 5.76764, 5.79815, 5.81671, -- cgit v1.2.3 From 608ff9c6ff668d18f0c42aebf658ef80b517a6de Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Mon, 6 Sep 2021 06:45:18 +0300 Subject: Find nearest marker Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Migrate the `web.webqtl.heatmap.Heatmap.getNearestMarker` function in GN1 to GN3. --- gn3/computations/heatmap.py | 49 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index 1143450..ccce385 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -30,7 +30,7 @@ def export_trait_data( The dictionary of key-value pairs representing a trait strainlist: (list) A list of strain names - type: (str) + dtype: (str) ... verify what this is ... var_exists: (bool) A flag indicating existence of variance @@ -232,3 +232,50 @@ def retrieve_strains_and_values(orders, strainlist, traits_data_list): values = [] return rets + +def nearest_marker_finder(genotype): + """ + Returns a function to be used with `genotype` to compute the nearest marker + to the trait passed to the returned function. + + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L425-434 + """ + def __compute_distances(chromo, trait): + loci = chromo.get("loci", None) + if not loci: + return None + return tuple( + { + "name": locus["name"], + "distance": abs(locus["Mb"] - trait["mb"]) + } for locus in loci) + + def __finder(trait): + _chrs = tuple( + _chr for _chr in genotype["chromosomes"] + if str(_chr["name"]) == str(trait["chr"])) + if len(_chrs) == 0: + return None + distances = tuple( + distance for dists in + filter( + lambda x: x is not None, + (__compute_distances(_chr, trait) for _chr in _chrs)) + for distance in dists) + nearest = min(distances, key=lambda d: d["distance"]) + return nearest["name"] + return __finder + +def get_nearest_marker(traits_list, genotype): + """ + Retrieves the nearest marker for each of the traits in the list. + + DESCRIPTION: + This migrates the code in + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L419-L438 + """ + if not genotype["Mbmap"]: + return [None] * len(trait_list) + + marker_finder = nearest_marker_finder(genotype) + return [marker_finder(trait) for trait in traits_list] -- cgit v1.2.3 From 31ca02d1f095c2cc667e5b7d49131d702982f321 Mon Sep 17 00:00:00 2001 From: Muriithi Frederick Muriuki Date: Wed, 8 Sep 2021 06:52:01 +0300 Subject: Fix the traits order computations for clustering Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * gn3/computations/heatmap.py: Fix ordering function * tests/unit/computations/test_heatmap.py: update test The order of the traits is important for the clustering algorithm, since the clustering seems to use the distance of one trait from another to determine how to order them. This commit also gets rid of the xoffset argument that is not important to the ordering, and was used in the older GN1 to determine how to draw the clustering lines. --- gn3/computations/heatmap.py | 16 ++++++---------- tests/unit/computations/test_heatmap.py | 11 +++-------- 2 files changed, 9 insertions(+), 18 deletions(-) (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py index ccce385..8727c92 100644 --- a/gn3/computations/heatmap.py +++ b/gn3/computations/heatmap.py @@ -180,28 +180,24 @@ def heatmap_data(traits_names, conn: Any): "traits_filename": traits_filename } -def compute_heatmap_order( - slink_data, xoffset: int = 40, neworder: tuple = tuple()): +def compute_traits_order(slink_data, neworder: tuple = tuple()): """ - Compute the data used for drawing the heatmap proper from `slink_data`. + Compute the order of the traits for clustering from `slink_data`. This function tries to reproduce the creation and update of the `neworder` variable in https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L120 and in the `web.webqtl.heatmap.Heatmap.draw` function in GN1 """ - d_1 = (0, 0, 0) # returned from self.draw in lines 391 and 399. This is just a placeholder - def __order_maker(norder, slnk_dt): if isinstance(slnk_dt[0], int) and isinstance(slnk_dt[1], int): - return norder + ( - (xoffset+20, slnk_dt[0]), (xoffset + 40, slnk_dt[1])) + return norder + (slnk_dt[0], slnk_dt[1]) if isinstance(slnk_dt[0], int): - return norder + ((xoffset + 20, slnk_dt[0]), ) + return __order_maker((norder + (slnk_dt[0], )), slnk_dt[1]) if isinstance(slnk_dt[1], int): - return norder + ((xoffset + d_1[0] + 20, slnk_dt[1]), ) + return __order_maker(norder, slnk_dt[0]) + (slnk_dt[1], ) return __order_maker(__order_maker(norder, slnk_dt[0]), slnk_dt[1]) @@ -222,7 +218,7 @@ def retrieve_strains_and_values(orders, strainlist, traits_data_list): values = [] rets = [] for order in orders: - temp_val = traits_data_list[order[1]] + temp_val = traits_data_list[order] for i, strain in enumerate(strainlist): if temp_val[i] is not None: strains.append(strain) diff --git a/tests/unit/computations/test_heatmap.py b/tests/unit/computations/test_heatmap.py index 87f8e45..f1bbefc 100644 --- a/tests/unit/computations/test_heatmap.py +++ b/tests/unit/computations/test_heatmap.py @@ -3,7 +3,7 @@ from unittest import TestCase from gn3.computations.heatmap import ( cluster_traits, export_trait_data, - compute_heatmap_order, + compute_traits_order, retrieve_strains_and_values) strainlist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"] @@ -158,13 +158,8 @@ class TestHeatmap(TestCase): def test_compute_heatmap_order(self): """Test the orders.""" - for xoff, expected in [ - (40, ((60, 9), (60, 4))), - (30, ((50, 9), (50, 4))), - (20, ((40, 9), (40, 4)))]: - with self.subTest(xoffset=xoff): - self.assertEqual( - compute_heatmap_order(slinked, xoffset=xoff), expected) + self.assertEqual( + compute_traits_order(slinked), (0, 2, 1, 7, 5, 9, 3, 6, 8, 4)) def test_retrieve_strains_and_values(self): """Test retrieval of strains and values.""" -- cgit v1.2.3 From e3e18950cfcdec918429dcbb5d5ed2e9616b7a20 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Wed, 15 Sep 2021 11:19:56 +0300 Subject: Reorganise modules Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * The heatmap generation does not fall cleanly within the computations or db modules. This commit moves it to the higher level gn3 module. --- gn3/computations/heatmap.py | 277 ----------------------------- gn3/heatmaps.py | 302 ++++++++++++++++++++++++++++++++ gn3/heatmaps/heatmaps.py | 67 ------- tests/unit/computations/test_heatmap.py | 187 -------------------- tests/unit/test_heatmaps.py | 187 ++++++++++++++++++++ 5 files changed, 489 insertions(+), 531 deletions(-) delete mode 100644 gn3/computations/heatmap.py create mode 100644 gn3/heatmaps.py delete mode 100644 gn3/heatmaps/heatmaps.py delete mode 100644 tests/unit/computations/test_heatmap.py create mode 100644 tests/unit/test_heatmaps.py (limited to 'gn3/computations/heatmap.py') diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py deleted file mode 100644 index 8727c92..0000000 --- a/gn3/computations/heatmap.py +++ /dev/null @@ -1,277 +0,0 @@ -""" -This module will contain functions to be used in computation of the data used to -generate various kinds of heatmaps. -""" - -from functools import reduce -from typing import Any, Dict, Sequence -from gn3.computations.slink import slink -from gn3.computations.qtlreaper import generate_traits_file -from gn3.computations.correlations2 import compute_correlation -from gn3.db.genotypes import build_genotype_file, load_genotype_samples -from gn3.db.traits import ( - retrieve_trait_data, - retrieve_trait_info, - generate_traits_filename) - -def export_trait_data( - trait_data: dict, strainlist: Sequence[str], dtype: str = "val", - var_exists: bool = False, n_exists: bool = False): - """ - Export data according to `strainlist`. Mostly used in calculating - correlations. - - DESCRIPTION: - Migrated from - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211 - - PARAMETERS - trait: (dict) - The dictionary of key-value pairs representing a trait - strainlist: (list) - A list of strain names - dtype: (str) - ... verify what this is ... - var_exists: (bool) - A flag indicating existence of variance - n_exists: (bool) - A flag indicating existence of ndata - """ - def __export_all_types(tdata, strain): - sample_data = [] - if tdata[strain]["value"]: - sample_data.append(tdata[strain]["value"]) - if var_exists: - if tdata[strain]["variance"]: - sample_data.append(tdata[strain]["variance"]) - else: - sample_data.append(None) - if n_exists: - if tdata[strain]["ndata"]: - sample_data.append(tdata[strain]["ndata"]) - else: - sample_data.append(None) - else: - if var_exists and n_exists: - sample_data += [None, None, None] - elif var_exists or n_exists: - sample_data += [None, None] - else: - sample_data.append(None) - - return tuple(sample_data) - - def __exporter(accumulator, strain): - # pylint: disable=[R0911] - if strain in trait_data["data"]: - if dtype == "val": - return accumulator + (trait_data["data"][strain]["value"], ) - if dtype == "var": - return accumulator + (trait_data["data"][strain]["variance"], ) - if dtype == "N": - return accumulator + (trait_data["data"][strain]["ndata"], ) - if dtype == "all": - return accumulator + __export_all_types(trait_data["data"], strain) - raise KeyError("Type `%s` is incorrect" % dtype) - if var_exists and n_exists: - return accumulator + (None, None, None) - if var_exists or n_exists: - return accumulator + (None, None) - return accumulator + (None,) - - return reduce(__exporter, strainlist, tuple()) - -def trait_display_name(trait: Dict): - """ - Given a trait, return a name to use to display the trait on a heatmap. - - DESCRIPTION - Migrated from - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L141-L157 - """ - if trait.get("db", None) and trait.get("trait_name", None): - if trait["db"]["dataset_type"] == "Temp": - desc = trait["description"] - if desc.find("PCA") >= 0: - return "%s::%s" % ( - trait["db"]["displayname"], - desc[desc.rindex(':')+1:].strip()) - return "%s::%s" % ( - trait["db"]["displayname"], - desc[:desc.index('entered')].strip()) - prefix = "%s::%s" % ( - trait["db"]["dataset_name"], trait["trait_name"]) - if trait["cellid"]: - return "%s::%s" % (prefix, trait["cellid"]) - return prefix - return trait["description"] - -def cluster_traits(traits_data_list: Sequence[Dict]): - """ - Clusters the trait values. - - DESCRIPTION - Attempts to replicate the clustering of the traits, as done at - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L162 - """ - def __compute_corr(tdata_i, tdata_j): - if tdata_i[0] == tdata_j[0]: - return 0.0 - corr_vals = compute_correlation(tdata_i[1], tdata_j[1]) - corr = corr_vals[0] - if (1 - corr) < 0: - return 0.0 - return 1 - corr - - def __cluster(tdata_i): - return tuple( - __compute_corr(tdata_i, tdata_j) - for tdata_j in enumerate(traits_data_list)) - - return tuple(__cluster(tdata_i) for tdata_i in enumerate(traits_data_list)) - -def heatmap_data(traits_names, conn: Any): - """ - heatmap function - - DESCRIPTION - This function is an attempt to reproduce the initialisation at - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L46-L64 - and also the clustering and slink computations at - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L165 - with the help of the `gn3.computations.heatmap.cluster_traits` function. - - It does not try to actually draw the heatmap image. - - PARAMETERS: - TODO: Elaborate on the parameters here... - """ - threshold = 0 # webqtlConfig.PUBLICTHRESH - def __retrieve_traitlist_and_datalist(threshold, fullname): - trait = retrieve_trait_info(threshold, fullname, conn) - return (trait, retrieve_trait_data(trait, conn)) - - traits_details = [ - __retrieve_traitlist_and_datalist(threshold, fullname) - for fullname in traits_names] - traits_list = tuple(x[0] for x in traits_details) - traits_data_list = [x[1] for x in traits_details] - genotype_filename = build_genotype_file(traits_list[0]["riset"]) - strainlist = load_genotype_samples(genotype_filename) - exported_traits_data_list = tuple( - export_trait_data(td, strainlist) for td in traits_data_list) - slink_data = slink(cluster_traits(exported_traits_data_list)) - ordering_data = compute_heatmap_order(slink_data) - strains_and_values = retrieve_strains_and_values( - ordering_data, strainlist, exported_traits_data_list) - strains_values = strains_and_values[0][1] - trait_values = [t[2] for t in strains_and_values] - traits_filename = generate_traits_filename() - generate_traits_file(strains_values, trait_values, traits_filename) - - return { - "slink_data": slink_data, - "ordering_data": ordering_data, - "strainlist": strainlist, - "genotype_filename": genotype_filename, - "traits_list": traits_list, - "traits_data_list": traits_data_list, - "exported_traits_data_list": exported_traits_data_list, - "traits_filename": traits_filename - } - -def compute_traits_order(slink_data, neworder: tuple = tuple()): - """ - Compute the order of the traits for clustering from `slink_data`. - - This function tries to reproduce the creation and update of the `neworder` - variable in - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L120 - and in the `web.webqtl.heatmap.Heatmap.draw` function in GN1 - """ - def __order_maker(norder, slnk_dt): - if isinstance(slnk_dt[0], int) and isinstance(slnk_dt[1], int): - return norder + (slnk_dt[0], slnk_dt[1]) - - if isinstance(slnk_dt[0], int): - return __order_maker((norder + (slnk_dt[0], )), slnk_dt[1]) - - if isinstance(slnk_dt[1], int): - return __order_maker(norder, slnk_dt[0]) + (slnk_dt[1], ) - - return __order_maker(__order_maker(norder, slnk_dt[0]), slnk_dt[1]) - - return __order_maker(neworder, slink_data) - -def retrieve_strains_and_values(orders, strainlist, traits_data_list): - """ - Get the strains and their corresponding values from `strainlist` and - `traits_data_list`. - - This migrates the code in - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L215-221 - """ - # This feels nasty! There's a lot of mutation of values here, that might - # indicate something untoward in the design of this function and its - # dependents ==> Review - strains = [] - values = [] - rets = [] - for order in orders: - temp_val = traits_data_list[order] - for i, strain in enumerate(strainlist): - if temp_val[i] is not None: - strains.append(strain) - values.append(temp_val[i]) - rets.append([order, strains[:], values[:]]) - strains = [] - values = [] - - return rets - -def nearest_marker_finder(genotype): - """ - Returns a function to be used with `genotype` to compute the nearest marker - to the trait passed to the returned function. - - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L425-434 - """ - def __compute_distances(chromo, trait): - loci = chromo.get("loci", None) - if not loci: - return None - return tuple( - { - "name": locus["name"], - "distance": abs(locus["Mb"] - trait["mb"]) - } for locus in loci) - - def __finder(trait): - _chrs = tuple( - _chr for _chr in genotype["chromosomes"] - if str(_chr["name"]) == str(trait["chr"])) - if len(_chrs) == 0: - return None - distances = tuple( - distance for dists in - filter( - lambda x: x is not None, - (__compute_distances(_chr, trait) for _chr in _chrs)) - for distance in dists) - nearest = min(distances, key=lambda d: d["distance"]) - return nearest["name"] - return __finder - -def get_nearest_marker(traits_list, genotype): - """ - Retrieves the nearest marker for each of the traits in the list. - - DESCRIPTION: - This migrates the code in - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L419-L438 - """ - if not genotype["Mbmap"]: - return [None] * len(trait_list) - - marker_finder = nearest_marker_finder(genotype) - return [marker_finder(trait) for trait in traits_list] diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py new file mode 100644 index 0000000..198fb45 --- /dev/null +++ b/gn3/heatmaps.py @@ -0,0 +1,302 @@ +""" +This module will contain functions to be used in computation of the data used to +generate various kinds of heatmaps. +""" + +from functools import reduce +from typing import Any, Dict, Sequence +from gn3.computations.slink import slink +from gn3.computations.qtlreaper import generate_traits_file +from gn3.computations.correlations2 import compute_correlation +from gn3.db.genotypes import build_genotype_file, load_genotype_samples +from gn3.db.traits import ( + retrieve_trait_data, + retrieve_trait_info, + generate_traits_filename) + +def export_trait_data( + trait_data: dict, strainlist: Sequence[str], dtype: str = "val", + var_exists: bool = False, n_exists: bool = False): + """ + Export data according to `strainlist`. Mostly used in calculating + correlations. + + DESCRIPTION: + Migrated from + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211 + + PARAMETERS + trait: (dict) + The dictionary of key-value pairs representing a trait + strainlist: (list) + A list of strain names + dtype: (str) + ... verify what this is ... + var_exists: (bool) + A flag indicating existence of variance + n_exists: (bool) + A flag indicating existence of ndata + """ + def __export_all_types(tdata, strain): + sample_data = [] + if tdata[strain]["value"]: + sample_data.append(tdata[strain]["value"]) + if var_exists: + if tdata[strain]["variance"]: + sample_data.append(tdata[strain]["variance"]) + else: + sample_data.append(None) + if n_exists: + if tdata[strain]["ndata"]: + sample_data.append(tdata[strain]["ndata"]) + else: + sample_data.append(None) + else: + if var_exists and n_exists: + sample_data += [None, None, None] + elif var_exists or n_exists: + sample_data += [None, None] + else: + sample_data.append(None) + + return tuple(sample_data) + + def __exporter(accumulator, strain): + # pylint: disable=[R0911] + if strain in trait_data["data"]: + if dtype == "val": + return accumulator + (trait_data["data"][strain]["value"], ) + if dtype == "var": + return accumulator + (trait_data["data"][strain]["variance"], ) + if dtype == "N": + return accumulator + (trait_data["data"][strain]["ndata"], ) + if dtype == "all": + return accumulator + __export_all_types(trait_data["data"], strain) + raise KeyError("Type `%s` is incorrect" % dtype) + if var_exists and n_exists: + return accumulator + (None, None, None) + if var_exists or n_exists: + return accumulator + (None, None) + return accumulator + (None,) + + return reduce(__exporter, strainlist, tuple()) + +def trait_display_name(trait: Dict): + """ + Given a trait, return a name to use to display the trait on a heatmap. + + DESCRIPTION + Migrated from + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L141-L157 + """ + if trait.get("db", None) and trait.get("trait_name", None): + if trait["db"]["dataset_type"] == "Temp": + desc = trait["description"] + if desc.find("PCA") >= 0: + return "%s::%s" % ( + trait["db"]["displayname"], + desc[desc.rindex(':')+1:].strip()) + return "%s::%s" % ( + trait["db"]["displayname"], + desc[:desc.index('entered')].strip()) + prefix = "%s::%s" % ( + trait["db"]["dataset_name"], trait["trait_name"]) + if trait["cellid"]: + return "%s::%s" % (prefix, trait["cellid"]) + return prefix + return trait["description"] + +def cluster_traits(traits_data_list: Sequence[Dict]): + """ + Clusters the trait values. + + DESCRIPTION + Attempts to replicate the clustering of the traits, as done at + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L162 + """ + def __compute_corr(tdata_i, tdata_j): + if tdata_i[0] == tdata_j[0]: + return 0.0 + corr_vals = compute_correlation(tdata_i[1], tdata_j[1]) + corr = corr_vals[0] + if (1 - corr) < 0: + return 0.0 + return 1 - corr + + def __cluster(tdata_i): + return tuple( + __compute_corr(tdata_i, tdata_j) + for tdata_j in enumerate(traits_data_list)) + + return tuple(__cluster(tdata_i) for tdata_i in enumerate(traits_data_list)) + +def heatmap_data(traits_names, conn: Any): + """ + heatmap function + + DESCRIPTION + This function is an attempt to reproduce the initialisation at + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L46-L64 + and also the clustering and slink computations at + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L165 + with the help of the `gn3.computations.heatmap.cluster_traits` function. + + It does not try to actually draw the heatmap image. + + PARAMETERS: + TODO: Elaborate on the parameters here... + """ + threshold = 0 # webqtlConfig.PUBLICTHRESH + def __retrieve_traitlist_and_datalist(threshold, fullname): + trait = retrieve_trait_info(threshold, fullname, conn) + return (trait, retrieve_trait_data(trait, conn)) + + traits_details = [ + __retrieve_traitlist_and_datalist(threshold, fullname) + for fullname in traits_names] + traits_list = tuple(x[0] for x in traits_details) + traits_data_list = [x[1] for x in traits_details] + genotype_filename = build_genotype_file(traits_list[0]["riset"]) + strainlist = load_genotype_samples(genotype_filename) + exported_traits_data_list = tuple( + export_trait_data(td, strainlist) for td in traits_data_list) + slink_data = slink(cluster_traits(exported_traits_data_list)) + ordering_data = compute_heatmap_order(slink_data) + strains_and_values = retrieve_strains_and_values( + ordering_data, strainlist, exported_traits_data_list) + strains_values = strains_and_values[0][1] + trait_values = [t[2] for t in strains_and_values] + traits_filename = generate_traits_filename() + generate_traits_file(strains_values, trait_values, traits_filename) + + return { + "slink_data": slink_data, + "ordering_data": ordering_data, + "strainlist": strainlist, + "genotype_filename": genotype_filename, + "traits_list": traits_list, + "traits_data_list": traits_data_list, + "exported_traits_data_list": exported_traits_data_list, + "traits_filename": traits_filename + } + +def compute_traits_order(slink_data, neworder: tuple = tuple()): + """ + Compute the order of the traits for clustering from `slink_data`. + + This function tries to reproduce the creation and update of the `neworder` + variable in + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L120 + and in the `web.webqtl.heatmap.Heatmap.draw` function in GN1 + """ + def __order_maker(norder, slnk_dt): + if isinstance(slnk_dt[0], int) and isinstance(slnk_dt[1], int): + return norder + (slnk_dt[0], slnk_dt[1]) + + if isinstance(slnk_dt[0], int): + return __order_maker((norder + (slnk_dt[0], )), slnk_dt[1]) + + if isinstance(slnk_dt[1], int): + return __order_maker(norder, slnk_dt[0]) + (slnk_dt[1], ) + + return __order_maker(__order_maker(norder, slnk_dt[0]), slnk_dt[1]) + + return __order_maker(neworder, slink_data) + +def retrieve_strains_and_values(orders, strainlist, traits_data_list): + """ + Get the strains and their corresponding values from `strainlist` and + `traits_data_list`. + + This migrates the code in + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L215-221 + """ + # This feels nasty! There's a lot of mutation of values here, that might + # indicate something untoward in the design of this function and its + # dependents ==> Review + strains = [] + values = [] + rets = [] + for order in orders: + temp_val = traits_data_list[order] + for i, strain in enumerate(strainlist): + if temp_val[i] is not None: + strains.append(strain) + values.append(temp_val[i]) + rets.append([order, strains[:], values[:]]) + strains = [] + values = [] + + return rets + +def nearest_marker_finder(genotype): + """ + Returns a function to be used with `genotype` to compute the nearest marker + to the trait passed to the returned function. + + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L425-434 + """ + def __compute_distances(chromo, trait): + loci = chromo.get("loci", None) + if not loci: + return None + return tuple( + { + "name": locus["name"], + "distance": abs(locus["Mb"] - trait["mb"]) + } for locus in loci) + + def __finder(trait): + _chrs = tuple( + _chr for _chr in genotype["chromosomes"] + if str(_chr["name"]) == str(trait["chr"])) + if len(_chrs) == 0: + return None + distances = tuple( + distance for dists in + filter( + lambda x: x is not None, + (__compute_distances(_chr, trait) for _chr in _chrs)) + for distance in dists) + nearest = min(distances, key=lambda d: d["distance"]) + return nearest["name"] + return __finder + +def get_nearest_marker(traits_list, genotype): + """ + Retrieves the nearest marker for each of the traits in the list. + + DESCRIPTION: + This migrates the code in + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L419-L438 + """ + if not genotype["Mbmap"]: + return [None] * len(trait_list) + + marker_finder = nearest_marker_finder(genotype) + return [marker_finder(trait) for trait in traits_list] + +# # Grey + Blue + Red +# def generate_heatmap(): +# cols = 20 +# y_axis = (["%s"%x for x in range(1, cols+1)][:-1] + ["X"]) #replace last item with x for now +# x_axis = heatmap_x_axis_names() +# data = generate_random_data(height=cols, width=len(x_axis)) +# fig = px.imshow( +# data, +# x=x_axis, +# y=y_axis, +# width=500) +# fig.update_traces(xtype="array") +# fig.update_traces(ytype="array") +# # fig.update_traces(xgap=10) +# fig.update_xaxes( +# visible=True, +# title_text="Traits", +# title_font_size=16) +# fig.update_layout( +# coloraxis_colorscale=[ +# [0.0, '#3B3B3B'], [0.4999999999999999, '#ABABAB'], +# [0.5, '#F5DE11'], [1.0, '#FF0D00']]) +# fig.write_html("%s/%s"%(heatmap_dir, "test_image.html")) +# return fig diff --git a/gn3/heatmaps/heatmaps.py b/gn3/heatmaps/heatmaps.py deleted file mode 100644 index 88f546d..0000000 --- a/gn3/heatmaps/heatmaps.py +++ /dev/null @@ -1,67 +0,0 @@ -import random -import plotly.express as px - -#### Remove these #### - -heatmap_dir = "heatmap_images" - -def generate_random_data(data_stop: float = 2, width: int = 10, height: int = 30): - """ - This is mostly a utility function to be used to generate random data, useful - for development of the heatmap generation code, without access to the actual - database data. - """ - return [[random.uniform(0,data_stop) for i in range(0, width)] - for j in range(0, height)] - -def generate_random_data2(data_stop: float = 2, width: int = 10, height: int = 30): - """ - This is mostly a utility function to be used to generate random data, useful - for development of the heatmap generation code, without access to the actual - database data. - """ - return [ - [{ - "value": item, - "category": random.choice(["C57BL/6J +", "DBA/2J +"])} - for item in axis] - for axis in generate_random_data(data_stop, width, height)] - -def heatmap_x_axis_names(): - return [ - "UCLA_BXDBXH_CARTILAGE_V2::ILM103710672", - "UCLA_BXDBXH_CARTILAGE_V2::ILM2260338", - "UCLA_BXDBXH_CARTILAGE_V2::ILM3140576", - "UCLA_BXDBXH_CARTILAGE_V2::ILM5670577", - "UCLA_BXDBXH_CARTILAGE_V2::ILM2070121", - "UCLA_BXDBXH_CARTILAGE_V2::ILM103990541", - "UCLA_BXDBXH_CARTILAGE_V2::ILM1190722", - "UCLA_BXDBXH_CARTILAGE_V2::ILM6590722", - "UCLA_BXDBXH_CARTILAGE_V2::ILM4200064", - "UCLA_BXDBXH_CARTILAGE_V2::ILM3140463"] -#### END: Remove these #### - -# Grey + Blue + Red -def generate_heatmap(): - cols = 20 - y_axis = (["%s"%x for x in range(1, cols+1)][:-1] + ["X"]) #replace last item with x for now - x_axis = heatmap_x_axis_names() - data = generate_random_data(height=cols, width=len(x_axis)) - fig = px.imshow( - data, - x=x_axis, - y=y_axis, - width=500) - fig.update_traces(xtype="array") - fig.update_traces(ytype="array") - # fig.update_traces(xgap=10) - fig.update_xaxes( - visible=True, - title_text="Traits", - title_font_size=16) - fig.update_layout( - coloraxis_colorscale=[ - [0.0, '#3B3B3B'], [0.4999999999999999, '#ABABAB'], - [0.5, '#F5DE11'], [1.0, '#FF0D00']]) - fig.write_html("%s/%s"%(heatmap_dir, "test_image.html")) - return fig diff --git a/tests/unit/computations/test_heatmap.py b/tests/unit/computations/test_heatmap.py deleted file mode 100644 index 156af45..0000000 --- a/tests/unit/computations/test_heatmap.py +++ /dev/null @@ -1,187 +0,0 @@ -"""Module contains tests for gn3.computations.heatmap""" -from unittest import TestCase -from gn3.computations.heatmap import ( - cluster_traits, - export_trait_data, - compute_traits_order, - retrieve_strains_and_values) - -strainlist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"] -trait_data = { - "mysqlid": 36688172, - "data": { - "B6cC3-1": {"strain_name": "B6cC3-1", "value": 7.51879, "variance": None, "ndata": None}, - "BXD1": {"strain_name": "BXD1", "value": 7.77141, "variance": None, "ndata": None}, - "BXD12": {"strain_name": "BXD12", "value": 8.39265, "variance": None, "ndata": None}, - "BXD16": {"strain_name": "BXD16", "value": 8.17443, "variance": None, "ndata": None}, - "BXD19": {"strain_name": "BXD19", "value": 8.30401, "variance": None, "ndata": None}, - "BXD2": {"strain_name": "BXD2", "value": 7.80944, "variance": None, "ndata": None}, - "BXD21": {"strain_name": "BXD21", "value": 8.93809, "variance": None, "ndata": None}, - "BXD24": {"strain_name": "BXD24", "value": 7.99415, "variance": None, "ndata": None}, - "BXD27": {"strain_name": "BXD27", "value": 8.12177, "variance": None, "ndata": None}, - "BXD28": {"strain_name": "BXD28", "value": 7.67688, "variance": None, "ndata": None}, - "BXD32": {"strain_name": "BXD32", "value": 7.79062, "variance": None, "ndata": None}, - "BXD39": {"strain_name": "BXD39", "value": 8.27641, "variance": None, "ndata": None}, - "BXD40": {"strain_name": "BXD40", "value": 8.18012, "variance": None, "ndata": None}, - "BXD42": {"strain_name": "BXD42", "value": 7.82433, "variance": None, "ndata": None}, - "BXD6": {"strain_name": "BXD6", "value": 8.09718, "variance": None, "ndata": None}, - "BXH14": {"strain_name": "BXH14", "value": 7.97475, "variance": None, "ndata": None}, - "BXH19": {"strain_name": "BXH19", "value": 7.67223, "variance": None, "ndata": None}, - "BXH2": {"strain_name": "BXH2", "value": 7.93622, "variance": None, "ndata": None}, - "BXH22": {"strain_name": "BXH22", "value": 7.43692, "variance": None, "ndata": None}, - "BXH4": {"strain_name": "BXH4", "value": 7.96336, "variance": None, "ndata": None}, - "BXH6": {"strain_name": "BXH6", "value": 7.75132, "variance": None, "ndata": None}, - "BXH7": {"strain_name": "BXH7", "value": 8.12927, "variance": None, "ndata": None}, - "BXH8": {"strain_name": "BXH8", "value": 6.77338, "variance": None, "ndata": None}, - "BXH9": {"strain_name": "BXH9", "value": 8.03836, "variance": None, "ndata": None}, - "C3H/HeJ": {"strain_name": "C3H/HeJ", "value": 7.42795, "variance": None, "ndata": None}, - "C57BL/6J": {"strain_name": "C57BL/6J", "value": 7.50606, "variance": None, "ndata": None}, - "DBA/2J": {"strain_name": "DBA/2J", "value": 7.72588, "variance": None, "ndata": None}}} - -slinked = ( - (((0, 2, 0.16381088984330505), - ((1, 7, 0.06024619831474998), 5, 0.19179284676938602), - 0.20337048635536847), - 9, - 0.23451785425383564), - ((3, (6, 8, 0.2140799896286565), 0.25879514152086425), - 4, 0.8968250491499363), - 0.9313185954797953) - -class TestHeatmap(TestCase): - """Class for testing heatmap computation functions""" - - def test_export_trait_data_dtype(self): - """ - Test `export_trait_data` with different values for the `dtype` keyword - argument - """ - for dtype, expected in [ - ["val", (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], - ["var", (None, None, None, None, None, None)], - ["N", (None, None, None, None, None, None)], - ["all", (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)]]: - with self.subTest(dtype=dtype): - self.assertEqual( - export_trait_data(trait_data, strainlist, dtype=dtype), - expected) - - def test_export_trait_data_dtype_all_flags(self): - """ - Test `export_trait_data` with different values for the `dtype` keyword - argument and the different flags set up - """ - for dtype, vflag, nflag, expected in [ - ["val", False, False, - (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], - ["val", False, True, - (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], - ["val", True, False, - (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], - ["val", True, True, - (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], - ["var", False, False, (None, None, None, None, None, None)], - ["var", False, True, (None, None, None, None, None, None)], - ["var", True, False, (None, None, None, None, None, None)], - ["var", True, True, (None, None, None, None, None, None)], - ["N", False, False, (None, None, None, None, None, None)], - ["N", False, True, (None, None, None, None, None, None)], - ["N", True, False, (None, None, None, None, None, None)], - ["N", True, True, (None, None, None, None, None, None)], - ["all", False, False, - (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], - ["all", False, True, - (7.51879, None, 7.77141, None, 8.39265, None, 8.17443, None, - 8.30401, None, 7.80944, None)], - ["all", True, False, - (7.51879, None, 7.77141, None, 8.39265, None, 8.17443, None, - 8.30401, None, 7.80944, None)], - ["all", True, True, - (7.51879, None, None, 7.77141, None, None, 8.39265, None, None, - 8.17443, None, None, 8.30401, None, None, 7.80944, None, None)] - ]: - with self.subTest(dtype=dtype, vflag=vflag, nflag=nflag): - self.assertEqual( - export_trait_data( - trait_data, strainlist, dtype=dtype, var_exists=vflag, - n_exists=nflag), - expected) - - def test_cluster_traits(self): - """ - Test that the clustering is working as expected. - """ - traits_data_list = [ - (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944), - (6.1427, 6.50588, 7.73705, 6.68328, 7.49293, 7.27398), - (8.4211, 8.30581, 9.24076, 8.51173, 9.18455, 8.36077), - (10.0904, 10.6509, 9.36716, 9.91202, 8.57444, 10.5731), - (10.188, 9.76652, 9.54813, 9.05074, 9.52319, 9.10505), - (6.74676, 7.01029, 7.54169, 6.48574, 7.01427, 7.26815), - (6.39359, 6.85321, 5.78337, 7.11141, 6.22101, 6.16544), - (6.84118, 7.08432, 7.59844, 7.08229, 7.26774, 7.24991), - (9.45215, 10.6943, 8.64719, 10.1592, 7.75044, 8.78615), - (7.04737, 6.87185, 7.58586, 6.92456, 6.84243, 7.36913)] - self.assertEqual( - cluster_traits(traits_data_list), - ((0.0, 0.20337048635536847, 0.16381088984330505, 1.7388553629398245, - 1.5025235756329178, 0.6952839500255574, 1.271661230252733, - 0.2100487290977544, 1.4699690641062024, 0.7934461515867415), - (0.20337048635536847, 0.0, 0.2198321044997198, 1.5753041735592204, - 1.4815755944537086, 0.26087293140686374, 1.6939790104301427, - 0.06024619831474998, 1.7430082449189215, 0.4497104244247795), - (0.16381088984330505, 0.2198321044997198, 0.0, 1.9073926868549234, - 1.0396738891139845, 0.5278328671176757, 1.6275069061182947, - 0.2636503792482082, 1.739617877037615, 0.7127042590637039), - (1.7388553629398245, 1.5753041735592204, 1.9073926868549234, 0.0, - 0.9936846292920328, 1.1169999189889366, 0.6007483980555253, - 1.430209221053372, 0.25879514152086425, 0.9313185954797953), - (1.5025235756329178, 1.4815755944537086, 1.0396738891139845, - 0.9936846292920328, 0.0, 1.027827186339337, 1.1441743109173244, - 1.4122477962364253, 0.8968250491499363, 1.1683723389247052), - (0.6952839500255574, 0.26087293140686374, 0.5278328671176757, - 1.1169999189889366, 1.027827186339337, 0.0, 1.8420471110023269, - 0.19179284676938602, 1.4875072385631605, 0.23451785425383564), - (1.271661230252733, 1.6939790104301427, 1.6275069061182947, - 0.6007483980555253, 1.1441743109173244, 1.8420471110023269, 0.0, - 1.6540234785929928, 0.2140799896286565, 1.7413442197913358), - (0.2100487290977544, 0.06024619831474998, 0.2636503792482082, - 1.430209221053372, 1.4122477962364253, 0.19179284676938602, - 1.6540234785929928, 0.0, 1.5225640692832796, 0.33370067057028485), - (1.4699690641062024, 1.7430082449189215, 1.739617877037615, - 0.25879514152086425, 0.8968250491499363, 1.4875072385631605, - 0.2140799896286565, 1.5225640692832796, 0.0, 1.3256191648260216), - (0.7934461515867415, 0.4497104244247795, 0.7127042590637039, - 0.9313185954797953, 1.1683723389247052, 0.23451785425383564, - 1.7413442197913358, 0.33370067057028485, 1.3256191648260216, - 0.0))) - - def test_compute_heatmap_order(self): - """Test the orders.""" - self.assertEqual( - compute_traits_order(slinked), (0, 2, 1, 7, 5, 9, 3, 6, 8, 4)) - - def test_retrieve_strains_and_values(self): - """Test retrieval of strains and values.""" - for orders, slist, tdata, expected in [ - [ - [2], - ["s1", "s2", "s3", "s4"], - [[2, 9, 6, None, 4], - [7, 5, None, None, 4], - [9, None, 5, 4, 7], - [6, None, None, 4, None]], - [[2, ["s1", "s3", "s4"], [9, 5, 4]]] - ], - [ - [3], - ["s1", "s2", "s3", "s4", "s5"], - [[2, 9, 6, None, 4], - [7, 5, None, None, 4], - [9, None, 5, 4, 7], - [6, None, None, 4, None]], - [[3, ["s1", "s4"], [6, 4]]] - ]]: - with self.subTest(strainlist=slist, traitdata=tdata): - self.assertEqual( - retrieve_strains_and_values(orders, slist, tdata), expected) diff --git a/tests/unit/test_heatmaps.py b/tests/unit/test_heatmaps.py new file mode 100644 index 0000000..265d5a8 --- /dev/null +++ b/tests/unit/test_heatmaps.py @@ -0,0 +1,187 @@ +"""Module contains tests for gn3.heatmaps.heatmaps""" +from unittest import TestCase +from gn3.heatmaps import ( + cluster_traits, + export_trait_data, + compute_traits_order, + retrieve_strains_and_values) + +strainlist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"] +trait_data = { + "mysqlid": 36688172, + "data": { + "B6cC3-1": {"strain_name": "B6cC3-1", "value": 7.51879, "variance": None, "ndata": None}, + "BXD1": {"strain_name": "BXD1", "value": 7.77141, "variance": None, "ndata": None}, + "BXD12": {"strain_name": "BXD12", "value": 8.39265, "variance": None, "ndata": None}, + "BXD16": {"strain_name": "BXD16", "value": 8.17443, "variance": None, "ndata": None}, + "BXD19": {"strain_name": "BXD19", "value": 8.30401, "variance": None, "ndata": None}, + "BXD2": {"strain_name": "BXD2", "value": 7.80944, "variance": None, "ndata": None}, + "BXD21": {"strain_name": "BXD21", "value": 8.93809, "variance": None, "ndata": None}, + "BXD24": {"strain_name": "BXD24", "value": 7.99415, "variance": None, "ndata": None}, + "BXD27": {"strain_name": "BXD27", "value": 8.12177, "variance": None, "ndata": None}, + "BXD28": {"strain_name": "BXD28", "value": 7.67688, "variance": None, "ndata": None}, + "BXD32": {"strain_name": "BXD32", "value": 7.79062, "variance": None, "ndata": None}, + "BXD39": {"strain_name": "BXD39", "value": 8.27641, "variance": None, "ndata": None}, + "BXD40": {"strain_name": "BXD40", "value": 8.18012, "variance": None, "ndata": None}, + "BXD42": {"strain_name": "BXD42", "value": 7.82433, "variance": None, "ndata": None}, + "BXD6": {"strain_name": "BXD6", "value": 8.09718, "variance": None, "ndata": None}, + "BXH14": {"strain_name": "BXH14", "value": 7.97475, "variance": None, "ndata": None}, + "BXH19": {"strain_name": "BXH19", "value": 7.67223, "variance": None, "ndata": None}, + "BXH2": {"strain_name": "BXH2", "value": 7.93622, "variance": None, "ndata": None}, + "BXH22": {"strain_name": "BXH22", "value": 7.43692, "variance": None, "ndata": None}, + "BXH4": {"strain_name": "BXH4", "value": 7.96336, "variance": None, "ndata": None}, + "BXH6": {"strain_name": "BXH6", "value": 7.75132, "variance": None, "ndata": None}, + "BXH7": {"strain_name": "BXH7", "value": 8.12927, "variance": None, "ndata": None}, + "BXH8": {"strain_name": "BXH8", "value": 6.77338, "variance": None, "ndata": None}, + "BXH9": {"strain_name": "BXH9", "value": 8.03836, "variance": None, "ndata": None}, + "C3H/HeJ": {"strain_name": "C3H/HeJ", "value": 7.42795, "variance": None, "ndata": None}, + "C57BL/6J": {"strain_name": "C57BL/6J", "value": 7.50606, "variance": None, "ndata": None}, + "DBA/2J": {"strain_name": "DBA/2J", "value": 7.72588, "variance": None, "ndata": None}}} + +slinked = ( + (((0, 2, 0.16381088984330505), + ((1, 7, 0.06024619831474998), 5, 0.19179284676938602), + 0.20337048635536847), + 9, + 0.23451785425383564), + ((3, (6, 8, 0.2140799896286565), 0.25879514152086425), + 4, 0.8968250491499363), + 0.9313185954797953) + +class TestHeatmap(TestCase): + """Class for testing heatmap computation functions""" + + def test_export_trait_data_dtype(self): + """ + Test `export_trait_data` with different values for the `dtype` keyword + argument + """ + for dtype, expected in [ + ["val", (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["var", (None, None, None, None, None, None)], + ["N", (None, None, None, None, None, None)], + ["all", (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)]]: + with self.subTest(dtype=dtype): + self.assertEqual( + export_trait_data(trait_data, strainlist, dtype=dtype), + expected) + + def test_export_trait_data_dtype_all_flags(self): + """ + Test `export_trait_data` with different values for the `dtype` keyword + argument and the different flags set up + """ + for dtype, vflag, nflag, expected in [ + ["val", False, False, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["val", False, True, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["val", True, False, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["val", True, True, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["var", False, False, (None, None, None, None, None, None)], + ["var", False, True, (None, None, None, None, None, None)], + ["var", True, False, (None, None, None, None, None, None)], + ["var", True, True, (None, None, None, None, None, None)], + ["N", False, False, (None, None, None, None, None, None)], + ["N", False, True, (None, None, None, None, None, None)], + ["N", True, False, (None, None, None, None, None, None)], + ["N", True, True, (None, None, None, None, None, None)], + ["all", False, False, + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)], + ["all", False, True, + (7.51879, None, 7.77141, None, 8.39265, None, 8.17443, None, + 8.30401, None, 7.80944, None)], + ["all", True, False, + (7.51879, None, 7.77141, None, 8.39265, None, 8.17443, None, + 8.30401, None, 7.80944, None)], + ["all", True, True, + (7.51879, None, None, 7.77141, None, None, 8.39265, None, None, + 8.17443, None, None, 8.30401, None, None, 7.80944, None, None)] + ]: + with self.subTest(dtype=dtype, vflag=vflag, nflag=nflag): + self.assertEqual( + export_trait_data( + trait_data, strainlist, dtype=dtype, var_exists=vflag, + n_exists=nflag), + expected) + + def test_cluster_traits(self): + """ + Test that the clustering is working as expected. + """ + traits_data_list = [ + (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944), + (6.1427, 6.50588, 7.73705, 6.68328, 7.49293, 7.27398), + (8.4211, 8.30581, 9.24076, 8.51173, 9.18455, 8.36077), + (10.0904, 10.6509, 9.36716, 9.91202, 8.57444, 10.5731), + (10.188, 9.76652, 9.54813, 9.05074, 9.52319, 9.10505), + (6.74676, 7.01029, 7.54169, 6.48574, 7.01427, 7.26815), + (6.39359, 6.85321, 5.78337, 7.11141, 6.22101, 6.16544), + (6.84118, 7.08432, 7.59844, 7.08229, 7.26774, 7.24991), + (9.45215, 10.6943, 8.64719, 10.1592, 7.75044, 8.78615), + (7.04737, 6.87185, 7.58586, 6.92456, 6.84243, 7.36913)] + self.assertEqual( + cluster_traits(traits_data_list), + ((0.0, 0.20337048635536847, 0.16381088984330505, 1.7388553629398245, + 1.5025235756329178, 0.6952839500255574, 1.271661230252733, + 0.2100487290977544, 1.4699690641062024, 0.7934461515867415), + (0.20337048635536847, 0.0, 0.2198321044997198, 1.5753041735592204, + 1.4815755944537086, 0.26087293140686374, 1.6939790104301427, + 0.06024619831474998, 1.7430082449189215, 0.4497104244247795), + (0.16381088984330505, 0.2198321044997198, 0.0, 1.9073926868549234, + 1.0396738891139845, 0.5278328671176757, 1.6275069061182947, + 0.2636503792482082, 1.739617877037615, 0.7127042590637039), + (1.7388553629398245, 1.5753041735592204, 1.9073926868549234, 0.0, + 0.9936846292920328, 1.1169999189889366, 0.6007483980555253, + 1.430209221053372, 0.25879514152086425, 0.9313185954797953), + (1.5025235756329178, 1.4815755944537086, 1.0396738891139845, + 0.9936846292920328, 0.0, 1.027827186339337, 1.1441743109173244, + 1.4122477962364253, 0.8968250491499363, 1.1683723389247052), + (0.6952839500255574, 0.26087293140686374, 0.5278328671176757, + 1.1169999189889366, 1.027827186339337, 0.0, 1.8420471110023269, + 0.19179284676938602, 1.4875072385631605, 0.23451785425383564), + (1.271661230252733, 1.6939790104301427, 1.6275069061182947, + 0.6007483980555253, 1.1441743109173244, 1.8420471110023269, 0.0, + 1.6540234785929928, 0.2140799896286565, 1.7413442197913358), + (0.2100487290977544, 0.06024619831474998, 0.2636503792482082, + 1.430209221053372, 1.4122477962364253, 0.19179284676938602, + 1.6540234785929928, 0.0, 1.5225640692832796, 0.33370067057028485), + (1.4699690641062024, 1.7430082449189215, 1.739617877037615, + 0.25879514152086425, 0.8968250491499363, 1.4875072385631605, + 0.2140799896286565, 1.5225640692832796, 0.0, 1.3256191648260216), + (0.7934461515867415, 0.4497104244247795, 0.7127042590637039, + 0.9313185954797953, 1.1683723389247052, 0.23451785425383564, + 1.7413442197913358, 0.33370067057028485, 1.3256191648260216, + 0.0))) + + def test_compute_heatmap_order(self): + """Test the orders.""" + self.assertEqual( + compute_traits_order(slinked), (0, 2, 1, 7, 5, 9, 3, 6, 8, 4)) + + def test_retrieve_strains_and_values(self): + """Test retrieval of strains and values.""" + for orders, slist, tdata, expected in [ + [ + [2], + ["s1", "s2", "s3", "s4"], + [[2, 9, 6, None, 4], + [7, 5, None, None, 4], + [9, None, 5, 4, 7], + [6, None, None, 4, None]], + [[2, ["s1", "s3", "s4"], [9, 5, 4]]] + ], + [ + [3], + ["s1", "s2", "s3", "s4", "s5"], + [[2, 9, 6, None, 4], + [7, 5, None, None, 4], + [9, None, 5, 4, 7], + [6, None, None, 4, None]], + [[3, ["s1", "s4"], [6, 4]]] + ]]: + with self.subTest(strainlist=slist, traitdata=tdata): + self.assertEqual( + retrieve_strains_and_values(orders, slist, tdata), expected) -- cgit v1.2.3