diff options
Diffstat (limited to 'gn3')
-rw-r--r-- | gn3/heatmaps.py | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index 170b0cd..bf69d9b 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -3,9 +3,11 @@ This module will contain functions to be used in computation of the data used to generate various kinds of heatmaps. """ +import numpy as np from functools import reduce from gn3.settings import TMPDIR import plotly.graph_objects as go +import plotly.figure_factory as ff from gn3.random import random_string from typing import Any, Dict, Sequence from gn3.computations.slink import slink @@ -167,7 +169,8 @@ def build_heatmap(traits_names, conn: Any): strains = load_genotype_samples(genotype_filename) exported_traits_data_list = [ export_trait_data(td, strains) for td in traits_data_list] - slinked = slink(cluster_traits(exported_traits_data_list)) + clustered = cluster_traits(exported_traits_data_list) + slinked = slink(clustered) traits_order = compute_traits_order(slinked) ordered_traits_names = [ traits[idx]["trait_fullname"] for idx in traits_order] @@ -200,6 +203,7 @@ def build_heatmap(traits_names, conn: Any): return generate_clustered_heatmap( process_traits_data_for_heatmap( organised, traits_ids, chromosome_names), + clustered, "single_heatmap_{}".format(random_string(10)), y_axis=tuple( ordered_traits_names[traits_ids[order]] @@ -336,8 +340,9 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names): return hdata def generate_clustered_heatmap( - data, image_filename_prefix, x_axis = None, x_label: str = "", - y_axis = None, y_label: str = "", output_dir: str = TMPDIR, + data, clustering_data, image_filename_prefix, x_axis = None, + x_label: str = "", y_axis = None, y_label: str = "", + output_dir: str = TMPDIR, colorscale = ( (0.0, '#3B3B3B'), (0.4999999999999999, '#ABABAB'), (0.5, '#F5DE11'), (1.0, '#FF0D00'))): @@ -345,21 +350,22 @@ def generate_clustered_heatmap( Generate a dendrogram, and heatmaps for each chromosome, and put them all into one plot. """ - num_cols = len(x_axis) + num_cols = 1 + len(x_axis) fig = make_subplots( rows=1, cols=num_cols, shared_yaxes="rows", - # horizontal_spacing=(1 / (num_cols - 1)), - subplot_titles=x_axis - ) + horizontal_spacing=0.001, + subplot_titles=["distance"] + x_axis, + figure = ff.create_dendrogram( + np.array(clustering_data), orientation="right", labels=y_axis)) hms = [go.Heatmap( name=chromo, y = y_axis, z = data_array, showscale=False) for chromo, data_array in zip(x_axis, data)] - for col, hm in enumerate(hms): - fig.add_trace(hm, row=1, col=(col + 1)) + for i, hm in enumerate(hms): + fig.add_trace(hm, row=1, col=(i + 2)) fig.update_traces( showlegend=False, |