about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/api/heatmaps.py6
-rw-r--r--gn3/heatmaps.py14
2 files changed, 13 insertions, 7 deletions
diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py
index 26c165f..1b04a95 100644
--- a/gn3/api/heatmaps.py
+++ b/gn3/api/heatmaps.py
@@ -31,7 +31,11 @@ def clustered_heatmaps():
         traits_fullnames = [parse_trait_fullname(trait) for trait in traits_names]
 
         with io.StringIO() as io_str:
-            figure = build_heatmap(traits_fullnames, conn, vertical=vertical)
+            figure = build_heatmap(conn,
+                                   traits_fullnames,
+                                   current_app.config["GENOTYPE_FILES"],
+                                   vertical=vertical,
+                                   current_app.config["TMPDIR"])
             figure.write_json(io_str)
             fig_json = io_str.getvalue()
         return fig_json, 200
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index 882a3c7..79c4082 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -2,7 +2,7 @@
 This module will contain functions to be used in computation of the data used to
 generate various kinds of heatmaps.
 """
-
+from pathlib import Path
 from functools import reduce
 from typing import Any, Dict, Union, Sequence
 
@@ -11,7 +11,6 @@ import plotly.graph_objects as go # type: ignore
 import plotly.figure_factory as ff # type: ignore
 from plotly.subplots import make_subplots # type: ignore
 
-from gn3.settings import TMPDIR
 from gn3.chancy import random_string
 from gn3.computations.slink import slink
 from gn3.db.traits import export_trait_data
@@ -99,8 +98,11 @@ def get_loci_names(
     return tuple(loci_dict[_chr] for _chr in chromosome_names)
 
 def build_heatmap(
-        traits_names: Sequence[str], conn: Any,
-        vertical: bool = False) -> go.Figure:
+        conn: Any,
+        traits_names: Sequence[str],
+        genotype_files: Union[str, Path],
+        vertical: bool = False,
+        tmpdir: Union[str, Path] = "/tmp") -> go.Figure:
     """
     heatmap function
 
@@ -122,7 +124,7 @@ def build_heatmap(
         retrieve_trait_info(threshold, fullname, conn)
         for fullname in traits_names]
     traits_data_list = [retrieve_trait_data(t, conn) for t in traits]
-    genotype_filename = build_genotype_file(traits[0]["group"])
+    genotype_filename = build_genotype_file(traits[0]["group"], genotype_files)
     samples = load_genotype_samples(genotype_filename)
     exported_traits_data_list = [
         export_trait_data(td, samples) for td in traits_data_list]
@@ -131,7 +133,7 @@ def build_heatmap(
     traits_order = compute_traits_order(slinked)
     samples_and_values = retrieve_samples_and_values(
         traits_order, samples, exported_traits_data_list)
-    traits_filename = f"{TMPDIR}/traits_test_file_{random_string(10)}.txt"
+    traits_filename = f"{tmpdir}/traits_test_file_{random_string(10)}.txt"
     generate_traits_file(
         samples_and_values[0][1],
         [t[2] for t in samples_and_values],