diff options
-rw-r--r-- | gn3/heatmaps.py | 28 | ||||
-rw-r--r-- | tests/unit/test_heatmaps.py | 15 |
2 files changed, 42 insertions, 1 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index 2ef2d16..9c10ba3 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -4,7 +4,7 @@ generate various kinds of heatmaps. """ from functools import reduce -from typing import Any, Dict, Sequence +from typing import Any, Dict, Union, Sequence import numpy as np import plotly.graph_objects as go # type: ignore @@ -142,6 +142,32 @@ def cluster_traits(traits_data_list: Sequence[Dict]): return tuple(__cluster(tdata_i) for tdata_i in enumerate(traits_data_list)) +def get_loci_names( + organised: dict, + chromosome_names: Sequence[str]) -> Sequence[Sequence[str]]: + """ + Get the loci names organised by the same order as the `chromosome_names`. + """ + def __get_trait_loci(accumulator, trait): + chrs = tuple(trait["chromosomes"].keys()) + trait_loci = { + _chr: tuple( + locus["Locus"] + for locus in trait["chromosomes"][_chr]["loci"] + ) for _chr in chrs + } + return { + **accumulator, + **{ + _chr: tuple(sorted(set( + trait_loci[_chr] + accumulator.get(_chr, tuple())))) + for _chr in trait_loci.keys() + } + } + loci_dict: Dict[Union[str, int], Sequence[str]] = reduce( + __get_trait_loci, [v[1] for v in organised.items()], {}) + return tuple(loci_dict[_chr] for _chr in chromosome_names) + def build_heatmap(traits_names, conn: Any): """ heatmap function diff --git a/tests/unit/test_heatmaps.py b/tests/unit/test_heatmaps.py index b54e2f3..7b66688 100644 --- a/tests/unit/test_heatmaps.py +++ b/tests/unit/test_heatmaps.py @@ -2,6 +2,7 @@ from unittest import TestCase from gn3.heatmaps import ( cluster_traits, + get_loci_names, get_lrs_from_chr, export_trait_data, compute_traits_order, @@ -214,3 +215,17 @@ class TestHeatmap(TestCase): [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]], [[0.5, 0.579, 0.5], [0.5, 0.5, 0.5]]]) + + def test_get_loci_names(self): + """Check that loci names are retrieved correctly.""" + for organised, expected in ( + (organised_trait_1, + (("rs258367496", "rs30658298", "rs31443144", "rs32285189", + "rs32430919", "rs36251697", "rs6269442"), + ("rs31879829", "rs36742481", "rs51852623"))), + ({**organised_trait_1, **organised_trait_2}, + (("rs258367496", "rs30658298", "rs31443144", "rs32285189", + "rs32430919", "rs36251697", "rs6269442"), + ("rs31879829", "rs36742481", "rs51852623")))): + with self.subTest(organised=organised): + self.assertEqual(get_loci_names(organised, (1, 2)), expected) |