about summary refs log tree commit diff
path: root/gn3/heatmaps.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/heatmaps.py')
-rw-r--r--gn3/heatmaps.py196
1 files changed, 74 insertions, 122 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index adbfbc6..bf9dfd1 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -14,6 +14,7 @@ from plotly.subplots import make_subplots # type: ignore
 from gn3.settings import TMPDIR
 from gn3.random import random_string
 from gn3.computations.slink import slink
+from gn3.db.traits import export_trait_data
 from gn3.computations.correlations2 import compute_correlation
 from gn3.db.genotypes import (
     build_genotype_file, load_genotype_samples)
@@ -26,72 +27,6 @@ from gn3.computations.qtlreaper import (
     parse_reaper_main_results,
     organise_reaper_main_results)
 
-def export_trait_data(
-        trait_data: dict, samplelist: Sequence[str], dtype: str = "val",
-        var_exists: bool = False, n_exists: bool = False):
-    """
-    Export data according to `samplelist`. Mostly used in calculating
-    correlations.
-
-    DESCRIPTION:
-    Migrated from
-    https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211
-
-    PARAMETERS
-    trait: (dict)
-      The dictionary of key-value pairs representing a trait
-    samplelist: (list)
-      A list of sample names
-    dtype: (str)
-      ... verify what this is ...
-    var_exists: (bool)
-      A flag indicating existence of variance
-    n_exists: (bool)
-      A flag indicating existence of ndata
-    """
-    def __export_all_types(tdata, sample):
-        sample_data = []
-        if tdata[sample]["value"]:
-            sample_data.append(tdata[sample]["value"])
-            if var_exists:
-                if tdata[sample]["variance"]:
-                    sample_data.append(tdata[sample]["variance"])
-                else:
-                    sample_data.append(None)
-            if n_exists:
-                if tdata[sample]["ndata"]:
-                    sample_data.append(tdata[sample]["ndata"])
-                else:
-                    sample_data.append(None)
-        else:
-            if var_exists and n_exists:
-                sample_data += [None, None, None]
-            elif var_exists or n_exists:
-                sample_data += [None, None]
-            else:
-                sample_data.append(None)
-
-        return tuple(sample_data)
-
-    def __exporter(accumulator, sample):
-        # pylint: disable=[R0911]
-        if sample in trait_data["data"]:
-            if dtype == "val":
-                return accumulator + (trait_data["data"][sample]["value"], )
-            if dtype == "var":
-                return accumulator + (trait_data["data"][sample]["variance"], )
-            if dtype == "N":
-                return accumulator + (trait_data["data"][sample]["ndata"], )
-            if dtype == "all":
-                return accumulator + __export_all_types(trait_data["data"], sample)
-            raise KeyError("Type `%s` is incorrect" % dtype)
-        if var_exists and n_exists:
-            return accumulator + (None, None, None)
-        if var_exists or n_exists:
-            return accumulator + (None, None)
-        return accumulator + (None,)
-
-    return reduce(__exporter, samplelist, tuple())
 
 def trait_display_name(trait: Dict):
     """
@@ -168,7 +103,9 @@ def get_loci_names(
         __get_trait_loci, [v[1] for v in organised.items()], {})
     return tuple(loci_dict[_chr] for _chr in chromosome_names)
 
-def build_heatmap(traits_names, conn: Any):
+def build_heatmap(
+        traits_names: Sequence[str], conn: Any,
+        vertical: bool = False) -> go.Figure:
     """
     heatmap function
 
@@ -220,17 +157,21 @@ def build_heatmap(traits_names, conn: Any):
         zip(traits_ids,
             [traits[idx]["trait_fullname"] for idx in traits_order]))
 
-    return generate_clustered_heatmap(
+    return clustered_heatmap(
         process_traits_data_for_heatmap(
             organised, traits_ids, chromosome_names),
         clustered,
-        "single_heatmap_{}".format(random_string(10)),
-        y_axis=tuple(
-            ordered_traits_names[traits_ids[order]]
-            for order in traits_order),
-        y_label="Traits",
-        x_axis=chromosome_names,
-        x_label="Chromosomes",
+        x_axis={
+            "label": "Chromosomes",
+            "data": chromosome_names
+        },
+        y_axis={
+            "label": "Traits",
+            "data": tuple(
+                ordered_traits_names[traits_ids[order]]
+                for order in traits_order)
+        },
+        vertical=vertical,
         loci_names=get_loci_names(organised, chromosome_names))
 
 def compute_traits_order(slink_data, neworder: tuple = tuple()):
@@ -349,68 +290,81 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names):
         for chr_name in chromosome_names]
     return hdata
 
