aboutsummaryrefslogtreecommitdiff
path: root/gn3/heatmaps.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/heatmaps.py')
-rw-r--r--gn3/heatmaps.py196
1 files changed, 74 insertions, 122 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index adbfbc6..bf9dfd1 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -14,6 +14,7 @@ from plotly.subplots import make_subplots # type: ignore
from gn3.settings import TMPDIR
from gn3.random import random_string
from gn3.computations.slink import slink
+from gn3.db.traits import export_trait_data
from gn3.computations.correlations2 import compute_correlation
from gn3.db.genotypes import (
build_genotype_file, load_genotype_samples)
@@ -26,72 +27,6 @@ from gn3.computations.qtlreaper import (
parse_reaper_main_results,
organise_reaper_main_results)
-def export_trait_data(
- trait_data: dict, samplelist: Sequence[str], dtype: str = "val",
- var_exists: bool = False, n_exists: bool = False):
- """
- Export data according to `samplelist`. Mostly used in calculating
- correlations.
-
- DESCRIPTION:
- Migrated from
- https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211
-
- PARAMETERS
- trait: (dict)
- The dictionary of key-value pairs representing a trait
- samplelist: (list)
- A list of sample names
- dtype: (str)
- ... verify what this is ...
- var_exists: (bool)
- A flag indicating existence of variance
- n_exists: (bool)
- A flag indicating existence of ndata
- """
- def __export_all_types(tdata, sample):
- sample_data = []
- if tdata[sample]["value"]:
- sample_data.append(tdata[sample]["value"])
- if var_exists:
- if tdata[sample]["variance"]:
- sample_data.append(tdata[sample]["variance"])
- else:
- sample_data.append(None)
- if n_exists:
- if tdata[sample]["ndata"]:
- sample_data.append(tdata[sample]["ndata"])
- else:
- sample_data.append(None)
- else:
- if var_exists and n_exists:
- sample_data += [None, None, None]
- elif var_exists or n_exists:
- sample_data += [None, None]
- else:
- sample_data.append(None)
-
- return tuple(sample_data)
-
- def __exporter(accumulator, sample):
- # pylint: disable=[R0911]
- if sample in trait_data["data"]:
- if dtype == "val":
- return accumulator + (trait_data["data"][sample]["value"], )
- if dtype == "var":
- return accumulator + (trait_data["data"][sample]["variance"], )
- if dtype == "N":
- return accumulator + (trait_data["data"][sample]["ndata"], )
- if dtype == "all":
- return accumulator + __export_all_types(trait_data["data"], sample)
- raise KeyError("Type `%s` is incorrect" % dtype)
- if var_exists and n_exists:
- return accumulator + (None, None, None)
- if var_exists or n_exists:
- return accumulator + (None, None)
- return accumulator + (None,)
-
- return reduce(__exporter, samplelist, tuple())
def trait_display_name(trait: Dict):
"""
@@ -168,7 +103,9 @@ def get_loci_names(
__get_trait_loci, [v[1] for v in organised.items()], {})
return tuple(loci_dict[_chr] for _chr in chromosome_names)
-def build_heatmap(traits_names, conn: Any):
+def build_heatmap(
+ traits_names: Sequence[str], conn: Any,
+ vertical: bool = False) -> go.Figure:
"""
heatmap function
@@ -220,17 +157,21 @@ def build_heatmap(traits_names, conn: Any):
zip(traits_ids,
[traits[idx]["trait_fullname"] for idx in traits_order]))
- return generate_clustered_heatmap(
+ return clustered_heatmap(
process_traits_data_for_heatmap(
organised, traits_ids, chromosome_names),
clustered,
- "single_heatmap_{}".format(random_string(10)),
- y_axis=tuple(
- ordered_traits_names[traits_ids[order]]
- for order in traits_order),
- y_label="Traits",
- x_axis=chromosome_names,
- x_label="Chromosomes",
+ x_axis={
+ "label": "Chromosomes",
+ "data": chromosome_names
+ },
+ y_axis={
+ "label": "Traits",
+ "data": tuple(
+ ordered_traits_names[traits_ids[order]]
+ for order in traits_order)
+ },
+ vertical=vertical,
loci_names=get_loci_names(organised, chromosome_names))
def compute_traits_order(slink_data, neworder: tuple = tuple()):
@@ -349,68 +290,81 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names):
for chr_name in chromosome_names]
return hdata
-def generate_clustered_heatmap(
- data, clustering_data, image_filename_prefix, x_axis=None,
- x_label: str = "", y_axis=None, y_label: str = "",
+def clustered_heatmap(
+ data: Sequence[Sequence[float]], clustering_data: Sequence[float],
+ x_axis,#: Dict[Union[str, int], Union[str, Sequence[str]]],
+ y_axis: Dict[str, Union[str, Sequence[str]]],
loci_names: Sequence[Sequence[str]] = tuple(),
- output_dir: str = TMPDIR,
- colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))):
+ vertical: bool = False,
+ colorscale: Sequence[Sequence[Union[float, str]]] = (
+ (0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))) -> go.Figure:
"""
Generate a dendrogram, and heatmaps for each chromosome, and put them all
into one plot.
"""
# pylint: disable=[R0913, R0914]
- num_cols = 1 + len(x_axis)
+ x_axis_data = x_axis["data"]
+ y_axis_data = y_axis["data"]
+ num_plots = 1 + len(x_axis_data)
fig = make_subplots(
- rows=1,
- cols=num_cols,
- shared_yaxes="rows",
+ rows=num_plots if vertical else 1,
+ cols=1 if vertical else num_plots,
+ shared_xaxes="columns" if vertical else False,
+ shared_yaxes=False if vertical else "rows",
+ vertical_spacing=0.010,
horizontal_spacing=0.001,
- subplot_titles=["distance"] + x_axis,
+ subplot_titles=["" if vertical else x_axis["label"]] + [
+ "Chromosome: {}".format(chromo) if vertical else chromo
+ for chromo in x_axis_data],#+ x_axis_data,
figure=ff.create_dendrogram(
- np.array(clustering_data), orientation="right", labels=y_axis))
+ np.array(clustering_data),
+ orientation="bottom" if vertical else "right",
+ labels=y_axis_data))
hms = [go.Heatmap(
name=chromo,
- x=loci,
- y=y_axis,
+ x=y_axis_data if vertical else loci,
+ y=loci if vertical else y_axis_data,
z=data_array,
+ transpose=vertical,
showscale=False)
for chromo, data_array, loci
- in zip(x_axis, data, loci_names)]
+ in zip(x_axis_data, data, loci_names)]
for i, heatmap in enumerate(hms):
- fig.add_trace(heatmap, row=1, col=(i + 2))
-
- fig.update_layout(
- {
- "width": 1500,
- "height": 800,
- "xaxis": {
+ fig.add_trace(
+ heatmap,
+ row=((i + 2) if vertical else 1),
+ col=(1 if vertical else (i + 2)))
+
+ axes_layouts = {
+ "{axis}axis{count}".format(
+ axis=("y" if vertical else "x"),
+ count=(i+1 if i > 0 else "")): {
"mirror": False,
- "showgrid": True,
- "title": x_label
- },
- "yaxis": {
- "title": y_label
+ "showticklabels": i == 0,
+ "ticks": "outside" if i == 0 else ""
}
- })
+ for i in range(num_plots)}
- x_axes_layouts = {
- "xaxis{}".format(i+1 if i > 0 else ""): {
- "mirror": False,
- "showticklabels": i == 0,
- "ticks": "outside" if i == 0 else ""
- }
- for i in range(num_cols)}
+ print("vertical?: {} ==> {}".format("T" if vertical else "F", axes_layouts))
- fig.update_layout(
- {
- "width": 4000,
- "height": 800,
- "yaxis": {
- "mirror": False,
- "ticks": ""
- },
- **x_axes_layouts})
+ fig.update_layout({
+ "width": 800 if vertical else 4000,
+ "height": 4000 if vertical else 800,
+ "{}axis".format("x" if vertical else "y"): {
+ "mirror": False,
+ "ticks": "",
+ "side": "top" if vertical else "left",
+ "title": y_axis["label"],
+ "tickangle": 90 if vertical else 0,
+ "ticklabelposition": "outside top" if vertical else "outside left"
+ },
+ "{}axis".format("y" if vertical else "x"): {
+ "mirror": False,
+ "showgrid": True,
+ "title": "Distance",
+ "side": "right" if vertical else "top"
+ },
+ **axes_layouts})
fig.update_traces(
showlegend=False,
colorscale=colorscale,
@@ -418,7 +372,5 @@ def generate_clustered_heatmap(
fig.update_traces(
showlegend=True,
showscale=True,
- selector={"name": x_axis[-1]})
- image_filename = "{}/{}.html".format(output_dir, image_filename_prefix)
- fig.write_html(image_filename)
- return image_filename, fig
+ selector={"name": x_axis_data[-1]})
+ return fig