aboutsummaryrefslogtreecommitdiff
path: root/gn3/heatmaps.py
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2021-10-04 06:00:11 +0300
committerBonfaceKilz2021-10-19 10:12:51 +0300
commit0797e53220046d8f36a45d8b09b395b156d8fde7 (patch)
tree1a83439494860bc1235c285dec676e4752b7fe04 /gn3/heatmaps.py
parentfe39bccc23186d0a0f0b51a792d4577aaca88bd1 (diff)
downloadgenenetwork3-0797e53220046d8f36a45d8b09b395b156d8fde7.tar.gz
Add typing. Simplify arguments.
Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/non-clustered-heatmaps-and-flipping.gmi * Add type-hints to the functions * Merge the axis data and labels to simpler dict arguments to reduce number of parameters to the function.
Diffstat (limited to 'gn3/heatmaps.py')
-rw-r--r--gn3/heatmaps.py52
1 files changed, 28 insertions, 24 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index 42231bf..00f4353 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -168,7 +168,7 @@ def get_loci_names(
__get_trait_loci, [v[1] for v in organised.items()], {})
return tuple(loci_dict[_chr] for _chr in chromosome_names)
-def build_heatmap(traits_names, conn: Any):
+def build_heatmap(traits_names: Sequence[str], conn: Any) -> go.Figure:
"""
heatmap function
@@ -220,16 +220,20 @@ def build_heatmap(traits_names, conn: Any):
zip(traits_ids,
[traits[idx]["trait_fullname"] for idx in traits_order]))
- return generate_clustered_heatmap(
+ return clustered_heatmap(
process_traits_data_for_heatmap(
organised, traits_ids, chromosome_names),
clustered,
- y_axis=tuple(
- ordered_traits_names[traits_ids[order]]
- for order in traits_order),
- y_label="Traits",
- x_axis=chromosome_names,
- x_label="Chromosomes",
+ x_axis={
+ "label": "Chromosomes",
+ "data": chromosome_names
+ },
+ y_axis={
+ "label": "Traits",
+ "data": tuple(
+ ordered_traits_names[traits_ids[order]]
+ for order in traits_order)
+ },
loci_names=get_loci_names(organised, chromosome_names))
def compute_traits_order(slink_data, neworder: tuple = tuple()):
@@ -348,37 +352,37 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names):
for chr_name in chromosome_names]
return hdata
-def generate_clustered_heatmap(
- data, clustering_data, image_filename_prefix, x_axis=None,
- x_label: str = "", y_axis=None, y_label: str = "",
+def clustered_heatmap(
+ data: Sequence[Sequence[float]], clustering_data: Sequence[float],
+ x_axis,#: Dict[Union[str, int], Union[str, Sequence[str]]],
+ y_axis: Dict[str, Union[str, Sequence[str]]],
loci_names: Sequence[Sequence[str]] = tuple(),
- output_dir: str = TMPDIR,
- colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))):
- data, clustering_data, x_axis=None, x_label: str = "", y_axis=None,
- y_label: str = "", loci_names: Sequence[Sequence[str]] = tuple(),
- colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))):
+ colorscale: Sequence[Sequence[Union[float, str]]] = (
+ (0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))) -> go.Figure:
"""
Generate a dendrogram, and heatmaps for each chromosome, and put them all
into one plot.
"""
# pylint: disable=[R0913, R0914]
- num_cols = 1 + len(x_axis)
+ x_axis_data = x_axis["data"]
+ y_axis_data = y_axis["data"]
+ num_cols = 1 + len(x_axis_data)
fig = make_subplots(
rows=1,
cols=num_cols,
shared_yaxes="rows",
horizontal_spacing=0.001,
- subplot_titles=["distance"] + x_axis,
+ subplot_titles=["distance"] + x_axis_data,
figure=ff.create_dendrogram(
- np.array(clustering_data), orientation="right", labels=y_axis))
+ np.array(clustering_data), orientation="right", labels=y_axis_data))
hms = [go.Heatmap(
name=chromo,
x=loci,
- y=y_axis,
+ y=y_axis_data,
z=data_array,
showscale=False)
for chromo, data_array, loci
- in zip(x_axis, data, loci_names)]
+ in zip(x_axis_data, data, loci_names)]
for i, heatmap in enumerate(hms):
fig.add_trace(heatmap, row=1, col=(i + 2))
@@ -389,10 +393,10 @@ def generate_clustered_heatmap(
"xaxis": {
"mirror": False,
"showgrid": True,
- "title": x_label
+ "title": x_axis["label"]
},
"yaxis": {
- "title": y_label
+ "title": y_axis["label"]
}
})
@@ -420,5 +424,5 @@ def generate_clustered_heatmap(
fig.update_traces(
showlegend=True,
showscale=True,
- selector={"name": x_axis[-1]})
+ selector={"name": x_axis_data[-1]})
return fig