about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/heatmaps.py24
1 files changed, 15 insertions, 9 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index 170b0cd..bf69d9b 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -3,9 +3,11 @@ This module will contain functions to be used in computation of the data used to
 generate various kinds of heatmaps.
 """
 
+import numpy as np
 from functools import reduce
 from gn3.settings import TMPDIR
 import plotly.graph_objects as go
+import plotly.figure_factory as ff
 from gn3.random import random_string
 from typing import Any, Dict, Sequence
 from gn3.computations.slink import slink
@@ -167,7 +169,8 @@ def build_heatmap(traits_names, conn: Any):
     strains = load_genotype_samples(genotype_filename)
     exported_traits_data_list = [
         export_trait_data(td, strains) for td in traits_data_list]
-    slinked = slink(cluster_traits(exported_traits_data_list))
+    clustered = cluster_traits(exported_traits_data_list)
+    slinked = slink(clustered)
     traits_order = compute_traits_order(slinked)
     ordered_traits_names = [
         traits[idx]["trait_fullname"] for idx in traits_order]
@@ -200,6 +203,7 @@ def build_heatmap(traits_names, conn: Any):
     return generate_clustered_heatmap(
         process_traits_data_for_heatmap(
             organised, traits_ids, chromosome_names),
+        clustered,
         "single_heatmap_{}".format(random_string(10)),
         y_axis=tuple(
             ordered_traits_names[traits_ids[order]]
@@ -336,8 +340,9 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names):
     return hdata
 
 def generate_clustered_heatmap(
-        data, image_filename_prefix, x_axis = None, x_label: str = "",
-        y_axis = None, y_label: str = "", output_dir: str = TMPDIR,
+        data, clustering_data, image_filename_prefix, x_axis = None,
+        x_label: str = "", y_axis = None, y_label: str = "",
+        output_dir: str = TMPDIR,
         colorscale = (
             (0.0, '#3B3B3B'), (0.4999999999999999, '#ABABAB'),
             (0.5, '#F5DE11'), (1.0, '#FF0D00'))):
@@ -345,21 +350,22 @@ def generate_clustered_heatmap(
     Generate a dendrogram, and heatmaps for each chromosome, and put them all
     into one plot.
     """
-    num_cols = len(x_axis)
+    num_cols = 1 + len(x_axis)
     fig = make_subplots(
         rows=1,
         cols=num_cols,
         shared_yaxes="rows",
-        # horizontal_spacing=(1 / (num_cols - 1)),
-        subplot_titles=x_axis
-    )
+        horizontal_spacing=0.001,
+        subplot_titles=["distance"] + x_axis,
+        figure = ff.create_dendrogram(
+            np.array(clustering_data), orientation="right", labels=y_axis))
     hms = [go.Heatmap(
         name=chromo,
         y = y_axis,
         z = data_array,
         showscale=False) for chromo, data_array in zip(x_axis, data)]
-    for col, hm in enumerate(hms):
-        fig.add_trace(hm, row=1, col=(col + 1))
+    for i, hm in enumerate(hms):
+        fig.add_trace(hm, row=1, col=(i + 2))
 
     fig.update_traces(
         showlegend=False,