diff options
author | Frederick Muriuki Muriithi | 2021-10-04 06:00:11 +0300 |
---|---|---|
committer | BonfaceKilz | 2021-10-19 10:12:51 +0300 |
commit | 0797e53220046d8f36a45d8b09b395b156d8fde7 (patch) | |
tree | 1a83439494860bc1235c285dec676e4752b7fe04 /gn3/heatmaps.py | |
parent | fe39bccc23186d0a0f0b51a792d4577aaca88bd1 (diff) | |
download | genenetwork3-0797e53220046d8f36a45d8b09b395b156d8fde7.tar.gz |
Add typing. Simplify arguments.
Issue:
https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/non-clustered-heatmaps-and-flipping.gmi
* Add type-hints to the functions
* Merge the axis data and labels to simpler dict arguments to reduce number of
parameters to the function.
Diffstat (limited to 'gn3/heatmaps.py')
-rw-r--r-- | gn3/heatmaps.py | 52 |
1 files changed, 28 insertions, 24 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index 42231bf..00f4353 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -168,7 +168,7 @@ def get_loci_names( __get_trait_loci, [v[1] for v in organised.items()], {}) return tuple(loci_dict[_chr] for _chr in chromosome_names) -def build_heatmap(traits_names, conn: Any): +def build_heatmap(traits_names: Sequence[str], conn: Any) -> go.Figure: """ heatmap function @@ -220,16 +220,20 @@ def build_heatmap(traits_names, conn: Any): zip(traits_ids, [traits[idx]["trait_fullname"] for idx in traits_order])) - return generate_clustered_heatmap( + return clustered_heatmap( process_traits_data_for_heatmap( organised, traits_ids, chromosome_names), clustered, - y_axis=tuple( - ordered_traits_names[traits_ids[order]] - for order in traits_order), - y_label="Traits", - x_axis=chromosome_names, - x_label="Chromosomes", + x_axis={ + "label": "Chromosomes", + "data": chromosome_names + }, + y_axis={ + "label": "Traits", + "data": tuple( + ordered_traits_names[traits_ids[order]] + for order in traits_order) + }, loci_names=get_loci_names(organised, chromosome_names)) def compute_traits_order(slink_data, neworder: tuple = tuple()): @@ -348,37 +352,37 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names): for chr_name in chromosome_names] return hdata -def generate_clustered_heatmap( - data, clustering_data, image_filename_prefix, x_axis=None, - x_label: str = "", y_axis=None, y_label: str = "", +def clustered_heatmap( + data: Sequence[Sequence[float]], clustering_data: Sequence[float], + x_axis,#: Dict[Union[str, int], Union[str, Sequence[str]]], + y_axis: Dict[str, Union[str, Sequence[str]]], loci_names: Sequence[Sequence[str]] = tuple(), - output_dir: str = TMPDIR, - colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))): - data, clustering_data, x_axis=None, x_label: str = "", y_axis=None, - y_label: str = "", loci_names: Sequence[Sequence[str]] = tuple(), - colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))): + colorscale: Sequence[Sequence[Union[float, str]]] = ( + (0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))) -> go.Figure: """ Generate a dendrogram, and heatmaps for each chromosome, and put them all into one plot. """ # pylint: disable=[R0913, R0914] - num_cols = 1 + len(x_axis) + x_axis_data = x_axis["data"] + y_axis_data = y_axis["data"] + num_cols = 1 + len(x_axis_data) fig = make_subplots( rows=1, cols=num_cols, shared_yaxes="rows", horizontal_spacing=0.001, - subplot_titles=["distance"] + x_axis, + subplot_titles=["distance"] + x_axis_data, figure=ff.create_dendrogram( - np.array(clustering_data), orientation="right", labels=y_axis)) + np.array(clustering_data), orientation="right", labels=y_axis_data)) hms = [go.Heatmap( name=chromo, x=loci, - y=y_axis, + y=y_axis_data, z=data_array, showscale=False) for chromo, data_array, loci - in zip(x_axis, data, loci_names)] + in zip(x_axis_data, data, loci_names)] for i, heatmap in enumerate(hms): fig.add_trace(heatmap, row=1, col=(i + 2)) @@ -389,10 +393,10 @@ def generate_clustered_heatmap( "xaxis": { "mirror": False, "showgrid": True, - "title": x_label + "title": x_axis["label"] }, "yaxis": { - "title": y_label + "title": y_axis["label"] } }) @@ -420,5 +424,5 @@ def generate_clustered_heatmap( fig.update_traces( showlegend=True, showscale=True, - selector={"name": x_axis[-1]}) + selector={"name": x_axis_data[-1]}) return fig |