-def generate_clustered_heatmap(
-        data, clustering_data, image_filename_prefix, x_axis=None,
-        x_label: str = "", y_axis=None, y_label: str = "",
+def clustered_heatmap(
+        data: Sequence[Sequence[float]], clustering_data: Sequence[float],
+        x_axis,#: Dict[Union[str, int], Union[str, Sequence[str]]],
+        y_axis: Dict[str, Union[str, Sequence[str]]],
         loci_names: Sequence[Sequence[str]] = tuple(),
-        output_dir: str = TMPDIR,
-        colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))):
+        vertical: bool = False,
+        colorscale: Sequence[Sequence[Union[float, str]]] = (
+            (0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))) -> go.Figure:
     """
     Generate a dendrogram, and heatmaps for each chromosome, and put them all
     into one plot.
     """
     # pylint: disable=[R0913, R0914]
-    num_cols = 1 + len(x_axis)
+    x_axis_data = x_axis["data"]
+    y_axis_data = y_axis["data"]
+    num_plots = 1 + len(x_axis_data)
     fig = make_subplots(
-        rows=1,
-        cols=num_cols,
-        shared_yaxes="rows",
+        rows=num_plots if vertical else 1,
+        cols=1 if vertical else num_plots,
+        shared_xaxes="columns" if vertical else False,
+        shared_yaxes=False if vertical else "rows",
+        vertical_spacing=0.010,
         horizontal_spacing=0.001,
-        subplot_titles=["distance"] + x_axis,
+        subplot_titles=["" if vertical else x_axis["label"]] + [
+            "Chromosome: {}".format(chromo) if vertical else chromo
+            for chromo in x_axis_data],#+ x_axis_data,
         figure=ff.create_dendrogram(
-            np.array(clustering_data), orientation="right", labels=y_axis))
+            np.array(clustering_data),
+            orientation="bottom" if vertical else "right",
+            labels=y_axis_data))
     hms = [go.Heatmap(
         name=chromo,
-        x=loci,
-        y=y_axis,
+        x=y_axis_data if vertical else loci,
+        y=loci if vertical else y_axis_data,
         z=data_array,
+        transpose=vertical,
         showscale=False)
            for chromo, data_array, loci
-           in zip(x_axis, data, loci_names)]
+           in zip(x_axis_data, data, loci_names)]
     for i, heatmap in enumerate(hms):
-        fig.add_trace(heatmap, row=1, col=(i + 2))
-
-    fig.update_layout(
-        {
-            "width": 1500,
-            "height": 800,
-            "xaxis": {
+        fig.add_trace(
+            heatmap,
+            row=((i + 2) if vertical else 1),
+            col=(1 if vertical else (i + 2)))
+
+    axes_layouts = {
+        "{axis}axis{count}".format(
+            axis=("y" if vertical else "x"),
+            count=(i+1 if i > 0 else "")): {
                 "mirror": False,
-                "showgrid": True,
-                "title": x_label
-            },
-            "yaxis": {
-                "title": y_label
+                "showticklabels": i == 0,
+                "ticks": "outside" if i == 0 else ""
             }
-        })
+        for i in range(num_plots)}
 
-    x_axes_layouts = {
-        "xaxis{}".format(i+1 if i > 0 else ""): {
-            "mirror": False,
-            "showticklabels": i == 0,
-            "ticks": "outside" if i == 0 else ""
-        }
-        for i in range(num_cols)}
+    print("vertical?: {} ==> {}".format("T" if vertical else "F", axes_layouts))
 
-    fig.update_layout(
-        {
-            "width": 4000,
-            "height": 800,
-            "yaxis": {
-                "mirror": False,
-                "ticks": ""
-            },
-            **x_axes_layouts})
+    fig.update_layout({
+        "width": 800 if vertical else 4000,
+        "height": 4000 if vertical else 800,
+        "{}axis".format("x" if vertical else "y"): {
+            "mirror": False,
+            "ticks": "",
+            "side": "top" if vertical else "left",
+            "title": y_axis["label"],
+            "tickangle": 90 if vertical else 0,
+            "ticklabelposition": "outside top" if vertical else "outside left"
+        },
+        "{}axis".format("y" if vertical else "x"): {
+            "mirror": False,
+            "showgrid": True,
+            "title": "Distance",
+            "side": "right" if vertical else "top"
+        },
+        **axes_layouts})
     fig.update_traces(
         showlegend=False,
         colorscale=colorscale,
@@ -418,7 +372,5 @@ def generate_clustered_heatmap(
     fig.update_traces(
         showlegend=True,
         showscale=True,
-        selector={"name": x_axis[-1]})
-    image_filename = "{}/{}.html".format(output_dir, image_filename_prefix)
-    fig.write_html(image_filename)
-    return image_filename, fig
+        selector={"name": x_axis_data[-1]})
+    return fig