diff options
Diffstat (limited to 'gn3/heatmaps.py')
-rw-r--r-- | gn3/heatmaps.py | 196 |
1 files changed, 74 insertions, 122 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index adbfbc6..bf9dfd1 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -14,6 +14,7 @@ from plotly.subplots import make_subplots # type: ignore from gn3.settings import TMPDIR from gn3.random import random_string from gn3.computations.slink import slink +from gn3.db.traits import export_trait_data from gn3.computations.correlations2 import compute_correlation from gn3.db.genotypes import ( build_genotype_file, load_genotype_samples) @@ -26,72 +27,6 @@ from gn3.computations.qtlreaper import ( parse_reaper_main_results, organise_reaper_main_results) -def export_trait_data( - trait_data: dict, samplelist: Sequence[str], dtype: str = "val", - var_exists: bool = False, n_exists: bool = False): - """ - Export data according to `samplelist`. Mostly used in calculating - correlations. - - DESCRIPTION: - Migrated from - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211 - - PARAMETERS - trait: (dict) - The dictionary of key-value pairs representing a trait - samplelist: (list) - A list of sample names - dtype: (str) - ... verify what this is ... - var_exists: (bool) - A flag indicating existence of variance - n_exists: (bool) - A flag indicating existence of ndata - """ - def __export_all_types(tdata, sample): - sample_data = [] - if tdata[sample]["value"]: - sample_data.append(tdata[sample]["value"]) - if var_exists: - if tdata[sample]["variance"]: - sample_data.append(tdata[sample]["variance"]) - else: - sample_data.append(None) - if n_exists: - if tdata[sample]["ndata"]: - sample_data.append(tdata[sample]["ndata"]) - else: - sample_data.append(None) - else: - if var_exists and n_exists: - sample_data += [None, None, None] - elif var_exists or n_exists: - sample_data += [None, None] - else: - sample_data.append(None) - - return tuple(sample_data) - - def __exporter(accumulator, sample): - # pylint: disable=[R0911] - if sample in trait_data["data"]: - if dtype == "val": - return accumulator + (trait_data["data"][sample]["value"], ) - if dtype == "var": - return accumulator + (trait_data["data"][sample]["variance"], ) - if dtype == "N": - return accumulator + (trait_data["data"][sample]["ndata"], ) - if dtype == "all": - return accumulator + __export_all_types(trait_data["data"], sample) - raise KeyError("Type `%s` is incorrect" % dtype) - if var_exists and n_exists: - return accumulator + (None, None, None) - if var_exists or n_exists: - return accumulator + (None, None) - return accumulator + (None,) - - return reduce(__exporter, samplelist, tuple()) def trait_display_name(trait: Dict): """ @@ -168,7 +103,9 @@ def get_loci_names( __get_trait_loci, [v[1] for v in organised.items()], {}) return tuple(loci_dict[_chr] for _chr in chromosome_names) -def build_heatmap(traits_names, conn: Any): +def build_heatmap( + traits_names: Sequence[str], conn: Any, + vertical: bool = False) -> go.Figure: """ heatmap function @@ -220,17 +157,21 @@ def build_heatmap(traits_names, conn: Any): zip(traits_ids, [traits[idx]["trait_fullname"] for idx in traits_order])) - return generate_clustered_heatmap( + return clustered_heatmap( process_traits_data_for_heatmap( organised, traits_ids, chromosome_names), clustered, - "single_heatmap_{}".format(random_string(10)), - y_axis=tuple( - ordered_traits_names[traits_ids[order]] - for order in traits_order), - y_label="Traits", - x_axis=chromosome_names, - x_label="Chromosomes", + x_axis={ + "label": "Chromosomes", + "data": chromosome_names + }, + y_axis={ + "label": "Traits", + "data": tuple( + ordered_traits_names[traits_ids[order]] + for order in traits_order) + }, + vertical=vertical, loci_names=get_loci_names(organised, chromosome_names)) def compute_traits_order(slink_data, neworder: tuple = tuple()): @@ -349,68 +290,81 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names): for chr_name in chromosome_names] return hdata -def generate_clustered_heatmap( - data, clustering_data, image_filename_prefix, x_axis=None, - x_label: str = "", y_axis=None, y_label: str = "", +def clustered_heatmap( + data: Sequence[Sequence[float]], clustering_data: Sequence[float], + x_axis,#: Dict[Union[str, int], Union[str, Sequence[str]]], + y_axis: Dict[str, Union[str, Sequence[str]]], loci_names: Sequence[Sequence[str]] = tuple(), - output_dir: str = TMPDIR, - colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))): + vertical: bool = False, + colorscale: Sequence[Sequence[Union[float, str]]] = ( + (0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))) -> go.Figure: """ Generate a dendrogram, and heatmaps for each chromosome, and put them all into one plot. """ # pylint: disable=[R0913, R0914] - num_cols = 1 + len(x_axis) + x_axis_data = x_axis["data"] + y_axis_data = y_axis["data"] + num_plots = 1 + len(x_axis_data) fig = make_subplots( - rows=1, - cols=num_cols, - shared_yaxes="rows", + rows=num_plots if vertical else 1, + cols=1 if vertical else num_plots, + shared_xaxes="columns" if vertical else False, + shared_yaxes=False if vertical else "rows", + vertical_spacing=0.010, horizontal_spacing=0.001, - subplot_titles=["distance"] + x_axis, + subplot_titles=["" if vertical else x_axis["label"]] + [ + "Chromosome: {}".format(chromo) if vertical else chromo + for chromo in x_axis_data],#+ x_axis_data, figure=ff.create_dendrogram( - np.array(clustering_data), orientation="right", labels=y_axis)) + np.array(clustering_data), + orientation="bottom" if vertical else "right", + labels=y_axis_data)) hms = [go.Heatmap( name=chromo, - x=loci, - y=y_axis, + x=y_axis_data if vertical else loci, + y=loci if vertical else y_axis_data, z=data_array, + transpose=vertical, showscale=False) for chromo, data_array, loci - in zip(x_axis, data, loci_names)] + in zip(x_axis_data, data, loci_names)] for i, heatmap in enumerate(hms): - fig.add_trace(heatmap, row=1, col=(i + 2)) - - fig.update_layout( - { - "width": 1500, - "height": 800, - "xaxis": { + fig.add_trace( + heatmap, + row=((i + 2) if vertical else 1), + col=(1 if vertical else (i + 2))) + + axes_layouts = { + "{axis}axis{count}".format( + axis=("y" if vertical else "x"), + count=(i+1 if i > 0 else "")): { "mirror": False, - "showgrid": True, - "title": x_label - }, - "yaxis": { - "title": y_label + "showticklabels": i == 0, + "ticks": "outside" if i == 0 else "" } - }) + for i in range(num_plots)} - x_axes_layouts = { - "xaxis{}".format(i+1 if i > 0 else ""): { - "mirror": False, - "showticklabels": i == 0, - "ticks": "outside" if i == 0 else "" - } - for i in range(num_cols)} + print("vertical?: {} ==> {}".format("T" if vertical else "F", axes_layouts)) - fig.update_layout( - { - "width": 4000, - "height": 800, - "yaxis": { - "mirror": False, - "ticks": "" - }, - **x_axes_layouts}) + fig.update_layout({ + "width": 800 if vertical else 4000, + "height": 4000 if vertical else 800, + "{}axis".format("x" if vertical else "y"): { + "mirror": False, + "ticks": "", + "side": "top" if vertical else "left", + "title": y_axis["label"], + "tickangle": 90 if vertical else 0, + "ticklabelposition": "outside top" if vertical else "outside left" + }, + "{}axis".format("y" if vertical else "x"): { + "mirror": False, + "showgrid": True, + "title": "Distance", + "side": "right" if vertical else "top" + }, + **axes_layouts}) fig.update_traces( showlegend=False, colorscale=colorscale, @@ -418,7 +372,5 @@ def generate_clustered_heatmap( fig.update_traces( showlegend=True, showscale=True, - selector={"name": x_axis[-1]}) - image_filename = "{}/{}.html".format(output_dir, image_filename_prefix) - fig.write_html(image_filename) - return image_filename, fig + selector={"name": x_axis_data[-1]}) + return fig |