aboutsummaryrefslogtreecommitdiff
path: root/gn3/heatmaps.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/heatmaps.py')
-rw-r--r--gn3/heatmaps.py129
1 files changed, 73 insertions, 56 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index 3b94e88..bf9dfd1 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -103,7 +103,9 @@ def get_loci_names(
__get_trait_loci, [v[1] for v in organised.items()], {})
return tuple(loci_dict[_chr] for _chr in chromosome_names)
-def build_heatmap(traits_names, conn: Any):
+def build_heatmap(
+ traits_names: Sequence[str], conn: Any,
+ vertical: bool = False) -> go.Figure:
"""
heatmap function
@@ -155,17 +157,21 @@ def build_heatmap(traits_names, conn: Any):
zip(traits_ids,
[traits[idx]["trait_fullname"] for idx in traits_order]))
- return generate_clustered_heatmap(
+ return clustered_heatmap(
process_traits_data_for_heatmap(
organised, traits_ids, chromosome_names),
clustered,
- "single_heatmap_{}".format(random_string(10)),
- y_axis=tuple(
- ordered_traits_names[traits_ids[order]]
- for order in traits_order),
- y_label="Traits",
- x_axis=chromosome_names,
- x_label="Chromosomes",
+ x_axis={
+ "label": "Chromosomes",
+ "data": chromosome_names
+ },
+ y_axis={
+ "label": "Traits",
+ "data": tuple(
+ ordered_traits_names[traits_ids[order]]
+ for order in traits_order)
+ },
+ vertical=vertical,
loci_names=get_loci_names(organised, chromosome_names))
def compute_traits_order(slink_data, neworder: tuple = tuple()):
@@ -284,68 +290,81 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names):
for chr_name in chromosome_names]
return hdata
-def generate_clustered_heatmap(
- data, clustering_data, image_filename_prefix, x_axis=None,
- x_label: str = "", y_axis=None, y_label: str = "",
+def clustered_heatmap(
+ data: Sequence[Sequence[float]], clustering_data: Sequence[float],
+ x_axis,#: Dict[Union[str, int], Union[str, Sequence[str]]],
+ y_axis: Dict[str, Union[str, Sequence[str]]],
loci_names: Sequence[Sequence[str]] = tuple(),
- output_dir: str = TMPDIR,
- colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))):
+ vertical: bool = False,
+ colorscale: Sequence[Sequence[Union[float, str]]] = (
+ (0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))) -> go.Figure:
"""
Generate a dendrogram, and heatmaps for each chromosome, and put them all
into one plot.
"""
# pylint: disable=[R0913, R0914]
- num_cols = 1 + len(x_axis)
+ x_axis_data = x_axis["data"]
+ y_axis_data = y_axis["data"]
+ num_plots = 1 + len(x_axis_data)
fig = make_subplots(
- rows=1,
- cols=num_cols,
- shared_yaxes="rows",
+ rows=num_plots if vertical else 1,
+ cols=1 if vertical else num_plots,
+ shared_xaxes="columns" if vertical else False,
+ shared_yaxes=False if vertical else "rows",
+ vertical_spacing=0.010,
horizontal_spacing=0.001,
- subplot_titles=["distance"] + x_axis,
+ subplot_titles=["" if vertical else x_axis["label"]] + [
+ "Chromosome: {}".format(chromo) if vertical else chromo
+ for chromo in x_axis_data],#+ x_axis_data,
figure=ff.create_dendrogram(
- np.array(clustering_data), orientation="right", labels=y_axis))
+ np.array(clustering_data),
+ orientation="bottom" if vertical else "right",
+ labels=y_axis_data))
hms = [go.Heatmap(
name=chromo,
- x=loci,
- y=y_axis,
+ x=y_axis_data if vertical else loci,
+ y=loci if vertical else y_axis_data,
z=data_array,
+ transpose=vertical,
showscale=False)
for chromo, data_array, loci
- in zip(x_axis, data, loci_names)]
+ in zip(x_axis_data, data, loci_names)]
for i, heatmap in enumerate(hms):
- fig.add_trace(heatmap, row=1, col=(i + 2))
-
- fig.update_layout(
- {
- "width": 1500,
- "height": 800,
- "xaxis": {
+ fig.add_trace(
+ heatmap,
+ row=((i + 2) if vertical else 1),
+ col=(1 if vertical else (i + 2)))
+
+ axes_layouts = {
+ "{axis}axis{count}".format(
+ axis=("y" if vertical else "x"),
+ count=(i+1 if i > 0 else "")): {
"mirror": False,
- "showgrid": True,
- "title": x_label
- },
- "yaxis": {
- "title": y_label
+ "showticklabels": i == 0,
+ "ticks": "outside" if i == 0 else ""
}
- })
+ for i in range(num_plots)}
- x_axes_layouts = {
- "xaxis{}".format(i+1 if i > 0 else ""): {
- "mirror": False,
- "showticklabels": i == 0,
- "ticks": "outside" if i == 0 else ""
- }
- for i in range(num_cols)}
+ print("vertical?: {} ==> {}".format("T" if vertical else "F", axes_layouts))
- fig.update_layout(
- {
- "width": 4000,
- "height": 800,
- "yaxis": {
- "mirror": False,
- "ticks": ""
- },
- **x_axes_layouts})
+ fig.update_layout({
+ "width": 800 if vertical else 4000,
+ "height": 4000 if vertical else 800,
+ "{}axis".format("x" if vertical else "y"): {
+ "mirror": False,
+ "ticks": "",
+ "side": "top" if vertical else "left",
+ "title": y_axis["label"],
+ "tickangle": 90 if vertical else 0,
+ "ticklabelposition": "outside top" if vertical else "outside left"
+ },
+ "{}axis".format("y" if vertical else "x"): {
+ "mirror": False,
+ "showgrid": True,
+ "title": "Distance",
+ "side": "right" if vertical else "top"
+ },
+ **axes_layouts})
fig.update_traces(
showlegend=False,
colorscale=colorscale,
@@ -353,7 +372,5 @@ def generate_clustered_heatmap(
fig.update_traces(
showlegend=True,
showscale=True,
- selector={"name": x_axis[-1]})
- image_filename = "{}/{}.html".format(output_dir, image_filename_prefix)
- fig.write_html(image_filename)
- return image_filename, fig
+ selector={"name": x_axis_data[-1]})
+ return fig