about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/heatmaps.py28
-rw-r--r--tests/unit/test_heatmaps.py15
2 files changed, 42 insertions, 1 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index 2ef2d16..9c10ba3 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -4,7 +4,7 @@ generate various kinds of heatmaps.
 """
 
 from functools import reduce
-from typing import Any, Dict, Sequence
+from typing import Any, Dict, Union, Sequence
 
 import numpy as np
 import plotly.graph_objects as go # type: ignore
@@ -142,6 +142,32 @@ def cluster_traits(traits_data_list: Sequence[Dict]):
 
     return tuple(__cluster(tdata_i) for tdata_i in enumerate(traits_data_list))
 
+def get_loci_names(
+        organised: dict,
+        chromosome_names: Sequence[str]) -> Sequence[Sequence[str]]:
+    """
+    Get the loci names organised by the same order as the `chromosome_names`.
+    """
+    def __get_trait_loci(accumulator, trait):
+        chrs = tuple(trait["chromosomes"].keys())
+        trait_loci = {
+            _chr: tuple(
+                locus["Locus"]
+                for locus in trait["chromosomes"][_chr]["loci"]
+            ) for _chr in chrs
+        }
+        return {
+            **accumulator,
+            **{
+                _chr: tuple(sorted(set(
+                    trait_loci[_chr] + accumulator.get(_chr, tuple()))))
+                for _chr in trait_loci.keys()
+            }
+        }
+    loci_dict: Dict[Union[str, int], Sequence[str]] = reduce(
+        __get_trait_loci, [v[1] for v in organised.items()], {})
+    return tuple(loci_dict[_chr] for _chr in chromosome_names)
+
 def build_heatmap(traits_names, conn: Any):
     """
     heatmap function
diff --git a/tests/unit/test_heatmaps.py b/tests/unit/test_heatmaps.py
index b54e2f3..7b66688 100644
--- a/tests/unit/test_heatmaps.py
+++ b/tests/unit/test_heatmaps.py
@@ -2,6 +2,7 @@
 from unittest import TestCase
 from gn3.heatmaps import (
     cluster_traits,
+    get_loci_names,
     get_lrs_from_chr,
     export_trait_data,
     compute_traits_order,
@@ -214,3 +215,17 @@ class TestHeatmap(TestCase):
               [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]],
              [[0.5, 0.579, 0.5],
               [0.5, 0.5, 0.5]]])
+
+    def test_get_loci_names(self):
+        """Check that loci names are retrieved correctly."""
+        for organised, expected in (
+                (organised_trait_1,
+                 (("rs258367496", "rs30658298", "rs31443144", "rs32285189",
+                   "rs32430919", "rs36251697", "rs6269442"),
+                  ("rs31879829", "rs36742481", "rs51852623"))),
+                ({**organised_trait_1, **organised_trait_2},
+                 (("rs258367496", "rs30658298", "rs31443144", "rs32285189",
+                   "rs32430919", "rs36251697", "rs6269442"),
+                  ("rs31879829", "rs36742481", "rs51852623")))):
+            with self.subTest(organised=organised):
+                self.assertEqual(get_loci_names(organised, (1, 2)), expected)