aboutsummaryrefslogtreecommitdiff
path: root/gn3
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2021-09-17 10:30:16 +0300
committerFrederick Muriuki Muriithi2021-09-17 10:30:16 +0300
commit78871ef396f394c54072960e985476e418220fe3 (patch)
tree1a012f21154660c22d81c13e18400712a8176f31 /gn3
parent056171a0a2f127e90ab803b74635495fb0c079a2 (diff)
downloadgenenetwork3-78871ef396f394c54072960e985476e418220fe3.tar.gz
Create dendrogram to show clustering tree
Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Provide the clustering data to be used for the creation of the clustering dendrogram in the final clustered heatmap plot.
Diffstat (limited to 'gn3')
-rw-r--r--gn3/heatmaps.py24
1 files changed, 15 insertions, 9 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index 170b0cd..bf69d9b 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -3,9 +3,11 @@ This module will contain functions to be used in computation of the data used to
generate various kinds of heatmaps.
"""
+import numpy as np
from functools import reduce
from gn3.settings import TMPDIR
import plotly.graph_objects as go
+import plotly.figure_factory as ff
from gn3.random import random_string
from typing import Any, Dict, Sequence
from gn3.computations.slink import slink
@@ -167,7 +169,8 @@ def build_heatmap(traits_names, conn: Any):
strains = load_genotype_samples(genotype_filename)
exported_traits_data_list = [
export_trait_data(td, strains) for td in traits_data_list]
- slinked = slink(cluster_traits(exported_traits_data_list))
+ clustered = cluster_traits(exported_traits_data_list)
+ slinked = slink(clustered)
traits_order = compute_traits_order(slinked)
ordered_traits_names = [
traits[idx]["trait_fullname"] for idx in traits_order]
@@ -200,6 +203,7 @@ def build_heatmap(traits_names, conn: Any):
return generate_clustered_heatmap(
process_traits_data_for_heatmap(
organised, traits_ids, chromosome_names),
+ clustered,
"single_heatmap_{}".format(random_string(10)),
y_axis=tuple(
ordered_traits_names[traits_ids[order]]
@@ -336,8 +340,9 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names):
return hdata
def generate_clustered_heatmap(
- data, image_filename_prefix, x_axis = None, x_label: str = "",
- y_axis = None, y_label: str = "", output_dir: str = TMPDIR,
+ data, clustering_data, image_filename_prefix, x_axis = None,
+ x_label: str = "", y_axis = None, y_label: str = "",
+ output_dir: str = TMPDIR,
colorscale = (
(0.0, '#3B3B3B'), (0.4999999999999999, '#ABABAB'),
(0.5, '#F5DE11'), (1.0, '#FF0D00'))):
@@ -345,21 +350,22 @@ def generate_clustered_heatmap(
Generate a dendrogram, and heatmaps for each chromosome, and put them all
into one plot.
"""
- num_cols = len(x_axis)
+ num_cols = 1 + len(x_axis)
fig = make_subplots(
rows=1,
cols=num_cols,
shared_yaxes="rows",
- # horizontal_spacing=(1 / (num_cols - 1)),
- subplot_titles=x_axis
- )
+ horizontal_spacing=0.001,
+ subplot_titles=["distance"] + x_axis,
+ figure = ff.create_dendrogram(
+ np.array(clustering_data), orientation="right", labels=y_axis))
hms = [go.Heatmap(
name=chromo,
y = y_axis,
z = data_array,
showscale=False) for chromo, data_array in zip(x_axis, data)]
- for col, hm in enumerate(hms):
- fig.add_trace(hm, row=1, col=(col + 1))
+ for i, hm in enumerate(hms):
+ fig.add_trace(hm, row=1, col=(i + 2))
fig.update_traces(
showlegend=False,