about summary refs log tree commit diff
path: root/gn3
diff options
context:
space:
mode:
Diffstat (limited to 'gn3')
-rw-r--r--gn3/heatmaps.py52
1 files changed, 28 insertions, 24 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index 42231bf..00f4353 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -168,7 +168,7 @@ def get_loci_names(
         __get_trait_loci, [v[1] for v in organised.items()], {})
     return tuple(loci_dict[_chr] for _chr in chromosome_names)
 
-def build_heatmap(traits_names, conn: Any):
+def build_heatmap(traits_names: Sequence[str], conn: Any) -> go.Figure:
     """
     heatmap function
 
@@ -220,16 +220,20 @@ def build_heatmap(traits_names, conn: Any):
         zip(traits_ids,
             [traits[idx]["trait_fullname"] for idx in traits_order]))
 
-    return generate_clustered_heatmap(
+    return clustered_heatmap(
         process_traits_data_for_heatmap(
             organised, traits_ids, chromosome_names),
         clustered,
-        y_axis=tuple(
-            ordered_traits_names[traits_ids[order]]
-            for order in traits_order),
-        y_label="Traits",
-        x_axis=chromosome_names,
-        x_label="Chromosomes",
+        x_axis={
+            "label": "Chromosomes",
+            "data": chromosome_names
+        },
+        y_axis={
+            "label": "Traits",
+            "data": tuple(
+                ordered_traits_names[traits_ids[order]]
+                for order in traits_order)
+        },
         loci_names=get_loci_names(organised, chromosome_names))
 
 def compute_traits_order(slink_data, neworder: tuple = tuple()):
@@ -348,37 +352,37 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names):
         for chr_name in chromosome_names]
     return hdata
 
-def generate_clustered_heatmap(
-        data, clustering_data, image_filename_prefix, x_axis=None,
-        x_label: str = "", y_axis=None, y_label: str = "",
+def clustered_heatmap(
+        data: Sequence[Sequence[float]], clustering_data: Sequence[float],
+        x_axis,#: Dict[Union[str, int], Union[str, Sequence[str]]],
+        y_axis: Dict[str, Union[str, Sequence[str]]],
         loci_names: Sequence[Sequence[str]] = tuple(),
-        output_dir: str = TMPDIR,
-        colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))):
-        data, clustering_data, x_axis=None, x_label: str = "", y_axis=None,
-        y_label: str = "", loci_names: Sequence[Sequence[str]] = tuple(),
-        colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))):
+        colorscale: Sequence[Sequence[Union[float, str]]] = (
+            (0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))) -> go.Figure:
     """
     Generate a dendrogram, and heatmaps for each chromosome, and put them all
     into one plot.
     """
     # pylint: disable=[R0913, R0914]
-    num_cols = 1 + len(x_axis)
+    x_axis_data = x_axis["data"]
+    y_axis_data = y_axis["data"]
+    num_cols = 1 + len(x_axis_data)
     fig = make_subplots(
         rows=1,
         cols=num_cols,
         shared_yaxes="rows",
         horizontal_spacing=0.001,
-        subplot_titles=["distance"] + x_axis,
+        subplot_titles=["distance"] + x_axis_data,
         figure=ff.create_dendrogram(
-            np.array(clustering_data), orientation="right", labels=y_axis))
+            np.array(clustering_data), orientation="right", labels=y_axis_data))
     hms = [go.Heatmap(
         name=chromo,
         x=loci,
-        y=y_axis,
+        y=y_axis_data,
         z=data_array,
         showscale=False)
            for chromo, data_array, loci
-           in zip(x_axis, data, loci_names)]
+           in zip(x_axis_data, data, loci_names)]
     for i, heatmap in enumerate(hms):
         fig.add_trace(heatmap, row=1, col=(i + 2))
 
@@ -389,10 +393,10 @@ def generate_clustered_heatmap(
             "xaxis": {
                 "mirror": False,
                 "showgrid": True,
-                "title": x_label
+                "title": x_axis["label"]
             },
             "yaxis": {
-                "title": y_label
+                "title": y_axis["label"]
             }
         })
 
@@ -420,5 +424,5 @@ def generate_clustered_heatmap(
     fig.update_traces(
         showlegend=True,
         showscale=True,
-        selector={"name": x_axis[-1]})
+        selector={"name": x_axis_data[-1]})
     return fig