aboutsummaryrefslogtreecommitdiff
path: root/gn3
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2021-10-04 12:03:42 +0300
committerBonfaceKilz2021-10-19 10:12:51 +0300
commitaeefaad0629ca29e81ac3f0dbe882d7bf09b8711 (patch)
tree95c40ad15b0fc77da90943c514ded76d503f1afd /gn3
parentf71c0d5b04a2bb504acf306be11705ae0515aa14 (diff)
downloadgenenetwork3-aeefaad0629ca29e81ac3f0dbe882d7bf09b8711.tar.gz
Enable vertical orientation of heatmaps
Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/non-clustered-heatmaps-and-flipping.gmi * Update the code to enable the generation of the heatmap in both the horizontal and vertical orientations.
Diffstat (limited to 'gn3')
-rw-r--r--gn3/heatmaps.py91
1 files changed, 52 insertions, 39 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index 7e7113d..ff65652 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -168,7 +168,9 @@ def get_loci_names(
__get_trait_loci, [v[1] for v in organised.items()], {})
return tuple(loci_dict[_chr] for _chr in chromosome_names)
-def build_heatmap(traits_names: Sequence[str], conn: Any) -> go.Figure:
+def build_heatmap(
+ traits_names: Sequence[str], conn: Any,
+ vertical: bool = False) -> go.Figure:
"""
heatmap function
@@ -234,6 +236,7 @@ def build_heatmap(traits_names: Sequence[str], conn: Any) -> go.Figure:
ordered_traits_names[traits_ids[order]]
for order in traits_order)
},
+ vertical=vertical,
loci_names=get_loci_names(organised, chromosome_names))
def compute_traits_order(slink_data, neworder: tuple = tuple()):
@@ -357,6 +360,7 @@ def clustered_heatmap(
x_axis,#: Dict[Union[str, int], Union[str, Sequence[str]]],
y_axis: Dict[str, Union[str, Sequence[str]]],
loci_names: Sequence[Sequence[str]] = tuple(),
+ vertical: bool = False,
colorscale: Sequence[Sequence[Union[float, str]]] = (
(0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))) -> go.Figure:
"""
@@ -366,57 +370,66 @@ def clustered_heatmap(
# pylint: disable=[R0913, R0914]
x_axis_data = x_axis["data"]
y_axis_data = y_axis["data"]
- num_cols = 1 + len(x_axis_data)
+ num_plots = 1 + len(x_axis_data)
fig = make_subplots(
- rows=1,
- cols=num_cols,
- shared_yaxes="rows",
+ rows=num_plots if vertical else 1,
+ cols=1 if vertical else num_plots,
+ shared_xaxes = "columns" if vertical else False,
+ shared_yaxes = False if vertical else "rows",
+ vertical_spacing=0.010,
horizontal_spacing=0.001,
- subplot_titles=[x_axis["label"]] + x_axis_data,
+ subplot_titles=["" if vertical else x_axis["label"]] + [
+ "Chromosome: {}".format(chromo) if vertical else chromo
+ for chromo in x_axis_data],#+ x_axis_data,
figure=ff.create_dendrogram(
- np.array(clustering_data), orientation="right", labels=y_axis_data))
+ np.array(clustering_data),
+ orientation="bottom" if vertical else "right",
+ labels=y_axis_data))
hms = [go.Heatmap(
name=chromo,
- x=loci,
- y=y_axis_data,
+ x=y_axis_data if vertical else loci,
+ y=loci if vertical else y_axis_data,
z=data_array,
+ transpose=vertical,
showscale=False)
for chromo, data_array, loci
in zip(x_axis_data, data, loci_names)]
for i, heatmap in enumerate(hms):
- fig.add_trace(heatmap, row=1, col=(i + 2))
-
- fig.update_layout(
- {
- "width": 1500,
- "height": 800,
- "xaxis": {
+ fig.add_trace(
+ heatmap,
+ row=((i + 2) if vertical else 1),
+ col=(1 if vertical else (i + 2)))
+
+ axes_layouts = {
+ "{axis}axis{count}".format(
+ axis=("y" if vertical else "x"),
+ count=(i+1 if i > 0 else "")): {
"mirror": False,
- "showgrid": True,
- "title": "Distance"
- },
- "yaxis": {
- "title": y_axis["label"]
- }
- })
-
- x_axes_layouts = {
- "xaxis{}".format(i+1 if i > 0 else ""): {
- "mirror": False,
- "showticklabels": i == 0,
- "ticks": "outside" if i == 0 else ""
+ "showticklabels": i == 0,
+ "ticks": "outside" if i == 0 else ""
}
- for i in range(num_cols)}
+ for i in range(num_plots)}
- fig.update_layout(
- {
- "width": 4000,
- "height": 800,
- "yaxis": {
- "mirror": False,
- "ticks": ""
- },
- **x_axes_layouts})
+ print("vertical?: {} ==> {}".format("T" if vertical else "F", axes_layouts))
+
+ fig.update_layout({
+ "width": 800 if vertical else 4000,
+ "height": 4000 if vertical else 800,
+ "{}axis".format("x" if vertical else "y"): {
+ "mirror": False,
+ "ticks": "",
+ "side": "top" if vertical else "left",
+ "title": y_axis["label"],
+ "tickangle": 90 if vertical else 0,
+ "ticklabelposition": "outside top" if vertical else "outside left"
+ },
+ "{}axis".format("y" if vertical else "x"): {
+ "mirror": False,
+ "showgrid": True,
+ "title": "Distance",
+ "side": "right" if vertical else "top"
+ },
+ **axes_layouts})
fig.update_traces(
showlegend=False,
colorscale=colorscale,