about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/computations/heatmap.py42
-rw-r--r--gn3/db/traits.py5
2 files changed, 27 insertions, 20 deletions
diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py
index e0ff05b..92014cf 100644
--- a/gn3/computations/heatmap.py
+++ b/gn3/computations/heatmap.py
@@ -6,8 +6,12 @@ generate various kinds of heatmaps.
 from functools import reduce
 from typing import Any, Dict, Sequence
 from gn3.computations.slink import slink
-from gn3.db.traits import retrieve_trait_data, retrieve_trait_info
 from gn3.computations.correlations2 import compute_correlation
+from gn3.db.genotypes import build_genotype_file, load_genotype_samples
+from gn3.db.traits import (
+    retrieve_trait_data,
+    retrieve_trait_info,
+    generate_traits_filename)
 
 def export_trait_data(
         trait_data: dict, strainlist: Sequence[str], dtype: str = "val",
@@ -125,7 +129,7 @@ def cluster_traits(traits_data_list: Sequence[Dict]):
 
     return tuple(__cluster(tdata_i) for tdata_i in enumerate(traits_data_list))
 
-def heatmap_data(formd, search_result, conn: Any):
+def heatmap_data(traits_names, conn: Any):
     """
     heatmap function
 
@@ -142,39 +146,37 @@ def heatmap_data(formd, search_result, conn: Any):
     TODO: Elaborate on the parameters here...
     """
     threshold = 0 # webqtlConfig.PUBLICTHRESH
-    cluster_checked = formd.formdata.getvalue("clusterCheck", "")
-    strainlist = [
-        strain for strain in formd.strainlist if strain not in formd.parlist]
-    genotype = formd.genotype
-
     def __retrieve_traitlist_and_datalist(threshold, fullname):
         trait = retrieve_trait_info(threshold, fullname, conn)
         return (trait, retrieve_trait_data(trait, conn))
 
     traits_details = [
         __retrieve_traitlist_and_datalist(threshold, fullname)
-        for fullname in search_result]
+        for fullname in traits_names]
     traits_list = tuple(x[0] for x in traits_details)
     traits_data_list = [x[1] for x in traits_details]
     exported_traits_data_list = tuple(
         export_trait_data(td, strainlist) for td in traits_data_list)
+    genotype_filename = build_genotype_file(traits_list[0]["riset"])
+    strainlist = load_genotype_samples(genotype_filename)
+    slink_data = slink(cluster_traits(exported_traits_data_list))
+    ordering_data = compute_heatmap_order(slink_data)
+    strains_and_values = retrieve_strains_and_values(
+        orders, strainlist, exported_traits_data_list)
+    strains_values = strains_and_values[0][1]
+    trait_values = [t[2] for t in strains_and_values]
+    traits_filename = generate_traits_filename()
+    generate_traits_file(strains_values, trait_values, traits_filename)
 
     return {
-        "target_description_checked": formd.formdata.getvalue(
-            "targetDescriptionCheck", ""),
-        "cluster_checked": cluster_checked,
-        "slink_data": (
-            slink(cluster_traits(exported_traits_data_list))
-            if cluster_checked else False),
-        "sessionfile": formd.formdata.getvalue("session"),
-        "genotype": genotype,
-        "nLoci": sum(map(len, genotype)),
+        "slink_data": slink_data,
+        "ordering_data": ordering_data,
         "strainlist": strainlist,
-        "ppolar": formd.ppolar,
-        "mpolar":formd.mpolar,
+        "genotype_filename": genotype_filename,
         "traits_list": traits_list,
         "traits_data_list": traits_data_list,
-        "exported_traits_data_list": exported_traits_data_list
+        "exported_traits_data_list": exported_traits_data_list,
+        "traits_filename": traits_filename
     }
 
 def compute_heatmap_order(
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index 1031e44..ccb101a 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -1,4 +1,5 @@
 """This class contains functions relating to trait data manipulation"""
+from gn3.settings import TMPDIR
 from typing import Any, Dict, Union, Sequence
 from gn3.function_helpers import compose
 from gn3.db.datasets import retrieve_trait_dataset
@@ -666,3 +667,7 @@ def retrieve_trait_data(trait: dict, conn: Any, strainlist: Sequence[str] = tupl
                     {k:v for k, v in x.items() if x != "strain_name"}),
                 data))}
     return {}
+
+def generate_traits_filename(base_path: str = TMPDIR):
+    return "{}/traits_test_file_{}.txt".format(
+        os.path.abspath(base_path), random_string(10))