From e9fb4e45cfc52c5d86ef534b0e7f42ba8f4c84d3 Mon Sep 17 00:00:00 2001
From: Frederick Muriuki Muriithi
Date: Wed, 15 Sep 2021 12:28:56 +0300
Subject: Generate heatmaps in a single plot

Issue:
https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi

* Add a function to generate the heatmaps for each chromosome into a single
  plot.
---
 gn3/heatmaps.py | 65 ++++++++++++++++++++++++++++++++++++---------------------
 1 file changed, 41 insertions(+), 24 deletions(-)

diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index 0c00d6c..f3d7d25 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -4,6 +4,7 @@ generate various kinds of heatmaps.
 """
 
 from functools import reduce
+from gn3.settings import TMPDIR
 from typing import Any, Dict, Sequence
 from gn3.computations.slink import slink
 from gn3.computations.qtlreaper import generate_traits_file
@@ -296,27 +297,43 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names):
         for chr_name in chromosome_names]
     return hdata
 
-# # Grey + Blue + Red
-# def generate_heatmap():
-#     cols = 20
-#     y_axis = (["%s"%x for x in range(1, cols+1)][:-1] + ["X"]) #replace last item with x for now
-#     x_axis = heatmap_x_axis_names()
-#     data = generate_random_data(height=cols, width=len(x_axis))
-#     fig = px.imshow(
-#         data,
-#         x=x_axis,
-#         y=y_axis,
-#         width=500)
-#     fig.update_traces(xtype="array")
-#     fig.update_traces(ytype="array")
-#     # fig.update_traces(xgap=10)
-#     fig.update_xaxes(
-#         visible=True,
-#         title_text="Traits",
-#         title_font_size=16)
-#     fig.update_layout(
-#         coloraxis_colorscale=[
-#             [0.0, '#3B3B3B'], [0.4999999999999999, '#ABABAB'],
-#             [0.5, '#F5DE11'], [1.0, '#FF0D00']])
-#     fig.write_html("%s/%s"%(heatmap_dir, "test_image.html"))
-#     return fig
+def generate_clustered_heatmap(
+        data, image_filename_prefix, x_axis = None, x_label: str = "",
+        y_axis = None, y_label: str = "", output_dir: str = TMPDIR,
+        colorscale = [
+            [0.0, '#3B3B3B'], [0.4999999999999999, '#ABABAB'],
+            [0.5, '#F5DE11']], [1.0, '#FF0D00']):
+    """
+    Generate a dendrogram, and heatmaps for each chromosome, and put them all
+    into one plot.
+    """
+    num_cols = len(x_axis)
+    fig = make_subplots(
+        rows=1,
+        cols=num_cols,
+        shared_yaxes="rows",
+        # horizontal_spacing=(1 / (num_cols - 1)),
+        subplot_titles=x_axis
+    )
+    hms = [go.Heatmap(
+        name=chromo,
+        y = y_axis,
+        z = data_array,
+        showscale=False) for chromo, data_array in zip(x_axis, data)]
+    for col, hm in enumerate(hms):
+        fig.add_trace(hm, row=1, col=(col + 1))
+
+    fig.update_traces(
+        showlegend=False,
+        colorscale=colorscale,
+        selector={"type": "heatmap"})
+    fig.update_traces(
+        showlegend=True,
+        showscale=True,
+        selector={"name": x_axis[-1]})
+    fig.update_layout(
+        coloraxis_colorscale=colorscale
+    )
+    image_filename = "{}/{}.html".format(output_dir, image_filename_prefix)
+    fig.write_html(image_filename)
+    return image_filename, fig
-- 
cgit v1.2.3