aboutsummaryrefslogtreecommitdiff
path: root/gn3
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2021-10-19 10:35:32 +0300
committerFrederick Muriuki Muriithi2021-10-19 10:35:32 +0300
commitd603df2b7a50167319c8e26e101e29cef55b3a7a (patch)
treee1b3c6a22a47834a50d6eaaf70544b4dce4a335e /gn3
parentefb9896464f969de4fe8fcaee21a19ac1d881fa2 (diff)
parent546b37e77c11c5268aa9510b9756f2ed4d60241d (diff)
downloadgenenetwork3-d603df2b7a50167319c8e26e101e29cef55b3a7a.tar.gz
Merge branch 'main' of github.com:genenetwork/genenetwork3 into partial-correlations
Diffstat (limited to 'gn3')
-rw-r--r--gn3/api/heatmaps.py6
-rw-r--r--gn3/app.py13
-rw-r--r--gn3/heatmaps.py129
-rw-r--r--gn3/settings.py12
4 files changed, 94 insertions, 66 deletions
diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py
index 62ca2ad..633a061 100644
--- a/gn3/api/heatmaps.py
+++ b/gn3/api/heatmaps.py
@@ -17,7 +17,9 @@ def clustered_heatmaps():
Parses the incoming data and responds with the JSON-serialized plotly figure
representing the clustered heatmap.
"""
- traits_names = request.get_json().get("traits_names", tuple())
+ heatmap_request = request.get_json()
+ traits_names = heatmap_request.get("traits_names", tuple())
+ vertical = heatmap_request.get("vertical", False)
if len(traits_names) < 2:
return jsonify({
"message": "You need to provide at least two trait names."
@@ -30,7 +32,7 @@ def clustered_heatmaps():
traits_fullnames = [parse_trait_fullname(trait) for trait in traits_names]
with io.StringIO() as io_str:
- _filename, figure = build_heatmap(traits_fullnames, conn)
+ figure = build_heatmap(traits_fullnames, conn, vertical=vertical)
figure.write_json(io_str)
fig_json = io_str.getvalue()
return fig_json, 200
diff --git a/gn3/app.py b/gn3/app.py
index a25332c..3d68b3f 100644
--- a/gn3/app.py
+++ b/gn3/app.py
@@ -21,12 +21,6 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask:
# Load default configuration
app.config.from_object("gn3.settings")
- CORS(
- app,
- origins=app.config["CORS_ORIGINS"],
- allow_headers=app.config["CORS_HEADERS"],
- supports_credentials=True, intercept_exceptions=False)
-
# Load environment configuration
if "GN3_CONF" in os.environ:
app.config.from_envvar('GN3_CONF')
@@ -37,6 +31,13 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask:
app.config.update(config)
elif config.endswith(".py"):
app.config.from_pyfile(config)
+
+ CORS(
+ app,
+ origins=app.config["CORS_ORIGINS"],
+ allow_headers=app.config["CORS_HEADERS"],
+ supports_credentials=True, intercept_exceptions=False)
+
app.register_blueprint(general, url_prefix="/api/")
app.register_blueprint(gemma, url_prefix="/api/gemma")
app.register_blueprint(rqtl, url_prefix="/api/rqtl")
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index 3b94e88..bf9dfd1 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -103,7 +103,9 @@ def get_loci_names(
__get_trait_loci, [v[1] for v in organised.items()], {})
return tuple(loci_dict[_chr] for _chr in chromosome_names)
-def build_heatmap(traits_names, conn: Any):
+def build_heatmap(
+ traits_names: Sequence[str], conn: Any,
+ vertical: bool = False) -> go.Figure:
"""
heatmap function
@@ -155,17 +157,21 @@ def build_heatmap(traits_names, conn: Any):
zip(traits_ids,
[traits[idx]["trait_fullname"] for idx in traits_order]))
- return generate_clustered_heatmap(
+ return clustered_heatmap(
process_traits_data_for_heatmap(
organised, traits_ids, chromosome_names),
clustered,
- "single_heatmap_{}".format(random_string(10)),
- y_axis=tuple(
- ordered_traits_names[traits_ids[order]]
- for order in traits_order),
- y_label="Traits",
- x_axis=chromosome_names,
- x_label="Chromosomes",
+ x_axis={
+ "label": "Chromosomes",
+ "data": chromosome_names
+ },
+ y_axis={
+ "label": "Traits",
+ "data": tuple(
+ ordered_traits_names[traits_ids[order]]
+ for order in traits_order)
+ },
+ vertical=vertical,
loci_names=get_loci_names(organised, chromosome_names))
def compute_traits_order(slink_data, neworder: tuple = tuple()):
@@ -284,68 +290,81 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names):
for chr_name in chromosome_names]
return hdata
-def generate_clustered_heatmap(
- data, clustering_data, image_filename_prefix, x_axis=None,
- x_label: str = "", y_axis=None, y_label: str = "",
+def clustered_heatmap(
+ data: Sequence[Sequence[float]], clustering_data: Sequence[float],
+ x_axis,#: Dict[Union[str, int], Union[str, Sequence[str]]],
+ y_axis: Dict[str, Union[str, Sequence[str]]],
loci_names: Sequence[Sequence[str]] = tuple(),
- output_dir: str = TMPDIR,
- colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))):
+ vertical: bool = False,
+ colorscale: Sequence[Sequence[Union[float, str]]] = (
+ (0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))) -> go.Figure:
"""
Generate a dendrogram, and heatmaps for each chromosome, and put them all
into one plot.
"""
# pylint: disable=[R0913, R0914]
- num_cols = 1 + len(x_axis)
+ x_axis_data = x_axis["data"]
+ y_axis_data = y_axis["data"]
+ num_plots = 1 + len(x_axis_data)
fig = make_subplots(
- rows=1,
- cols=num_cols,
- shared_yaxes="rows",
+ rows=num_plots if vertical else 1,
+ cols=1 if vertical else num_plots,
+ shared_xaxes="columns" if vertical else False,
+ shared_yaxes=False if vertical else "rows",
+ vertical_spacing=0.010,
horizontal_spacing=0.001,
- subplot_titles=["distance"] + x_axis,
+ subplot_titles=["" if vertical else x_axis["label"]] + [
+ "Chromosome: {}".format(chromo) if vertical else chromo
+ for chromo in x_axis_data],#+ x_axis_data,
figure=ff.create_dendrogram(
- np.array(clustering_data), orientation="right", labels=y_axis))
+ np.array(clustering_data),
+ orientation="bottom" if vertical else "right",
+ labels=y_axis_data))
hms = [go.Heatmap(
name=chromo,
- x=loci,
- y=y_axis,
+ x=y_axis_data if vertical else loci,
+ y=loci if vertical else y_axis_data,
z=data_array,
+ transpose=vertical,
showscale=False)
for chromo, data_array, loci
- in zip(x_axis, data, loci_names)]
+ in zip(x_axis_data, data, loci_names)]
for i, heatmap in enumerate(hms):
- fig.add_trace(heatmap, row=1, col=(i + 2))
-
- fig.update_layout(
- {
- "width": 1500,
- "height": 800,
- "xaxis": {
+ fig.add_trace(
+ heatmap,
+ row=((i + 2) if vertical else 1),
+ col=(1 if vertical else (i + 2)))
+
+ axes_layouts = {
+ "{axis}axis{count}".format(
+ axis=("y" if vertical else "x"),
+ count=(i+1 if i > 0 else "")): {
"mirror": False,
- "showgrid": True,
- "title": x_label
- },
- "yaxis": {
- "title": y_label
+ "showticklabels": i == 0,
+ "ticks": "outside" if i == 0 else ""
}
- })
+ for i in range(num_plots)}
- x_axes_layouts = {
- "xaxis{}".format(i+1 if i > 0 else ""): {
- "mirror": False,
- "showticklabels": i == 0,
- "ticks": "outside" if i == 0 else ""
- }
- for i in range(num_cols)}
+ print("vertical?: {} ==> {}".format("T" if vertical else "F", axes_layouts))
- fig.update_layout(
- {
- "width": 4000,
- "height": 800,
- "yaxis": {
- "mirror": False,
- "ticks": ""
- },
- **x_axes_layouts})
+ fig.update_layout({
+ "width": 800 if vertical else 4000,
+ "height": 4000 if vertical else 800,
+ "{}axis".format("x" if vertical else "y"): {
+ "mirror": False,
+ "ticks": "",
+ "side": "top" if vertical else "left",
+ "title": y_axis["label"],
+ "tickangle": 90 if vertical else 0,
+ "ticklabelposition": "outside top" if vertical else "outside left"
+ },
+ "{}axis".format("y" if vertical else "x"): {
+ "mirror": False,
+ "showgrid": True,
+ "title": "Distance",
+ "side": "right" if vertical else "top"
+ },
+ **axes_layouts})
fig.update_traces(
showlegend=False,
colorscale=colorscale,
@@ -353,7 +372,5 @@ def generate_clustered_heatmap(
fig.update_traces(
showlegend=True,
showscale=True,
- selector={"name": x_axis[-1]})
- image_filename = "{}/{}.html".format(output_dir, image_filename_prefix)
- fig.write_html(image_filename)
- return image_filename, fig
+ selector={"name": x_axis_data[-1]})
+ return fig
diff --git a/gn3/settings.py b/gn3/settings.py
index 150d96d..d5f1d3c 100644
--- a/gn3/settings.py
+++ b/gn3/settings.py
@@ -35,10 +35,18 @@ GENOTYPE_FILES = os.environ.get(
"GENOTYPE_FILES", "{}/genotype_files/genotype".format(os.environ.get("HOME")))
# CROSS-ORIGIN SETUP
-CORS_ORIGINS = [
+def parse_env_cors(default):
+ """Parse comma-separated configuration into list of strings."""
+ origins_str = os.environ.get("CORS_ORIGINS", None)
+ if origins_str:
+ return [
+ origin.strip() for origin in origins_str.split(",") if origin != ""]
+ return default
+
+CORS_ORIGINS = parse_env_cors([
"http://localhost:*",
"http://127.0.0.1:*"
-]
+])
CORS_HEADERS = [
"Content-Type",