aboutsummaryrefslogtreecommitdiff
path: root/gn3/heatmaps.py
diff options
context:
space:
mode:
authorBonfaceKilz2021-09-27 13:59:17 +0300
committerGitHub2021-09-27 13:59:17 +0300
commit0cbb6ecde0315b7d6f021cb17406f5e5197e8a05 (patch)
tree0647dddf8b1aa4530476807bfa3a5dfd54a8119f /gn3/heatmaps.py
parent2cf220a11936125f059dc9b6a494d0f70eac068d (diff)
parenta9fc9814760d205674904f8feb700eadae480fb1 (diff)
downloadgenenetwork3-0cbb6ecde0315b7d6f021cb17406f5e5197e8a05.tar.gz
Merge pull request #37 from genenetwork/heatmap_generation
Heatmap generation
Diffstat (limited to 'gn3/heatmaps.py')
-rw-r--r--gn3/heatmaps.py395
1 files changed, 395 insertions, 0 deletions
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
new file mode 100644
index 0000000..a36940d
--- /dev/null
+++ b/gn3/heatmaps.py
@@ -0,0 +1,395 @@
+"""
+This module will contain functions to be used in computation of the data used to
+generate various kinds of heatmaps.
+"""
+
+from functools import reduce
+from typing import Any, Dict, Sequence
+
+import numpy as np
+import plotly.graph_objects as go # type: ignore
+import plotly.figure_factory as ff # type: ignore
+from plotly.subplots import make_subplots # type: ignore
+
+from gn3.settings import TMPDIR
+from gn3.random import random_string
+from gn3.computations.slink import slink
+from gn3.computations.correlations2 import compute_correlation
+from gn3.db.genotypes import (
+ build_genotype_file, load_genotype_samples)
+from gn3.db.traits import (
+ retrieve_trait_data, retrieve_trait_info)
+from gn3.computations.qtlreaper import (
+ run_reaper,
+ generate_traits_file,
+ chromosome_sorter_key_fn,
+ parse_reaper_main_results,
+ organise_reaper_main_results)
+
+def export_trait_data(
+ trait_data: dict, samplelist: Sequence[str], dtype: str = "val",
+ var_exists: bool = False, n_exists: bool = False):
+ """
+ Export data according to `samplelist`. Mostly used in calculating
+ correlations.
+
+ DESCRIPTION:
+ Migrated from
+ https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211
+
+ PARAMETERS
+ trait: (dict)
+ The dictionary of key-value pairs representing a trait
+ samplelist: (list)
+ A list of sample names
+ dtype: (str)
+ ... verify what this is ...
+ var_exists: (bool)
+ A flag indicating existence of variance
+ n_exists: (bool)
+ A flag indicating existence of ndata
+ """
+ def __export_all_types(tdata, sample):
+ sample_data = []
+ if tdata[sample]["value"]:
+ sample_data.append(tdata[sample]["value"])
+ if var_exists:
+ if tdata[sample]["variance"]:
+ sample_data.append(tdata[sample]["variance"])
+ else:
+ sample_data.append(None)
+ if n_exists:
+ if tdata[sample]["ndata"]:
+ sample_data.append(tdata[sample]["ndata"])
+ else:
+ sample_data.append(None)
+ else:
+ if var_exists and n_exists:
+ sample_data += [None, None, None]
+ elif var_exists or n_exists:
+ sample_data += [None, None]
+ else:
+ sample_data.append(None)
+
+ return tuple(sample_data)
+
+ def __exporter(accumulator, sample):
+ # pylint: disable=[R0911]
+ if sample in trait_data["data"]:
+ if dtype == "val":
+ return accumulator + (trait_data["data"][sample]["value"], )
+ if dtype == "var":
+ return accumulator + (trait_data["data"][sample]["variance"], )
+ if dtype == "N":
+ return accumulator + (trait_data["data"][sample]["ndata"], )
+ if dtype == "all":
+ return accumulator + __export_all_types(trait_data["data"], sample)
+ raise KeyError("Type `%s` is incorrect" % dtype)
+ if var_exists and n_exists:
+ return accumulator + (None, None, None)
+ if var_exists or n_exists:
+ return accumulator + (None, None)
+ return accumulator + (None,)
+
+ return reduce(__exporter, samplelist, tuple())
+
+def trait_display_name(trait: Dict):
+ """
+ Given a trait, return a name to use to display the trait on a heatmap.
+
+ DESCRIPTION
+ Migrated from
+ https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L141-L157
+ """
+ if trait.get("db", None) and trait.get("trait_name", None):
+ if trait["db"]["dataset_type"] == "Temp":
+ desc = trait["description"]
+ if desc.find("PCA") >= 0:
+ return "%s::%s" % (
+ trait["db"]["displayname"],
+ desc[desc.rindex(':')+1:].strip())
+ return "%s::%s" % (
+ trait["db"]["displayname"],
+ desc[:desc.index('entered')].strip())
+ prefix = "%s::%s" % (
+ trait["db"]["dataset_name"], trait["trait_name"])
+ if trait["cellid"]:
+ return "%s::%s" % (prefix, trait["cellid"])
+ return prefix
+ return trait["description"]
+
+def cluster_traits(traits_data_list: Sequence[Dict]):
+ """
+ Clusters the trait values.
+
+ DESCRIPTION
+ Attempts to replicate the clustering of the traits, as done at
+ https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L162
+ """
+ def __compute_corr(tdata_i, tdata_j):
+ if tdata_i[0] == tdata_j[0]:
+ return 0.0
+ corr_vals = compute_correlation(tdata_i[1], tdata_j[1])
+ corr = corr_vals[0]
+ if (1 - corr) < 0:
+ return 0.0
+ return 1 - corr
+
+ def __cluster(tdata_i):
+ return tuple(
+ __compute_corr(tdata_i, tdata_j)
+ for tdata_j in enumerate(traits_data_list))
+
+ return tuple(__cluster(tdata_i) for tdata_i in enumerate(traits_data_list))
+
+def build_heatmap(traits_names, conn: Any):
+ """
+ heatmap function
+
+ DESCRIPTION
+ This function is an attempt to reproduce the initialisation at
+ https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L46-L64
+ and also the clustering and slink computations at
+ https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L165
+ with the help of the `gn3.computations.heatmap.cluster_traits` function.
+
+ It does not try to actually draw the heatmap image.
+
+ PARAMETERS:
+ TODO: Elaborate on the parameters here...
+ """
+ # pylint: disable=[R0914]
+ threshold = 0 # webqtlConfig.PUBLICTHRESH
+ traits = [
+ retrieve_trait_info(threshold, fullname, conn)
+ for fullname in traits_names]
+ traits_data_list = [retrieve_trait_data(t, conn) for t in traits]
+ genotype_filename = build_genotype_file(traits[0]["group"])
+ samples = load_genotype_samples(genotype_filename)
+ exported_traits_data_list = [
+ export_trait_data(td, samples) for td in traits_data_list]
+ clustered = cluster_traits(exported_traits_data_list)
+ slinked = slink(clustered)
+ traits_order = compute_traits_order(slinked)
+ samples_and_values = retrieve_samples_and_values(
+ traits_order, samples, exported_traits_data_list)
+ traits_filename = "{}/traits_test_file_{}.txt".format(
+ TMPDIR, random_string(10))
+ generate_traits_file(
+ samples_and_values[0][1],
+ [t[2] for t in samples_and_values],
+ traits_filename)
+
+ main_output, _permutations_output = run_reaper(
+ genotype_filename, traits_filename, separate_nperm_output=True)
+
+ qtlresults = parse_reaper_main_results(main_output)
+ organised = organise_reaper_main_results(qtlresults)
+
+ traits_ids = [# sort numerically, but retain the ids as strings
+ str(i) for i in sorted({int(row["ID"]) for row in qtlresults})]
+ chromosome_names = sorted(
+ {row["Chr"] for row in qtlresults}, key=chromosome_sorter_key_fn)
+ ordered_traits_names = dict(
+ zip(traits_ids,
+ [traits[idx]["trait_fullname"] for idx in traits_order]))
+
+ return generate_clustered_heatmap(
+ process_traits_data_for_heatmap(
+ organised, traits_ids, chromosome_names),
+ clustered,
+ "single_heatmap_{}".format(random_string(10)),
+ y_axis=tuple(
+ ordered_traits_names[traits_ids[order]]
+ for order in traits_order),
+ y_label="Traits",
+ x_axis=chromosome_names,
+ x_label="Chromosomes")
+
+def compute_traits_order(slink_data, neworder: tuple = tuple()):
+ """
+ Compute the order of the traits for clustering from `slink_data`.
+
+ This function tries to reproduce the creation and update of the `neworder`
+ variable in
+ https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L120
+ and in the `web.webqtl.heatmap.Heatmap.draw` function in GN1
+ """
+ def __order_maker(norder, slnk_dt):
+ if isinstance(slnk_dt[0], int) and isinstance(slnk_dt[1], int):
+ return norder + (slnk_dt[0], slnk_dt[1])
+
+ if isinstance(slnk_dt[0], int):
+ return __order_maker((norder + (slnk_dt[0], )), slnk_dt[1])
+
+ if isinstance(slnk_dt[1], int):
+ return __order_maker(norder, slnk_dt[0]) + (slnk_dt[1], )
+
+ return __order_maker(__order_maker(norder, slnk_dt[0]), slnk_dt[1])
+
+ return __order_maker(neworder, slink_data)
+
+def retrieve_samples_and_values(orders, samplelist, traits_data_list):
+ """
+ Get the samples and their corresponding values from `samplelist` and
+ `traits_data_list`.
+
+ This migrates the code in
+ https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L215-221
+ """
+ # This feels nasty! There's a lot of mutation of values here, that might
+ # indicate something untoward in the design of this function and its
+ # dependents ==> Review
+ samples = []
+ values = []
+ rets = []
+ for order in orders:
+ temp_val = traits_data_list[order]
+ for i, sample in enumerate(samplelist):
+ if temp_val[i] is not None:
+ samples.append(sample)
+ values.append(temp_val[i])
+ rets.append([order, samples[:], values[:]])
+ samples = []
+ values = []
+
+ return rets
+
+def nearest_marker_finder(genotype):
+ """
+ Returns a function to be used with `genotype` to compute the nearest marker
+ to the trait passed to the returned function.
+
+ https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L425-434
+ """
+ def __compute_distances(chromo, trait):
+ loci = chromo.get("loci", None)
+ if not loci:
+ return None
+ return tuple(
+ {
+ "name": locus["name"],
+ "distance": abs(locus["Mb"] - trait["mb"])
+ } for locus in loci)
+
+ def __finder(trait):
+ _chrs = tuple(
+ _chr for _chr in genotype["chromosomes"]
+ if str(_chr["name"]) == str(trait["chr"]))
+ if len(_chrs) == 0:
+ return None
+ distances = tuple(
+ distance for dists in
+ filter(
+ lambda x: x is not None,
+ (__compute_distances(_chr, trait) for _chr in _chrs))
+ for distance in dists)
+ nearest = min(distances, key=lambda d: d["distance"])
+ return nearest["name"]
+ return __finder
+
+def get_nearest_marker(traits_list, genotype):
+ """
+ Retrieves the nearest marker for each of the traits in the list.
+
+ DESCRIPTION:
+ This migrates the code in
+ https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L419-L438
+ """
+ if not genotype["Mbmap"]:
+ return [None] * len(traits_list)
+
+ marker_finder = nearest_marker_finder(genotype)
+ return [marker_finder(trait) for trait in traits_list]
+
+def get_lrs_from_chr(trait, chr_name):
+ """
+ Retrieve the LRS values for a specific chromosome in the given trait.
+ """
+ chromosome = trait["chromosomes"].get(chr_name)
+ if chromosome:
+ return [
+ locus["LRS"] for locus in
+ sorted(chromosome["loci"], key=lambda loc: loc["Locus"])]
+ return [None]
+
+def process_traits_data_for_heatmap(data, trait_names, chromosome_names):
+ """
+ Process the traits data in a format useful for generating heatmap diagrams.
+ """
+ hdata = [
+ [get_lrs_from_chr(data[trait], chr_name) for trait in trait_names]
+ for chr_name in chromosome_names]
+ return hdata
+
+def generate_clustered_heatmap(
+ data, clustering_data, image_filename_prefix, x_axis=None,
+ x_label: str = "", y_axis=None, y_label: str = "",
+ output_dir: str = TMPDIR,
+ colorscale=(
+ (0.0, '#5D5D5D'), (0.4999999999999999, '#ABABAB'),
+ (0.5, '#F5DE11'), (1.0, '#FF0D00'))):
+ """
+ Generate a dendrogram, and heatmaps for each chromosome, and put them all
+ into one plot.
+ """
+ # pylint: disable=[R0913, R0914]
+ num_cols = 1 + len(x_axis)
+ fig = make_subplots(
+ rows=1,
+ cols=num_cols,
+ shared_yaxes="rows",
+ horizontal_spacing=0.001,
+ subplot_titles=["distance"] + x_axis,
+ figure=ff.create_dendrogram(
+ np.array(clustering_data), orientation="right", labels=y_axis))
+ hms = [go.Heatmap(
+ name=chromo,
+ y=y_axis,
+ z=data_array,
+ showscale=False) for chromo, data_array in zip(x_axis, data)]
+ for i, heatmap in enumerate(hms):
+ fig.add_trace(heatmap, row=1, col=(i + 2))
+
+ fig.update_layout(
+ {
+ "width": 1500,
+ "height": 800,
+ "xaxis": {
+ "mirror": False,
+ "showgrid": True,
+ "title": x_label
+ },
+ "yaxis": {
+ "title": y_label
+ }
+ })
+
+ x_axes_layouts = {
+ "xaxis{}".format(i+1 if i > 0 else ""): {
+ "mirror": False,
+ "showticklabels": i == 0,
+ "ticks": "outside" if i == 0 else ""
+ }
+ for i in range(num_cols)}
+
+ fig.update_layout(
+ {
+ "width": 4000,
+ "height": 800,
+ "yaxis": {
+ "mirror": False,
+ "ticks": ""
+ },
+ **x_axes_layouts})
+ fig.update_traces(
+ showlegend=False,
+ colorscale=colorscale,
+ selector={"type": "heatmap"})
+ fig.update_traces(
+ showlegend=True,
+ showscale=True,
+ selector={"name": x_axis[-1]})
+ image_filename = "{}/{}.html".format(output_dir, image_filename_prefix)
+ fig.write_html(image_filename)
+ return image_filename, fig