From e3e18950cfcdec918429dcbb5d5ed2e9616b7a20 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Wed, 15 Sep 2021 11:19:56 +0300 Subject: Reorganise modules Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * The heatmap generation does not fall cleanly within the computations or db modules. This commit moves it to the higher level gn3 module. --- gn3/heatmaps.py | 302 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 302 insertions(+) create mode 100644 gn3/heatmaps.py (limited to 'gn3/heatmaps.py') diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py new file mode 100644 index 0000000..198fb45 --- /dev/null +++ b/gn3/heatmaps.py @@ -0,0 +1,302 @@ +""" +This module will contain functions to be used in computation of the data used to +generate various kinds of heatmaps. +""" + +from functools import reduce +from typing import Any, Dict, Sequence +from gn3.computations.slink import slink +from gn3.computations.qtlreaper import generate_traits_file +from gn3.computations.correlations2 import compute_correlation +from gn3.db.genotypes import build_genotype_file, load_genotype_samples +from gn3.db.traits import ( + retrieve_trait_data, + retrieve_trait_info, + generate_traits_filename) + +def export_trait_data( + trait_data: dict, strainlist: Sequence[str], dtype: str = "val", + var_exists: bool = False, n_exists: bool = False): + """ + Export data according to `strainlist`. Mostly used in calculating + correlations. + + DESCRIPTION: + Migrated from + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211 + + PARAMETERS + trait: (dict) + The dictionary of key-value pairs representing a trait + strainlist: (list) + A list of strain names + dtype: (str) + ... verify what this is ... + var_exists: (bool) + A flag indicating existence of variance + n_exists: (bool) + A flag indicating existence of ndata + """ + def __export_all_types(tdata, strain): + sample_data = [] + if tdata[strain]["value"]: + sample_data.append(tdata[strain]["value"]) + if var_exists: + if tdata[strain]["variance"]: + sample_data.append(tdata[strain]["variance"]) + else: + sample_data.append(None) + if n_exists: + if tdata[strain]["ndata"]: + sample_data.append(tdata[strain]["ndata"]) + else: + sample_data.append(None) + else: + if var_exists and n_exists: + sample_data += [None, None, None] + elif var_exists or n_exists: + sample_data += [None, None] + else: + sample_data.append(None) + + return tuple(sample_data) + + def __exporter(accumulator, strain): + # pylint: disable=[R0911] + if strain in trait_data["data"]: + if dtype == "val": + return accumulator + (trait_data["data"][strain]["value"], ) + if dtype == "var": + return accumulator + (trait_data["data"][strain]["variance"], ) + if dtype == "N": + return accumulator + (trait_data["data"][strain]["ndata"], ) + if dtype == "all": + return accumulator + __export_all_types(trait_data["data"], strain) + raise KeyError("Type `%s` is incorrect" % dtype) + if var_exists and n_exists: + return accumulator + (None, None, None) + if var_exists or n_exists: + return accumulator + (None, None) + return accumulator + (None,) + + return reduce(__exporter, strainlist, tuple()) + +def trait_display_name(trait: Dict): + """ + Given a trait, return a name to use to display the trait on a heatmap. + + DESCRIPTION + Migrated from + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L141-L157 + """ + if trait.get("db", None) and trait.get("trait_name", None): + if trait["db"]["dataset_type"] == "Temp": + desc = trait["description"] + if desc.find("PCA") >= 0: + return "%s::%s" % ( + trait["db"]["displayname"], + desc[desc.rindex(':')+1:].strip()) + return "%s::%s" % ( + trait["db"]["displayname"], + desc[:desc.index('entered')].strip()) + prefix = "%s::%s" % ( + trait["db"]["dataset_name"], trait["trait_name"]) + if trait["cellid"]: + return "%s::%s" % (prefix, trait["cellid"]) + return prefix + return trait["description"] + +def cluster_traits(traits_data_list: Sequence[Dict]): + """ + Clusters the trait values. + + DESCRIPTION + Attempts to replicate the clustering of the traits, as done at + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L162 + """ + def __compute_corr(tdata_i, tdata_j): + if tdata_i[0] == tdata_j[0]: + return 0.0 + corr_vals = compute_correlation(tdata_i[1], tdata_j[1]) + corr = corr_vals[0] + if (1 - corr) < 0: + return 0.0 + return 1 - corr + + def __cluster(tdata_i): + return tuple( + __compute_corr(tdata_i, tdata_j) + for tdata_j in enumerate(traits_data_list)) + + return tuple(__cluster(tdata_i) for tdata_i in enumerate(traits_data_list)) + +def heatmap_data(traits_names, conn: Any): + """ + heatmap function + + DESCRIPTION + This function is an attempt to reproduce the initialisation at + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L46-L64 + and also the clustering and slink computations at + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L165 + with the help of the `gn3.computations.heatmap.cluster_traits` function. + + It does not try to actually draw the heatmap image. + + PARAMETERS: + TODO: Elaborate on the parameters here... + """ + threshold = 0 # webqtlConfig.PUBLICTHRESH + def __retrieve_traitlist_and_datalist(threshold, fullname): + trait = retrieve_trait_info(threshold, fullname, conn) + return (trait, retrieve_trait_data(trait, conn)) + + traits_details = [ + __retrieve_traitlist_and_datalist(threshold, fullname) + for fullname in traits_names] + traits_list = tuple(x[0] for x in traits_details) + traits_data_list = [x[1] for x in traits_details] + genotype_filename = build_genotype_file(traits_list[0]["riset"]) + strainlist = load_genotype_samples(genotype_filename) + exported_traits_data_list = tuple( + export_trait_data(td, strainlist) for td in traits_data_list) + slink_data = slink(cluster_traits(exported_traits_data_list)) + ordering_data = compute_heatmap_order(slink_data) + strains_and_values = retrieve_strains_and_values( + ordering_data, strainlist, exported_traits_data_list) + strains_values = strains_and_values[0][1] + trait_values = [t[2] for t in strains_and_values] + traits_filename = generate_traits_filename() + generate_traits_file(strains_values, trait_values, traits_filename) + + return { + "slink_data": slink_data, + "ordering_data": ordering_data, + "strainlist": strainlist, + "genotype_filename": genotype_filename, + "traits_list": traits_list, + "traits_data_list": traits_data_list, + "exported_traits_data_list": exported_traits_data_list, + "traits_filename": traits_filename + } + +def compute_traits_order(slink_data, neworder: tuple = tuple()): + """ + Compute the order of the traits for clustering from `slink_data`. + + This function tries to reproduce the creation and update of the `neworder` + variable in + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L120 + and in the `web.webqtl.heatmap.Heatmap.draw` function in GN1 + """ + def __order_maker(norder, slnk_dt): + if isinstance(slnk_dt[0], int) and isinstance(slnk_dt[1], int): + return norder + (slnk_dt[0], slnk_dt[1]) + + if isinstance(slnk_dt[0], int): + return __order_maker((norder + (slnk_dt[0], )), slnk_dt[1]) + + if isinstance(slnk_dt[1], int): + return __order_maker(norder, slnk_dt[0]) + (slnk_dt[1], ) + + return __order_maker(__order_maker(norder, slnk_dt[0]), slnk_dt[1]) + + return __order_maker(neworder, slink_data) + +def retrieve_strains_and_values(orders, strainlist, traits_data_list): + """ + Get the strains and their corresponding values from `strainlist` and + `traits_data_list`. + + This migrates the code in + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L215-221 + """ + # This feels nasty! There's a lot of mutation of values here, that might + # indicate something untoward in the design of this function and its + # dependents ==> Review + strains = [] + values = [] + rets = [] + for order in orders: + temp_val = traits_data_list[order] + for i, strain in enumerate(strainlist): + if temp_val[i] is not None: + strains.append(strain) + values.append(temp_val[i]) + rets.append([order, strains[:], values[:]]) + strains = [] + values = [] + + return rets + +def nearest_marker_finder(genotype): + """ + Returns a function to be used with `genotype` to compute the nearest marker + to the trait passed to the returned function. + + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L425-434 + """ + def __compute_distances(chromo, trait): + loci = chromo.get("loci", None) + if not loci: + return None + return tuple( + { + "name": locus["name"], + "distance": abs(locus["Mb"] - trait["mb"]) + } for locus in loci) + + def __finder(trait): + _chrs = tuple( + _chr for _chr in genotype["chromosomes"] + if str(_chr["name"]) == str(trait["chr"])) + if len(_chrs) == 0: + return None + distances = tuple( + distance for dists in + filter( + lambda x: x is not None, + (__compute_distances(_chr, trait) for _chr in _chrs)) + for distance in dists) + nearest = min(distances, key=lambda d: d["distance"]) + return nearest["name"] + return __finder + +def get_nearest_marker(traits_list, genotype): + """ + Retrieves the nearest marker for each of the traits in the list. + + DESCRIPTION: + This migrates the code in + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L419-L438 + """ + if not genotype["Mbmap"]: + return [None] * len(trait_list) + + marker_finder = nearest_marker_finder(genotype) + return [marker_finder(trait) for trait in traits_list] + +# # Grey + Blue + Red +# def generate_heatmap(): +# cols = 20 +# y_axis = (["%s"%x for x in range(1, cols+1)][:-1] + ["X"]) #replace last item with x for now +# x_axis = heatmap_x_axis_names() +# data = generate_random_data(height=cols, width=len(x_axis)) +# fig = px.imshow( +# data, +# x=x_axis, +# y=y_axis, +# width=500) +# fig.update_traces(xtype="array") +# fig.update_traces(ytype="array") +# # fig.update_traces(xgap=10) +# fig.update_xaxes( +# visible=True, +# title_text="Traits", +# title_font_size=16) +# fig.update_layout( +# coloraxis_colorscale=[ +# [0.0, '#3B3B3B'], [0.4999999999999999, '#ABABAB'], +# [0.5, '#F5DE11'], [1.0, '#FF0D00']]) +# fig.write_html("%s/%s"%(heatmap_dir, "test_image.html")) +# return fig -- cgit v1.2.3 From b1eb0451578c53afabe4f2054ce08665dec4bb82 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Wed, 15 Sep 2021 11:41:36 +0300 Subject: Integrate get_lsr_from_chr function * gn3/heatmaps.py: copy over function * tests/unit/test_heatmaps.py: add tests Copy function over from proof of concept and add some tests to ensure it works as expected. --- gn3/heatmaps.py | 8 ++++++++ tests/unit/test_heatmaps.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+) (limited to 'gn3/heatmaps.py') diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index 198fb45..991ddec 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -276,6 +276,14 @@ def get_nearest_marker(traits_list, genotype): marker_finder = nearest_marker_finder(genotype) return [marker_finder(trait) for trait in traits_list] +def get_lrs_from_chr(trait, chr_name): + chromosome = trait["chromosomes"].get(chr_name) + if chromosome: + return [ + locus["LRS"] for locus in + sorted(chromosome["loci"], key=lambda loc: loc["Locus"])] + return [None] + # # Grey + Blue + Red # def generate_heatmap(): # cols = 20 diff --git a/tests/unit/test_heatmaps.py b/tests/unit/test_heatmaps.py index 265d5a8..cfdde1e 100644 --- a/tests/unit/test_heatmaps.py +++ b/tests/unit/test_heatmaps.py @@ -2,6 +2,7 @@ from unittest import TestCase from gn3.heatmaps import ( cluster_traits, + get_lrs_from_chr, export_trait_data, compute_traits_order, retrieve_strains_and_values) @@ -185,3 +186,16 @@ class TestHeatmap(TestCase): with self.subTest(strainlist=slist, traitdata=tdata): self.assertEqual( retrieve_strains_and_values(orders, slist, tdata), expected) + + def test_get_lrs_from_chr(self): + for trait, chromosome, expected in [ + [{"chromosomes": {}}, 3, [None]], + [{"chromosomes": {3: {"loci": [ + {"Locus": "b", "LRS": 1.9}, + {"Locus": "a", "LRS": 13.2}, + {"Locus": "d", "LRS": 53.21}, + {"Locus": "c", "LRS": 2.22}]}}}, + 3, + [13.2, 1.9, 2.22, 53.21]]]: + with self.subTest(trait=trait, chromosome=chromosome): + self.assertEqual(get_lrs_from_chr(trait, chromosome), expected) -- cgit v1.2.3 From 11632a565a6f901eca852a5a40a6f9fd3170152a Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Wed, 15 Sep 2021 12:08:56 +0300 Subject: Process data into format usable by heatmaps * gn3/heatmaps.py: implement `process_traits_data_for_heatmap` function, that will process the data into a form usable by heatmaps. * tests/unit/test_heatmaps.py: check that the function processes the data into the correct form. --- gn3/heatmaps.py | 12 +++++ tests/unit/test_heatmaps.py | 107 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 118 insertions(+), 1 deletion(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index 991ddec..0c00d6c 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -277,6 +277,9 @@ def get_nearest_marker(traits_list, genotype): return [marker_finder(trait) for trait in traits_list] def get_lrs_from_chr(trait, chr_name): + """ + Retrieve the LRS values for a specific chromosome in the given trait. + """ chromosome = trait["chromosomes"].get(chr_name) if chromosome: return [ @@ -284,6 +287,15 @@ def get_lrs_from_chr(trait, chr_name): sorted(chromosome["loci"], key=lambda loc: loc["Locus"])] return [None] +def process_traits_data_for_heatmap(data, trait_names, chromosome_names): + """ + Process the traits data in a format useful for generating heatmap diagrams. + """ + hdata = [ + [get_lrs_from_chr(data[trait], chr_name) for trait in trait_names] + for chr_name in chromosome_names] + return hdata + # # Grey + Blue + Red # def generate_heatmap(): # cols = 20 diff --git a/tests/unit/test_heatmaps.py b/tests/unit/test_heatmaps.py index cfdde1e..f3a81c5 100644 --- a/tests/unit/test_heatmaps.py +++ b/tests/unit/test_heatmaps.py @@ -5,7 +5,8 @@ from gn3.heatmaps import ( get_lrs_from_chr, export_trait_data, compute_traits_order, - retrieve_strains_and_values) + retrieve_strains_and_values, + process_traits_data_for_heatmap) strainlist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"] trait_data = { @@ -199,3 +200,107 @@ class TestHeatmap(TestCase): [13.2, 1.9, 2.22, 53.21]]]: with self.subTest(trait=trait, chromosome=chromosome): self.assertEqual(get_lrs_from_chr(trait, chromosome), expected) + + def test_process_traits_data_for_heatmap(self): + self.assertEqual( + process_traits_data_for_heatmap( + {"1": { + "ID": "T1", + "chromosomes": { + 1: {"Chr": 1, + "loci": [ + { + "Locus": "rs31443144", "cM": 1.500, "Mb": 3.010, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs6269442", "cM": 1.500, "Mb": 3.492, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs32285189", "cM": 1.630, "Mb": 3.511, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs258367496", "cM": 1.630, "Mb": 3.660, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs32430919", "cM": 1.750, "Mb": 3.777, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs36251697", "cM": 1.880, "Mb": 3.812, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs30658298", "cM": 2.010, "Mb": 4.431, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }]}, + 2: {"Chr": 2, + "loci": [ + { + "Locus": "rs51852623", "cM": 2.010, "Mb": 4.447, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs31879829", "cM": 2.140, "Mb": 4.519, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs36742481", "cM": 2.140, "Mb": 4.776, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }]}}}, + "2": { + "ID": "T1", + "chromosomes": { + 1: {"Chr": 1, + "loci": [ + { + "Locus": "rs31443144", "cM": 1.500, "Mb": 3.010, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs6269442", "cM": 1.500, "Mb": 3.492, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs32285189", "cM": 1.630, "Mb": 3.511, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs258367496", "cM": 1.630, "Mb": 3.660, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs32430919", "cM": 1.750, "Mb": 3.777, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs36251697", "cM": 1.880, "Mb": 3.812, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs30658298", "cM": 2.010, "Mb": 4.431, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }]}, + 2: {"Chr": 2, + "loci": [ + { + "Locus": "rs51852623", "cM": 2.010, "Mb": 4.447, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs31879829", "cM": 2.140, "Mb": 4.519, + "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 + }, + { + "Locus": "rs36742481", "cM": 2.140, "Mb": 4.776, + "LRS": 0.579, "Additive": -0.074, "pValue": 1.000 + }]}}}}, + ["2", "1"], + [1, 2]), + [[[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]], + [[0.5, 0.579, 0.5], + [0.5, 0.5, 0.5]]]) -- cgit v1.2.3 From e9fb4e45cfc52c5d86ef534b0e7f42ba8f4c84d3 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Wed, 15 Sep 2021 12:28:56 +0300 Subject: Generate heatmaps in a single plot Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Add a function to generate the heatmaps for each chromosome into a single plot. --- gn3/heatmaps.py | 65 ++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 41 insertions(+), 24 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index 0c00d6c..f3d7d25 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -4,6 +4,7 @@ generate various kinds of heatmaps. """ from functools import reduce +from gn3.settings import TMPDIR from typing import Any, Dict, Sequence from gn3.computations.slink import slink from gn3.computations.qtlreaper import generate_traits_file @@ -296,27 +297,43 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names): for chr_name in chromosome_names] return hdata -# # Grey + Blue + Red -# def generate_heatmap(): -# cols = 20 -# y_axis = (["%s"%x for x in range(1, cols+1)][:-1] + ["X"]) #replace last item with x for now -# x_axis = heatmap_x_axis_names() -# data = generate_random_data(height=cols, width=len(x_axis)) -# fig = px.imshow( -# data, -# x=x_axis, -# y=y_axis, -# width=500) -# fig.update_traces(xtype="array") -# fig.update_traces(ytype="array") -# # fig.update_traces(xgap=10) -# fig.update_xaxes( -# visible=True, -# title_text="Traits", -# title_font_size=16) -# fig.update_layout( -# coloraxis_colorscale=[ -# [0.0, '#3B3B3B'], [0.4999999999999999, '#ABABAB'], -# [0.5, '#F5DE11'], [1.0, '#FF0D00']]) -# fig.write_html("%s/%s"%(heatmap_dir, "test_image.html")) -# return fig +def generate_clustered_heatmap( + data, image_filename_prefix, x_axis = None, x_label: str = "", + y_axis = None, y_label: str = "", output_dir: str = TMPDIR, + colorscale = [ + [0.0, '#3B3B3B'], [0.4999999999999999, '#ABABAB'], + [0.5, '#F5DE11']], [1.0, '#FF0D00']): + """ + Generate a dendrogram, and heatmaps for each chromosome, and put them all + into one plot. + """ + num_cols = len(x_axis) + fig = make_subplots( + rows=1, + cols=num_cols, + shared_yaxes="rows", + # horizontal_spacing=(1 / (num_cols - 1)), + subplot_titles=x_axis + ) + hms = [go.Heatmap( + name=chromo, + y = y_axis, + z = data_array, + showscale=False) for chromo, data_array in zip(x_axis, data)] + for col, hm in enumerate(hms): + fig.add_trace(hm, row=1, col=(col + 1)) + + fig.update_traces( + showlegend=False, + colorscale=colorscale, + selector={"type": "heatmap"}) + fig.update_traces( + showlegend=True, + showscale=True, + selector={"name": x_axis[-1]}) + fig.update_layout( + coloraxis_colorscale=colorscale + ) + image_filename = "{}/{}.html".format(output_dir, image_filename_prefix) + fig.write_html(image_filename) + return image_filename, fig -- cgit v1.2.3 From 347301c1bc60ba5036625364ee48a5d72eeb1186 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Wed, 15 Sep 2021 12:42:38 +0300 Subject: Update entry-point function for heatmap generation Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Copy over code from the proof-of-concept implementation and clean it up a little for the entry-point function for heatmap generation via the API --- gn3/heatmaps.py | 70 ++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 21 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index f3d7d25..4349ee0 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -131,7 +131,7 @@ def cluster_traits(traits_data_list: Sequence[Dict]): return tuple(__cluster(tdata_i) for tdata_i in enumerate(traits_data_list)) -def heatmap_data(traits_names, conn: Any): +def build_heatmap(traits_names, conn: Any): """ heatmap function @@ -148,27 +148,55 @@ def heatmap_data(traits_names, conn: Any): TODO: Elaborate on the parameters here... """ threshold = 0 # webqtlConfig.PUBLICTHRESH - def __retrieve_traitlist_and_datalist(threshold, fullname): - trait = retrieve_trait_info(threshold, fullname, conn) - return (trait, retrieve_trait_data(trait, conn)) - - traits_details = [ - __retrieve_traitlist_and_datalist(threshold, fullname) - for fullname in traits_names] - traits_list = tuple(x[0] for x in traits_details) - traits_data_list = [x[1] for x in traits_details] - genotype_filename = build_genotype_file(traits_list[0]["riset"]) - strainlist = load_genotype_samples(genotype_filename) - exported_traits_data_list = tuple( - export_trait_data(td, strainlist) for td in traits_data_list) - slink_data = slink(cluster_traits(exported_traits_data_list)) - ordering_data = compute_heatmap_order(slink_data) + traits = [ + retrieve_trait_info(threshold, fullname, conn) + for fullname in trait_fullnames()] + traits_data_list = [retrieve_trait_data(t, conn) for t in traits] + genotype_filename = build_genotype_file(traits[0]["riset"]) + genotype = parse_genotype_file(genotype_filename) + strains = load_genotype_samples(genotype_filename) + exported_traits_data_list = [ + export_trait_data(td, strains) for td in traits_data_list] + slinked = slink(cluster_traits(exported_traits_data_list)) + traits_order = compute_traits_order(slinked) + ordered_traits_names = [ + traits[idx]["trait_fullname"] for idx in traits_order] strains_and_values = retrieve_strains_and_values( - ordering_data, strainlist, exported_traits_data_list) - strains_values = strains_and_values[0][1] - trait_values = [t[2] for t in strains_and_values] - traits_filename = generate_traits_filename() - generate_traits_file(strains_values, trait_values, traits_filename) + traits_order, strains, exported_traits_data_list) + traits_filename = "{}/traits_test_file_{}.txt".format( + TMPDIR, random_string(10)) + generate_traits_file( + strains_and_values[0][1], + [t[2] for t in strains_and_values], + traits_filename) + + main_output, permutations_output = run_reaper( + genotype_filename, traits_filename, separate_nperm_output=True) + + qtlresults = parse_reaper_main_results(main_output) + permudata = parse_reaper_permutation_results(permutations_output) + organised = organise_reaper_main_results(qtlresults) + + traits_ids = [# sort numerically, but retain the ids as strings + str(i) for i in sorted({int(row["ID"]) for row in qtlresults})] + chromosome_names = sorted( + {row["Chr"] for row in qtlresults}, key = chromosome_sorter_key_fn) + loci_names = sorted({row["Locus"] for row in qtlresults}) + ordered_traits_names = { + res_id: trait for res_id, trait in + zip(traits_ids, + [traits[idx]["trait_fullname"] for idx in traits_order])} + + return generate_clustered_heatmap( + process_traits_data_for_heatmap( + organised, traits_ids, chromosome_names), + "single_heatmap_{}".format(random_string(10)), + y_axis=tuple( + ordered_traits_names[traits_ids[order]] + for order in traits_order), + y_label="Traits", + x_axis=[chromo for chromo in chromosome_names], + x_label="Chromosomes") return { "slink_data": slink_data, -- cgit v1.2.3 From 78e37440844a0397d29135a8ba215e54fd92c86d Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Thu, 16 Sep 2021 11:35:23 +0300 Subject: Add missing imports Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Add missing imports that are needed in the code. --- gn3/heatmaps.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index 4349ee0..c48a2d3 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -5,15 +5,25 @@ generate various kinds of heatmaps. from functools import reduce from gn3.settings import TMPDIR +import plotly.graph_objects as go +from gn3.random import random_string from typing import Any, Dict, Sequence from gn3.computations.slink import slink -from gn3.computations.qtlreaper import generate_traits_file +from plotly.subplots import make_subplots from gn3.computations.correlations2 import compute_correlation -from gn3.db.genotypes import build_genotype_file, load_genotype_samples +from gn3.db.genotypes import ( + build_genotype_file, load_genotype_samples, parse_genotype_file) from gn3.db.traits import ( retrieve_trait_data, retrieve_trait_info, generate_traits_filename) +from gn3.computations.qtlreaper import ( + run_reaper, + generate_traits_file, + chromosome_sorter_key_fn, + parse_reaper_main_results, + organise_reaper_main_results, + parse_reaper_permutation_results) def export_trait_data( trait_data: dict, strainlist: Sequence[str], dtype: str = "val", -- cgit v1.2.3 From 2cc9f382e199dbdbaab98c7e06deabd72e244adb Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Thu, 16 Sep 2021 11:36:42 +0300 Subject: Fix minor bugs Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Fix a few minor bugs left over from the integration of code from the proof-of-concept code. --- gn3/heatmaps.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index c48a2d3..170b0cd 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -160,7 +160,7 @@ def build_heatmap(traits_names, conn: Any): threshold = 0 # webqtlConfig.PUBLICTHRESH traits = [ retrieve_trait_info(threshold, fullname, conn) - for fullname in trait_fullnames()] + for fullname in traits_names] traits_data_list = [retrieve_trait_data(t, conn) for t in traits] genotype_filename = build_genotype_file(traits[0]["riset"]) genotype = parse_genotype_file(genotype_filename) @@ -338,9 +338,9 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names): def generate_clustered_heatmap( data, image_filename_prefix, x_axis = None, x_label: str = "", y_axis = None, y_label: str = "", output_dir: str = TMPDIR, - colorscale = [ - [0.0, '#3B3B3B'], [0.4999999999999999, '#ABABAB'], - [0.5, '#F5DE11']], [1.0, '#FF0D00']): + colorscale = ( + (0.0, '#3B3B3B'), (0.4999999999999999, '#ABABAB'), + (0.5, '#F5DE11'), (1.0, '#FF0D00'))): """ Generate a dendrogram, and heatmaps for each chromosome, and put them all into one plot. -- cgit v1.2.3 From 78871ef396f394c54072960e985476e418220fe3 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Fri, 17 Sep 2021 10:30:16 +0300 Subject: Create dendrogram to show clustering tree Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Provide the clustering data to be used for the creation of the clustering dendrogram in the final clustered heatmap plot. --- gn3/heatmaps.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index 170b0cd..bf69d9b 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -3,9 +3,11 @@ This module will contain functions to be used in computation of the data used to generate various kinds of heatmaps. """ +import numpy as np from functools import reduce from gn3.settings import TMPDIR import plotly.graph_objects as go +import plotly.figure_factory as ff from gn3.random import random_string from typing import Any, Dict, Sequence from gn3.computations.slink import slink @@ -167,7 +169,8 @@ def build_heatmap(traits_names, conn: Any): strains = load_genotype_samples(genotype_filename) exported_traits_data_list = [ export_trait_data(td, strains) for td in traits_data_list] - slinked = slink(cluster_traits(exported_traits_data_list)) + clustered = cluster_traits(exported_traits_data_list) + slinked = slink(clustered) traits_order = compute_traits_order(slinked) ordered_traits_names = [ traits[idx]["trait_fullname"] for idx in traits_order] @@ -200,6 +203,7 @@ def build_heatmap(traits_names, conn: Any): return generate_clustered_heatmap( process_traits_data_for_heatmap( organised, traits_ids, chromosome_names), + clustered, "single_heatmap_{}".format(random_string(10)), y_axis=tuple( ordered_traits_names[traits_ids[order]] @@ -336,8 +340,9 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names): return hdata def generate_clustered_heatmap( - data, image_filename_prefix, x_axis = None, x_label: str = "", - y_axis = None, y_label: str = "", output_dir: str = TMPDIR, + data, clustering_data, image_filename_prefix, x_axis = None, + x_label: str = "", y_axis = None, y_label: str = "", + output_dir: str = TMPDIR, colorscale = ( (0.0, '#3B3B3B'), (0.4999999999999999, '#ABABAB'), (0.5, '#F5DE11'), (1.0, '#FF0D00'))): @@ -345,21 +350,22 @@ def generate_clustered_heatmap( Generate a dendrogram, and heatmaps for each chromosome, and put them all into one plot. """ - num_cols = len(x_axis) + num_cols = 1 + len(x_axis) fig = make_subplots( rows=1, cols=num_cols, shared_yaxes="rows", - # horizontal_spacing=(1 / (num_cols - 1)), - subplot_titles=x_axis - ) + horizontal_spacing=0.001, + subplot_titles=["distance"] + x_axis, + figure = ff.create_dendrogram( + np.array(clustering_data), orientation="right", labels=y_axis)) hms = [go.Heatmap( name=chromo, y = y_axis, z = data_array, showscale=False) for chromo, data_array in zip(x_axis, data)] - for col, hm in enumerate(hms): - fig.add_trace(hm, row=1, col=(col + 1)) + for i, hm in enumerate(hms): + fig.add_trace(hm, row=1, col=(i + 2)) fig.update_traces( showlegend=False, -- cgit v1.2.3 From e17540bf8a57837fcf3241ea0b694250b53294fc Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Fri, 17 Sep 2021 10:32:23 +0300 Subject: Fix some layout issues and update colorscale Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Update the plot layouts and size to display the dendrogram and individual chromosome heatmaps side by side * Update the colour scale to begin with the grays rather than absolute black --- gn3/heatmaps.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index bf69d9b..2859dde 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -344,7 +344,7 @@ def generate_clustered_heatmap( x_label: str = "", y_axis = None, y_label: str = "", output_dir: str = TMPDIR, colorscale = ( - (0.0, '#3B3B3B'), (0.4999999999999999, '#ABABAB'), + (0.0, '#5D5D5D'), (0.4999999999999999, '#ABABAB'), (0.5, '#F5DE11'), (1.0, '#FF0D00'))): """ Generate a dendrogram, and heatmaps for each chromosome, and put them all @@ -367,6 +367,33 @@ def generate_clustered_heatmap( for i, hm in enumerate(hms): fig.add_trace(hm, row=1, col=(i + 2)) + fig.update_layout( + { + "width": 1500, + "height": 800, + "xaxis": { + "mirror": False, + "showgrid": True + } + }) + + x_axes_layouts = { + "xaxis{}".format(i+1 if i > 0 else ""): { + "mirror": False, + "showticklabels": True if i==0 else False, + "ticks": "outside" if i==0 else "" + } + for i in range(num_cols)} + + fig.update_layout( + { + "width": 4000, + "height": 800, + "yaxis": { + "mirror": False, + "ticks": "" + }, + **x_axes_layouts}) fig.update_traces( showlegend=False, colorscale=colorscale, @@ -375,9 +402,6 @@ def generate_clustered_heatmap( showlegend=True, showscale=True, selector={"name": x_axis[-1]}) - fig.update_layout( - coloraxis_colorscale=colorscale - ) image_filename = "{}/{}.html".format(output_dir, image_filename_prefix) fig.write_html(image_filename) return image_filename, fig -- cgit v1.2.3 From 1e2357049adc72808fbf8eaac3da9411d3c78c66 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Fri, 17 Sep 2021 11:20:16 +0300 Subject: Fix a number of linting issues Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi --- gn3/computations/qtlreaper.py | 7 ++-- gn3/db/genotypes.py | 2 +- gn3/heatmaps.py | 54 ++++++++++++------------------- tests/unit/computations/test_qtlreaper.py | 3 +- tests/unit/test_heatmaps.py | 6 ++-- 5 files changed, 32 insertions(+), 40 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/computations/qtlreaper.py b/gn3/computations/qtlreaper.py index 5180853..377db9b 100644 --- a/gn3/computations/qtlreaper.py +++ b/gn3/computations/qtlreaper.py @@ -110,9 +110,10 @@ def organise_reaper_main_results(parsed_results): unique_chromosomes = {item["Chr"] for item in id_items} return { "ID": identifier, - "chromosomes": {_chr["Chr"]: _chr for _chr in [ - __organise_by_chromosome(chromo, id_items) - for chromo in sorted( + "chromosomes": { + _chr["Chr"]: _chr for _chr in [ + __organise_by_chromosome(chromo, id_items) + for chromo in sorted( unique_chromosomes, key=chromosome_sorter_key_fn)]}} unique_ids = {res["ID"] for res in parsed_results} diff --git a/gn3/db/genotypes.py b/gn3/db/genotypes.py index b03d55c..9d052d9 100644 --- a/gn3/db/genotypes.py +++ b/gn3/db/genotypes.py @@ -174,7 +174,7 @@ def parse_genotype_file(filename: str, parlist: tuple = tuple()): geno_obj = dict(labels + header) markers = tuple( [parse_genotype_marker(line, geno_obj, parlist) - for line in data_lines[1:]]) + for line in data_lines[1:]]) chromosomes = tuple( dict(chromosome) for chromosome in build_genotype_chromosomes(geno_obj, markers)) diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index 2859dde..c4fc67d 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -3,13 +3,13 @@ This module will contain functions to be used in computation of the data used to generate various kinds of heatmaps. """ +from typing import Any, Dict, Sequence import numpy as np from functools import reduce from gn3.settings import TMPDIR import plotly.graph_objects as go import plotly.figure_factory as ff from gn3.random import random_string -from typing import Any, Dict, Sequence from gn3.computations.slink import slink from plotly.subplots import make_subplots from gn3.computations.correlations2 import compute_correlation @@ -165,7 +165,7 @@ def build_heatmap(traits_names, conn: Any): for fullname in traits_names] traits_data_list = [retrieve_trait_data(t, conn) for t in traits] genotype_filename = build_genotype_file(traits[0]["riset"]) - genotype = parse_genotype_file(genotype_filename) + # genotype = parse_genotype_file(genotype_filename) strains = load_genotype_samples(genotype_filename) exported_traits_data_list = [ export_trait_data(td, strains) for td in traits_data_list] @@ -183,22 +183,21 @@ def build_heatmap(traits_names, conn: Any): [t[2] for t in strains_and_values], traits_filename) - main_output, permutations_output = run_reaper( + main_output, _permutations_output = run_reaper( genotype_filename, traits_filename, separate_nperm_output=True) qtlresults = parse_reaper_main_results(main_output) - permudata = parse_reaper_permutation_results(permutations_output) + # permudata = parse_reaper_permutation_results(permutations_output) organised = organise_reaper_main_results(qtlresults) traits_ids = [# sort numerically, but retain the ids as strings str(i) for i in sorted({int(row["ID"]) for row in qtlresults})] chromosome_names = sorted( - {row["Chr"] for row in qtlresults}, key = chromosome_sorter_key_fn) - loci_names = sorted({row["Locus"] for row in qtlresults}) - ordered_traits_names = { - res_id: trait for res_id, trait in + {row["Chr"] for row in qtlresults}, key=chromosome_sorter_key_fn) + # loci_names = sorted({row["Locus"] for row in qtlresults}) + ordered_traits_names = dict( zip(traits_ids, - [traits[idx]["trait_fullname"] for idx in traits_order])} + [traits[idx]["trait_fullname"] for idx in traits_order])) return generate_clustered_heatmap( process_traits_data_for_heatmap( @@ -207,22 +206,11 @@ def build_heatmap(traits_names, conn: Any): "single_heatmap_{}".format(random_string(10)), y_axis=tuple( ordered_traits_names[traits_ids[order]] - for order in traits_order), + for order in traits_order), y_label="Traits", - x_axis=[chromo for chromo in chromosome_names], + x_axis=chromosome_names, x_label="Chromosomes") - return { - "slink_data": slink_data, - "ordering_data": ordering_data, - "strainlist": strainlist, - "genotype_filename": genotype_filename, - "traits_list": traits_list, - "traits_data_list": traits_data_list, - "exported_traits_data_list": exported_traits_data_list, - "traits_filename": traits_filename - } - def compute_traits_order(slink_data, neworder: tuple = tuple()): """ Compute the order of the traits for clustering from `slink_data`. @@ -314,7 +302,7 @@ def get_nearest_marker(traits_list, genotype): https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L419-L438 """ if not genotype["Mbmap"]: - return [None] * len(trait_list) + return [None] * len(traits_list) marker_finder = nearest_marker_finder(genotype) return [marker_finder(trait) for trait in traits_list] @@ -340,10 +328,10 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names): return hdata def generate_clustered_heatmap( - data, clustering_data, image_filename_prefix, x_axis = None, - x_label: str = "", y_axis = None, y_label: str = "", + data, clustering_data, image_filename_prefix, x_axis=None, + x_label: str = "", y_axis=None, y_label: str = "", output_dir: str = TMPDIR, - colorscale = ( + colorscale=( (0.0, '#5D5D5D'), (0.4999999999999999, '#ABABAB'), (0.5, '#F5DE11'), (1.0, '#FF0D00'))): """ @@ -357,15 +345,15 @@ def generate_clustered_heatmap( shared_yaxes="rows", horizontal_spacing=0.001, subplot_titles=["distance"] + x_axis, - figure = ff.create_dendrogram( + figure=ff.create_dendrogram( np.array(clustering_data), orientation="right", labels=y_axis)) hms = [go.Heatmap( name=chromo, - y = y_axis, - z = data_array, + y=y_axis, + z=data_array, showscale=False) for chromo, data_array in zip(x_axis, data)] - for i, hm in enumerate(hms): - fig.add_trace(hm, row=1, col=(i + 2)) + for i, heatmap in enumerate(hms): + fig.add_trace(heatmap, row=1, col=(i + 2)) fig.update_layout( { @@ -380,8 +368,8 @@ def generate_clustered_heatmap( x_axes_layouts = { "xaxis{}".format(i+1 if i > 0 else ""): { "mirror": False, - "showticklabels": True if i==0 else False, - "ticks": "outside" if i==0 else "" + "showticklabels": True if i == 0 else False, + "ticks": "outside" if i == 0 else "" } for i in range(num_cols)} diff --git a/tests/unit/computations/test_qtlreaper.py b/tests/unit/computations/test_qtlreaper.py index 1d67827..d420470 100644 --- a/tests/unit/computations/test_qtlreaper.py +++ b/tests/unit/computations/test_qtlreaper.py @@ -77,6 +77,7 @@ class TestQTLReaper(TestCase): 5.82775, 5.89659, 5.92117, 5.93396, 5.93396, 5.94957]) def test_organise_reaper_main_results(self): + """Check that results are organised correctly.""" self.assertEqual( organise_reaper_main_results([ { @@ -135,7 +136,7 @@ class TestQTLReaper(TestCase): 1: {"Chr": 1, "loci": [ { - "Locus": "rs31443144", "cM": 1.500, "Mb": 3.010, + "Locus": "rs31443144", "cM": 1.500, "Mb": 3.010, "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 }, { diff --git a/tests/unit/test_heatmaps.py b/tests/unit/test_heatmaps.py index f3a81c5..c0a496b 100644 --- a/tests/unit/test_heatmaps.py +++ b/tests/unit/test_heatmaps.py @@ -189,6 +189,7 @@ class TestHeatmap(TestCase): retrieve_strains_and_values(orders, slist, tdata), expected) def test_get_lrs_from_chr(self): + """Check that function gets correct LRS values""" for trait, chromosome, expected in [ [{"chromosomes": {}}, 3, [None]], [{"chromosomes": {3: {"loci": [ @@ -202,6 +203,7 @@ class TestHeatmap(TestCase): self.assertEqual(get_lrs_from_chr(trait, chromosome), expected) def test_process_traits_data_for_heatmap(self): + """Check for correct processing of data for heatmap generation.""" self.assertEqual( process_traits_data_for_heatmap( {"1": { @@ -210,7 +212,7 @@ class TestHeatmap(TestCase): 1: {"Chr": 1, "loci": [ { - "Locus": "rs31443144", "cM": 1.500, "Mb": 3.010, + "Locus": "rs31443144", "cM": 1.500, "Mb": 3.010, "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 }, { @@ -257,7 +259,7 @@ class TestHeatmap(TestCase): 1: {"Chr": 1, "loci": [ { - "Locus": "rs31443144", "cM": 1.500, "Mb": 3.010, + "Locus": "rs31443144", "cM": 1.500, "Mb": 3.010, "LRS": 0.500, "Additive": -0.074, "pValue": 1.000 }, { -- cgit v1.2.3 From b7fb10586b956a8b0389e7925e4c0cff28cde82f Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Mon, 20 Sep 2021 06:36:00 +0300 Subject: Return only the data Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * gn3/api/heatmaps.py: Parse incoming data to build up correct trait names and respond with only the computed heatmap data. * gn3/heatmaps.py: Return only the computed data for heatmaps and clustering. Since GN3 is supposed to handle only the data, and db-access, this commit ensures that GN3 responds to the client with only the computed heatmap data, and does not try to generate the heatmaps themselves. The generation of the heatmaps will be delegated to the UI clients, such as GeneNetwork2. --- gn3/api/heatmaps.py | 18 ++++++------------ gn3/heatmaps.py | 25 +++++++++++++++++-------- 2 files changed, 23 insertions(+), 20 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py index f053241..eea3ebe 100644 --- a/gn3/api/heatmaps.py +++ b/gn3/api/heatmaps.py @@ -15,15 +15,9 @@ def clustered_heatmaps(): "message": "You need to provide at least one trait name." }), 400 conn, _cursor = database_connector() - _heatmap_file, heatmap_fig = build_heatmap(traits_names, conn) - - # stream the heatmap data somehow here. - # Can plotly actually stream the figure data in a way that can be used on - # remote end to display the image without necessarily being html? - # return jsonify( - # { - # "query": heatmap_request, - # "output_png": heatmap_fig.to_image(format="png"), - # "output_svg": heatmap_fig.to_image(format="svg") - # }), 200 - return jsonify({"output_filename": _heatmap_file}), 200 + def setup_trait_fullname(trait): + name_parts = trait.split(":") + return "{dataset_name}::{trait_name}".format( + dataset_name=trait[1], trait_name=trait[0]) + traits_fullnames = [parse_trait_fullname(trait) for trait in traits_names] + return jsonify(build_heatmap(traits_fullnames, conn)), 200 diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index c4fc67d..205a3b3 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -199,17 +199,26 @@ def build_heatmap(traits_names, conn: Any): zip(traits_ids, [traits[idx]["trait_fullname"] for idx in traits_order])) - return generate_clustered_heatmap( - process_traits_data_for_heatmap( + # return generate_clustered_heatmap( + # process_traits_data_for_heatmap( + # organised, traits_ids, chromosome_names), + # clustered, + # "single_heatmap_{}".format(random_string(10)), + # y_axis=tuple( + # ordered_traits_names[traits_ids[order]] + # for order in traits_order), + # y_label="Traits", + # x_axis=chromosome_names, + # x_label="Chromosomes") + return { + "clustering_data": clustered, + "heatmap_data": process_traits_data_for_heatmap( organised, traits_ids, chromosome_names), - clustered, - "single_heatmap_{}".format(random_string(10)), - y_axis=tuple( + "traits": tuple( ordered_traits_names[traits_ids[order]] for order in traits_order), - y_label="Traits", - x_axis=chromosome_names, - x_label="Chromosomes") + "chromosomes": chromosome_names + } def compute_traits_order(slink_data, neworder: tuple = tuple()): """ -- cgit v1.2.3 From 920be820e9cefe1dcde86d9a252f098c67a2bb8b Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Wed, 22 Sep 2021 07:03:53 +0300 Subject: Return serialized plotly figure Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * gn3/api/heatmaps.py: Serialize the figure to JSON * gn3/heatmaps.py: Return the figure object Serialize the Plotly figure into JSON, and return that, so that it can be used on the client to display the image. --- gn3/api/heatmaps.py | 8 +++++++- gn3/heatmaps.py | 27 ++++++++------------------- 2 files changed, 15 insertions(+), 20 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py index 43ac580..0493f8a 100644 --- a/gn3/api/heatmaps.py +++ b/gn3/api/heatmaps.py @@ -1,3 +1,4 @@ +import io from flask import jsonify from flask import request from flask import Blueprint @@ -20,4 +21,9 @@ def clustered_heatmaps(): return "{dataset_name}::{trait_name}".format( dataset_name=name_parts[1], trait_name=name_parts[0]) traits_fullnames = [parse_trait_fullname(trait) for trait in traits_names] - return jsonify(build_heatmap(traits_fullnames, conn)), 200 + + with io.StringIO() as io_str: + _filename, figure = build_heatmap(traits_fullnames, conn) + figure.write_json(io_str) + fig_json = io_str.getvalue() + return fig_json, 200 diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index 205a3b3..cd93b3f 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -187,38 +187,27 @@ def build_heatmap(traits_names, conn: Any): genotype_filename, traits_filename, separate_nperm_output=True) qtlresults = parse_reaper_main_results(main_output) - # permudata = parse_reaper_permutation_results(permutations_output) organised = organise_reaper_main_results(qtlresults) traits_ids = [# sort numerically, but retain the ids as strings str(i) for i in sorted({int(row["ID"]) for row in qtlresults})] chromosome_names = sorted( {row["Chr"] for row in qtlresults}, key=chromosome_sorter_key_fn) - # loci_names = sorted({row["Locus"] for row in qtlresults}) ordered_traits_names = dict( zip(traits_ids, [traits[idx]["trait_fullname"] for idx in traits_order])) - # return generate_clustered_heatmap( - # process_traits_data_for_heatmap( - # organised, traits_ids, chromosome_names), - # clustered, - # "single_heatmap_{}".format(random_string(10)), - # y_axis=tuple( - # ordered_traits_names[traits_ids[order]] - # for order in traits_order), - # y_label="Traits", - # x_axis=chromosome_names, - # x_label="Chromosomes") - return { - "clustering_data": clustered, - "heatmap_data": process_traits_data_for_heatmap( + return generate_clustered_heatmap( + process_traits_data_for_heatmap( organised, traits_ids, chromosome_names), - "traits": tuple( + clustered, + "single_heatmap_{}".format(random_string(10)), + y_axis=tuple( ordered_traits_names[traits_ids[order]] for order in traits_order), - "chromosomes": chromosome_names - } + y_label="Traits", + x_axis=chromosome_names, + x_label="Chromosomes") def compute_traits_order(slink_data, neworder: tuple = tuple()): """ -- cgit v1.2.3 From cd7f301688fd9780df1f842f8bd2b7602775ba1f Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Wed, 22 Sep 2021 07:53:53 +0300 Subject: Fix pylint errors * Add missing function and module docstrings * Remove unused imports * Fix import order * Rework some code sections to fix issues * Disable some pylint errors. --- gn3/api/heatmaps.py | 8 ++++++++ gn3/app.py | 5 +++-- gn3/computations/qtlreaper.py | 8 ++++++++ gn3/db/genotypes.py | 1 + gn3/db/traits.py | 2 +- gn3/heatmaps.py | 28 ++++++++++++++++------------ 6 files changed, 37 insertions(+), 15 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py index 1022a35..fe47aee 100644 --- a/gn3/api/heatmaps.py +++ b/gn3/api/heatmaps.py @@ -1,3 +1,7 @@ +""" +Module to hold the entrypoint functions that generate heatmaps +""" + import io from flask import jsonify from flask import request @@ -9,6 +13,10 @@ heatmaps = Blueprint("heatmaps", __name__) @heatmaps.route("/clustered", methods=("POST",)) def clustered_heatmaps(): + """ + Parses the incoming data and responds with the JSON-serialized plotly figure + representing the clustered heatmap. + """ heatmap_request = request.get_json() traits_names = heatmap_request.get("traits_names", tuple()) if len(traits_names) < 2: diff --git a/gn3/app.py b/gn3/app.py index 6b4c57e..8badb65 100644 --- a/gn3/app.py +++ b/gn3/app.py @@ -3,7 +3,10 @@ import os from typing import Dict from typing import Union + from flask import Flask +from flask_cors import CORS + from gn3.api.gemma import gemma from gn3.api.rqtl import rqtl from gn3.api.general import general @@ -11,8 +14,6 @@ from gn3.api.heatmaps import heatmaps from gn3.api.correlation import correlation from gn3.api.data_entry import data_entry -from flask_cors import CORS - def create_app(config: Union[Dict, str, None] = None) -> Flask: """Create a new flask object""" app = Flask(__name__) diff --git a/gn3/computations/qtlreaper.py b/gn3/computations/qtlreaper.py index 377db9b..5d17fed 100644 --- a/gn3/computations/qtlreaper.py +++ b/gn3/computations/qtlreaper.py @@ -87,11 +87,17 @@ def run_reaper( return (output_filename, permu_output_filename) def chromosome_sorter_key_fn(val): + """ + Useful for sorting the chromosomes + """ if isinstance(val, int): return val return ord(val) def organise_reaper_main_results(parsed_results): + """ + Provide the results of running reaper in a format that is easier to use. + """ def __organise_by_chromosome(chr_name, items): chr_items = [item for item in items if item["Chr"] == chr_name] return { @@ -129,12 +135,14 @@ def parse_reaper_main_results(results_file): lines = infile.readlines() def __parse_column_float_value(value): + # pylint: disable=W0702 try: return float(value) except: return value def __parse_column_int_value(value): + # pylint: disable=W0702 try: return int(value) except: diff --git a/gn3/db/genotypes.py b/gn3/db/genotypes.py index 9d052d9..919c539 100644 --- a/gn3/db/genotypes.py +++ b/gn3/db/genotypes.py @@ -115,6 +115,7 @@ def parse_genotype_marker(line: str, geno_obj: dict, parlist: list): Reworks https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/utility/gen_geno_ob.py#L143-L190 """ + # pylint: disable=W0702 marker_row = [item.strip() for item in line.split("\t")] geno_table = { geno_obj["mat"]: -1, geno_obj["pat"]: 1, geno_obj["het"]: 0, diff --git a/gn3/db/traits.py b/gn3/db/traits.py index bfe887e..747ed27 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -46,7 +46,7 @@ def update_sample_data(conn: Any, count: Union[int, str]): """Given the right parameters, update sample-data from the relevant table.""" - # pylint: disable=[R0913, R0914] + # pylint: disable=[R0913, R0914, C0103] STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s" PUBLISH_DATA_SQL: str = ("UPDATE PublishData SET value = %s " "WHERE StrainId = %s AND Id = %s") diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index cd93b3f..9d82fb2 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -3,29 +3,28 @@ This module will contain functions to be used in computation of the data used to generate various kinds of heatmaps. """ +from functools import reduce from typing import Any, Dict, Sequence + import numpy as np -from functools import reduce -from gn3.settings import TMPDIR import plotly.graph_objects as go import plotly.figure_factory as ff +from plotly.subplots import make_subplots + +from gn3.settings import TMPDIR from gn3.random import random_string from gn3.computations.slink import slink -from plotly.subplots import make_subplots from gn3.computations.correlations2 import compute_correlation from gn3.db.genotypes import ( - build_genotype_file, load_genotype_samples, parse_genotype_file) + build_genotype_file, load_genotype_samples) from gn3.db.traits import ( - retrieve_trait_data, - retrieve_trait_info, - generate_traits_filename) + retrieve_trait_data, retrieve_trait_info) from gn3.computations.qtlreaper import ( run_reaper, generate_traits_file, chromosome_sorter_key_fn, parse_reaper_main_results, - organise_reaper_main_results, - parse_reaper_permutation_results) + organise_reaper_main_results) def export_trait_data( trait_data: dict, strainlist: Sequence[str], dtype: str = "val", @@ -159,13 +158,13 @@ def build_heatmap(traits_names, conn: Any): PARAMETERS: TODO: Elaborate on the parameters here... """ + # pylint: disable=[R0914] threshold = 0 # webqtlConfig.PUBLICTHRESH traits = [ retrieve_trait_info(threshold, fullname, conn) for fullname in traits_names] traits_data_list = [retrieve_trait_data(t, conn) for t in traits] genotype_filename = build_genotype_file(traits[0]["riset"]) - # genotype = parse_genotype_file(genotype_filename) strains = load_genotype_samples(genotype_filename) exported_traits_data_list = [ export_trait_data(td, strains) for td in traits_data_list] @@ -336,6 +335,7 @@ def generate_clustered_heatmap( Generate a dendrogram, and heatmaps for each chromosome, and put them all into one plot. """ + # pylint: disable=[R0913, R0914] num_cols = 1 + len(x_axis) fig = make_subplots( rows=1, @@ -359,14 +359,18 @@ def generate_clustered_heatmap( "height": 800, "xaxis": { "mirror": False, - "showgrid": True + "showgrid": True, + "title": x_label + }, + "yaxis": { + "title": y_label } }) x_axes_layouts = { "xaxis{}".format(i+1 if i > 0 else ""): { "mirror": False, - "showticklabels": True if i == 0 else False, + "showticklabels": i == 0, "ticks": "outside" if i == 0 else "" } for i in range(num_cols)} -- cgit v1.2.3 From 71cc35e5178904b512b9007e33be17a36f6656f2 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Wed, 22 Sep 2021 08:36:11 +0300 Subject: Fix typing issues * Ignore some errors * Update typing definitions for some portions of code * Add missing imports --- gn3/app.py | 2 +- gn3/computations/qtlreaper.py | 6 ++++-- gn3/db/genotypes.py | 10 ++++++---- gn3/db/traits.py | 8 ++++---- gn3/heatmaps.py | 8 +++----- 5 files changed, 18 insertions(+), 16 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/app.py b/gn3/app.py index 8badb65..5e852e1 100644 --- a/gn3/app.py +++ b/gn3/app.py @@ -5,7 +5,7 @@ from typing import Dict from typing import Union from flask import Flask -from flask_cors import CORS +from flask_cors import CORS # type: ignore from gn3.api.gemma import gemma from gn3.api.rqtl import rqtl diff --git a/gn3/computations/qtlreaper.py b/gn3/computations/qtlreaper.py index 5d17fed..5ddea76 100644 --- a/gn3/computations/qtlreaper.py +++ b/gn3/computations/qtlreaper.py @@ -4,6 +4,8 @@ computation of QTLs. """ import os import subprocess +from typing import Union + from gn3.random import random_string from gn3.settings import TMPDIR, REAPER_COMMAND @@ -70,9 +72,9 @@ def run_reaper( output_dir, random_string(10)) output_list = ["--main_output", output_filename] if separate_nperm_output: - permu_output_filename = "{}/qtlreaper/permu_output_{}.txt".format( + permu_output_filename: Union[None, str] = "{}/qtlreaper/permu_output_{}.txt".format( output_dir, random_string(10)) - output_list = output_list + ["--permu_output", permu_output_filename] + output_list = output_list + ["--permu_output", permu_output_filename] # type: ignore[list-item] else: permu_output_filename = None diff --git a/gn3/db/genotypes.py b/gn3/db/genotypes.py index 919c539..9ea9f20 100644 --- a/gn3/db/genotypes.py +++ b/gn3/db/genotypes.py @@ -2,6 +2,8 @@ import os import gzip +from typing import Union, TextIO + from gn3.settings import GENOTYPE_FILES def build_genotype_file( @@ -44,17 +46,17 @@ def __load_genotype_samples_from_geno(genotype_filename: str): """ gzipped_filename = "{}.gz".format(genotype_filename) if os.path.isfile(gzipped_filename): - genofile = gzip.open(gzipped_filename) + genofile: Union[TextIO, gzip.GzipFile] = gzip.open(gzipped_filename) else: genofile = open(genotype_filename) for row in genofile: line = row.strip() - if (not line) or (line.startswith(("#", "@"))): + if (not line) or (line.startswith(("#", "@"))): # type: ignore[arg-type] continue break - headers = line.split("\t") + headers = line.split("\t" ) # type: ignore[arg-type] if headers[3] == "Mb": return headers[4:] return headers[3:] @@ -107,7 +109,7 @@ def parse_genotype_header(line: str, parlist: tuple = tuple()): ("prgy", prgy), ("nprgy", len(prgy))) -def parse_genotype_marker(line: str, geno_obj: dict, parlist: list): +def parse_genotype_marker(line: str, geno_obj: dict, parlist: tuple): """ Parse a data line in a genotype file diff --git a/gn3/db/traits.py b/gn3/db/traits.py index 747ed27..4fc47c3 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -63,22 +63,22 @@ def update_sample_data(conn: Any, with conn.cursor() as cursor: # Update the Strains table cursor.execute(STRAIN_ID_SQL, (strain_name, strain_id)) - updated_strains: int = cursor.rowcount + updated_strains = cursor.rowcount # Update the PublishData table cursor.execute(PUBLISH_DATA_SQL, (None if value == "x" else value, strain_id, publish_data_id)) - updated_published_data: int = cursor.rowcount + updated_published_data = cursor.rowcount # Update the PublishSE table cursor.execute(PUBLISH_SE_SQL, (None if error == "x" else error, strain_id, publish_data_id)) - updated_se_data: int = cursor.rowcount + updated_se_data = cursor.rowcount # Update the NStrain table cursor.execute(N_STRAIN_SQL, (None if count == "x" else count, strain_id, publish_data_id)) - updated_n_strains: int = cursor.rowcount + updated_n_strains = cursor.rowcount return (updated_strains, updated_published_data, updated_se_data, updated_n_strains) diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index 9d82fb2..45d0c22 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -7,9 +7,9 @@ from functools import reduce from typing import Any, Dict, Sequence import numpy as np -import plotly.graph_objects as go -import plotly.figure_factory as ff -from plotly.subplots import make_subplots +import plotly.graph_objects as go # type: ignore +import plotly.figure_factory as ff # type: ignore +from plotly.subplots import make_subplots # type: ignore from gn3.settings import TMPDIR from gn3.random import random_string @@ -171,8 +171,6 @@ def build_heatmap(traits_names, conn: Any): clustered = cluster_traits(exported_traits_data_list) slinked = slink(clustered) traits_order = compute_traits_order(slinked) - ordered_traits_names = [ - traits[idx]["trait_fullname"] for idx in traits_order] strains_and_values = retrieve_strains_and_values( traits_order, strains, exported_traits_data_list) traits_filename = "{}/traits_test_file_{}.txt".format( -- cgit v1.2.3 From 19783a18c2bc7941fc5980e593f19fb1d18c3623 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Mon, 27 Sep 2021 04:48:53 +0300 Subject: Update terminology: `strain` to `sample` Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Update the terminology used: use `sample` in place of `strain` according to Zachary's direction at https://github.com/genenetwork/genenetwork3/pull/37#issuecomment-926043306 --- gn3/computations/parsers.py | 10 ++--- gn3/computations/qtlreaper.py | 8 ++-- gn3/db/genotypes.py | 8 ++-- gn3/db/traits.py | 44 ++++++++++----------- gn3/heatmaps.py | 62 ++++++++++++++--------------- tests/unit/computations/test_parsers.py | 4 +- tests/unit/test_heatmaps.py | 70 ++++++++++++++++----------------- 7 files changed, 103 insertions(+), 103 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/computations/parsers.py b/gn3/computations/parsers.py index 94387ff..1af35d6 100644 --- a/gn3/computations/parsers.py +++ b/gn3/computations/parsers.py @@ -14,7 +14,7 @@ def parse_genofile(file_path: str) -> Tuple[List[str], 'h': 0, 'u': None, } - genotypes, strains = [], [] + genotypes, samples = [], [] with open(file_path, "r") as _genofile: for line in _genofile: line = line.strip() @@ -22,8 +22,8 @@ def parse_genofile(file_path: str) -> Tuple[List[str], continue cells = line.split() if line.startswith("Chr"): - strains = cells[4:] - strains = [strain.lower() for strain in strains] + samples = cells[4:] + samples = [sample.lower() for sample in samples] continue values = [__map.get(value.lower(), None) for value in cells[4:]] genotype = { @@ -32,7 +32,7 @@ def parse_genofile(file_path: str) -> Tuple[List[str], "cm": cells[2], "mb": cells[3], "values": values, - "dicvalues": dict(zip(strains, values)), + "dicvalues": dict(zip(samples, values)), } genotypes.append(genotype) - return strains, genotypes + return samples, genotypes diff --git a/gn3/computations/qtlreaper.py b/gn3/computations/qtlreaper.py index 8b2893e..166d2dd 100644 --- a/gn3/computations/qtlreaper.py +++ b/gn3/computations/qtlreaper.py @@ -9,17 +9,17 @@ from typing import Union from gn3.random import random_string from gn3.settings import TMPDIR, REAPER_COMMAND -def generate_traits_file(strains, trait_values, traits_filename): +def generate_traits_file(samples, trait_values, traits_filename): """ Generate a traits file for use with `qtlreaper`. PARAMETERS: - strains: A list of strains to use as the headers for the various columns. - trait_values: A list of lists of values for each trait and strain. + samples: A list of samples to use as the headers for the various columns. + trait_values: A list of lists of values for each trait and sample. traits_filename: The tab-separated value to put the values in for computation of QTLs. """ - header = "Trait\t{}\n".format("\t".join(strains)) + header = "Trait\t{}\n".format("\t".join(samples)) data = ( [header] + ["{}\t{}\n".format(i+1, "\t".join([str(i) for i in t])) diff --git a/gn3/db/genotypes.py b/gn3/db/genotypes.py index 9987320..8f18cac 100644 --- a/gn3/db/genotypes.py +++ b/gn3/db/genotypes.py @@ -14,16 +14,16 @@ def build_genotype_file( def load_genotype_samples(genotype_filename: str, file_type: str = "geno"): """ - Load sample of strains from genotype files. + Load sample of samples from genotype files. DESCRIPTION: - Traits can contain a varied number of strains, some of which do not exist in + Traits can contain a varied number of samples, some of which do not exist in certain genotypes. In order to compute QTLs, GEMMAs, etc, we need to ensure - to pick only those strains that exist in the genotype under consideration + to pick only those samples that exist in the genotype under consideration for the traits used in the computation. This function loads a list of samples from the genotype files for use in - filtering out unusable strains. + filtering out unusable samples. PARAMETERS: diff --git a/gn3/db/traits.py b/gn3/db/traits.py index 4fc47c3..c9d05d7 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -445,7 +445,7 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any): query, {"trait_name": trait_info["trait_name"]}) return [dict(zip( - ["strain_name", "value", "se_error", "nstrain", "id"], row)) + ["sample_name", "value", "se_error", "nstrain", "id"], row)) for row in cursor.fetchall()] return [] @@ -484,7 +484,7 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any): "species_id": retrieve_species_id( trait_info["db"]["riset"], conn)}) return [dict(zip( - ["strain_name", "value", "se_error", "id"], row)) + ["sample_name", "value", "se_error", "id"], row)) for row in cursor.fetchall()] return [] @@ -515,7 +515,7 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any): {"trait_name": trait_info["trait_name"], "dataset_id": trait_info["db"]["dataset_id"]}) return [dict(zip( - ["strain_name", "value", "se_error", "nstrain", "id"], row)) + ["sample_name", "value", "se_error", "nstrain", "id"], row)) for row in cursor.fetchall()] return [] @@ -548,7 +548,7 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any): "trait_name": trait_info["trait_name"], "dataset_id": trait_info["db"]["dataset_id"]}) return [dict(zip( - ["strain_name", "value", "se_error", "id"], row)) + ["sample_name", "value", "se_error", "id"], row)) for row in cursor.fetchall()] return [] @@ -577,29 +577,29 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any): {"trait_name": trait_info["trait_name"], "dataset_name": trait_info["db"]["dataset_name"]}) return [dict(zip( - ["strain_name", "value", "se_error", "id"], row)) + ["sample_name", "value", "se_error", "id"], row)) for row in cursor.fetchall()] return [] -def with_strainlist_data_setup(strainlist: Sequence[str]): +def with_samplelist_data_setup(samplelist: Sequence[str]): """ - Build function that computes the trait data from provided list of strains. + Build function that computes the trait data from provided list of samples. PARAMETERS - strainlist: (list) - A list of strain names + samplelist: (list) + A list of sample names RETURNS: Returns a function that given some data from the database, computes the - strain's value, variance and ndata values, only if the strain is present - in the provided `strainlist` variable. + sample's value, variance and ndata values, only if the sample is present + in the provided `samplelist` variable. """ def setup_fn(tdata): - if tdata["strain_name"] in strainlist: + if tdata["sample_name"] in samplelist: val = tdata["value"] if val is not None: return { - "strain_name": tdata["strain_name"], + "sample_name": tdata["sample_name"], "value": val, "variance": tdata["se_error"], "ndata": tdata.get("nstrain", None) @@ -607,19 +607,19 @@ def with_strainlist_data_setup(strainlist: Sequence[str]): return None return setup_fn -def without_strainlist_data_setup(): +def without_samplelist_data_setup(): """ Build function that computes the trait data. RETURNS: Returns a function that given some data from the database, computes the - strain's value, variance and ndata values. + sample's value, variance and ndata values. """ def setup_fn(tdata): val = tdata["value"] if val is not None: return { - "strain_name": tdata["strain_name"], + "sample_name": tdata["sample_name"], "value": val, "variance": tdata["se_error"], "ndata": tdata.get("nstrain", None) @@ -627,7 +627,7 @@ def without_strainlist_data_setup(): return None return setup_fn -def retrieve_trait_data(trait: dict, conn: Any, strainlist: Sequence[str] = tuple()): +def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tuple()): """ Retrieve trait data @@ -650,23 +650,23 @@ def retrieve_trait_data(trait: dict, conn: Any, strainlist: Sequence[str] = tupl if results: # do something with mysqlid mysqlid = results[0]["id"] - if strainlist: + if samplelist: data = [ item for item in - map(with_strainlist_data_setup(strainlist), results) + map(with_samplelist_data_setup(samplelist), results) if item is not None] else: data = [ item for item in - map(without_strainlist_data_setup(), results) + map(without_samplelist_data_setup(), results) if item is not None] return { "mysqlid": mysqlid, "data": dict(map( lambda x: ( - x["strain_name"], - {k:v for k, v in x.items() if x != "strain_name"}), + x["sample_name"], + {k:v for k, v in x.items() if x != "sample_name"}), data))} return {} diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index 45d0c22..b6fc6d3 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -27,10 +27,10 @@ from gn3.computations.qtlreaper import ( organise_reaper_main_results) def export_trait_data( - trait_data: dict, strainlist: Sequence[str], dtype: str = "val", + trait_data: dict, samplelist: Sequence[str], dtype: str = "val", var_exists: bool = False, n_exists: bool = False): """ - Export data according to `strainlist`. Mostly used in calculating + Export data according to `samplelist`. Mostly used in calculating correlations. DESCRIPTION: @@ -40,8 +40,8 @@ def export_trait_data( PARAMETERS trait: (dict) The dictionary of key-value pairs representing a trait - strainlist: (list) - A list of strain names + samplelist: (list) + A list of sample names dtype: (str) ... verify what this is ... var_exists: (bool) @@ -49,18 +49,18 @@ def export_trait_data( n_exists: (bool) A flag indicating existence of ndata """ - def __export_all_types(tdata, strain): + def __export_all_types(tdata, sample): sample_data = [] - if tdata[strain]["value"]: - sample_data.append(tdata[strain]["value"]) + if tdata[sample]["value"]: + sample_data.append(tdata[sample]["value"]) if var_exists: - if tdata[strain]["variance"]: - sample_data.append(tdata[strain]["variance"]) + if tdata[sample]["variance"]: + sample_data.append(tdata[sample]["variance"]) else: sample_data.append(None) if n_exists: - if tdata[strain]["ndata"]: - sample_data.append(tdata[strain]["ndata"]) + if tdata[sample]["ndata"]: + sample_data.append(tdata[sample]["ndata"]) else: sample_data.append(None) else: @@ -73,17 +73,17 @@ def export_trait_data( return tuple(sample_data) - def __exporter(accumulator, strain): + def __exporter(accumulator, sample): # pylint: disable=[R0911] - if strain in trait_data["data"]: + if sample in trait_data["data"]: if dtype == "val": - return accumulator + (trait_data["data"][strain]["value"], ) + return accumulator + (trait_data["data"][sample]["value"], ) if dtype == "var": - return accumulator + (trait_data["data"][strain]["variance"], ) + return accumulator + (trait_data["data"][sample]["variance"], ) if dtype == "N": - return accumulator + (trait_data["data"][strain]["ndata"], ) + return accumulator + (trait_data["data"][sample]["ndata"], ) if dtype == "all": - return accumulator + __export_all_types(trait_data["data"], strain) + return accumulator + __export_all_types(trait_data["data"], sample) raise KeyError("Type `%s` is incorrect" % dtype) if var_exists and n_exists: return accumulator + (None, None, None) @@ -91,7 +91,7 @@ def export_trait_data( return accumulator + (None, None) return accumulator + (None,) - return reduce(__exporter, strainlist, tuple()) + return reduce(__exporter, samplelist, tuple()) def trait_display_name(trait: Dict): """ @@ -165,19 +165,19 @@ def build_heatmap(traits_names, conn: Any): for fullname in traits_names] traits_data_list = [retrieve_trait_data(t, conn) for t in traits] genotype_filename = build_genotype_file(traits[0]["riset"]) - strains = load_genotype_samples(genotype_filename) + samples = load_genotype_samples(genotype_filename) exported_traits_data_list = [ - export_trait_data(td, strains) for td in traits_data_list] + export_trait_data(td, samples) for td in traits_data_list] clustered = cluster_traits(exported_traits_data_list) slinked = slink(clustered) traits_order = compute_traits_order(slinked) - strains_and_values = retrieve_strains_and_values( - traits_order, strains, exported_traits_data_list) + samples_and_values = retrieve_samples_and_values( + traits_order, samples, exported_traits_data_list) traits_filename = "{}/traits_test_file_{}.txt".format( TMPDIR, random_string(10)) generate_traits_file( - strains_and_values[0][1], - [t[2] for t in strains_and_values], + samples_and_values[0][1], + [t[2] for t in samples_and_values], traits_filename) main_output, _permutations_output = run_reaper( @@ -229,9 +229,9 @@ def compute_traits_order(slink_data, neworder: tuple = tuple()): return __order_maker(neworder, slink_data) -def retrieve_strains_and_values(orders, strainlist, traits_data_list): +def retrieve_samples_and_values(orders, samplelist, traits_data_list): """ - Get the strains and their corresponding values from `strainlist` and + Get the samples and their corresponding values from `samplelist` and `traits_data_list`. This migrates the code in @@ -240,17 +240,17 @@ def retrieve_strains_and_values(orders, strainlist, traits_data_list): # This feels nasty! There's a lot of mutation of values here, that might # indicate something untoward in the design of this function and its # dependents ==> Review - strains = [] + samples = [] values = [] rets = [] for order in orders: temp_val = traits_data_list[order] - for i, strain in enumerate(strainlist): + for i, sample in enumerate(samplelist): if temp_val[i] is not None: - strains.append(strain) + samples.append(sample) values.append(temp_val[i]) - rets.append([order, strains[:], values[:]]) - strains = [] + rets.append([order, samples[:], values[:]]) + samples = [] values = [] return rets diff --git a/tests/unit/computations/test_parsers.py b/tests/unit/computations/test_parsers.py index 19c3067..b51b0bf 100644 --- a/tests/unit/computations/test_parsers.py +++ b/tests/unit/computations/test_parsers.py @@ -15,7 +15,7 @@ class TestParsers(unittest.TestCase): def test_parse_genofile_with_existing_file(self): """Test that a genotype file is parsed correctly""" - strains = ["bxd1", "bxd2"] + samples = ["bxd1", "bxd2"] genotypes = [ {"chr": "1", "locus": "rs31443144", "cm": "1.50", "mb": "3.010274", @@ -51,4 +51,4 @@ class TestParsers(unittest.TestCase): "../test_data/genotype.txt" )) self.assertEqual(parse_genofile( - test_genotype_file), (strains, genotypes)) + test_genotype_file), (samples, genotypes)) diff --git a/tests/unit/test_heatmaps.py b/tests/unit/test_heatmaps.py index fd91cf9..b54e2f3 100644 --- a/tests/unit/test_heatmaps.py +++ b/tests/unit/test_heatmaps.py @@ -5,41 +5,41 @@ from gn3.heatmaps import ( get_lrs_from_chr, export_trait_data, compute_traits_order, - retrieve_strains_and_values, + retrieve_samples_and_values, process_traits_data_for_heatmap) from tests.unit.sample_test_data import organised_trait_1, organised_trait_2 -strainlist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"] +samplelist = ["B6cC3-1", "BXD1", "BXD12", "BXD16", "BXD19", "BXD2"] trait_data = { "mysqlid": 36688172, "data": { - "B6cC3-1": {"strain_name": "B6cC3-1", "value": 7.51879, "variance": None, "ndata": None}, - "BXD1": {"strain_name": "BXD1", "value": 7.77141, "variance": None, "ndata": None}, - "BXD12": {"strain_name": "BXD12", "value": 8.39265, "variance": None, "ndata": None}, - "BXD16": {"strain_name": "BXD16", "value": 8.17443, "variance": None, "ndata": None}, - "BXD19": {"strain_name": "BXD19", "value": 8.30401, "variance": None, "ndata": None}, - "BXD2": {"strain_name": "BXD2", "value": 7.80944, "variance": None, "ndata": None}, - "BXD21": {"strain_name": "BXD21", "value": 8.93809, "variance": None, "ndata": None}, - "BXD24": {"strain_name": "BXD24", "value": 7.99415, "variance": None, "ndata": None}, - "BXD27": {"strain_name": "BXD27", "value": 8.12177, "variance": None, "ndata": None}, - "BXD28": {"strain_name": "BXD28", "value": 7.67688, "variance": None, "ndata": None}, - "BXD32": {"strain_name": "BXD32", "value": 7.79062, "variance": None, "ndata": None}, - "BXD39": {"strain_name": "BXD39", "value": 8.27641, "variance": None, "ndata": None}, - "BXD40": {"strain_name": "BXD40", "value": 8.18012, "variance": None, "ndata": None}, - "BXD42": {"strain_name": "BXD42", "value": 7.82433, "variance": None, "ndata": None}, - "BXD6": {"strain_name": "BXD6", "value": 8.09718, "variance": None, "ndata": None}, - "BXH14": {"strain_name": "BXH14", "value": 7.97475, "variance": None, "ndata": None}, - "BXH19": {"strain_name": "BXH19", "value": 7.67223, "variance": None, "ndata": None}, - "BXH2": {"strain_name": "BXH2", "value": 7.93622, "variance": None, "ndata": None}, - "BXH22": {"strain_name": "BXH22", "value": 7.43692, "variance": None, "ndata": None}, - "BXH4": {"strain_name": "BXH4", "value": 7.96336, "variance": None, "ndata": None}, - "BXH6": {"strain_name": "BXH6", "value": 7.75132, "variance": None, "ndata": None}, - "BXH7": {"strain_name": "BXH7", "value": 8.12927, "variance": None, "ndata": None}, - "BXH8": {"strain_name": "BXH8", "value": 6.77338, "variance": None, "ndata": None}, - "BXH9": {"strain_name": "BXH9", "value": 8.03836, "variance": None, "ndata": None}, - "C3H/HeJ": {"strain_name": "C3H/HeJ", "value": 7.42795, "variance": None, "ndata": None}, - "C57BL/6J": {"strain_name": "C57BL/6J", "value": 7.50606, "variance": None, "ndata": None}, - "DBA/2J": {"strain_name": "DBA/2J", "value": 7.72588, "variance": None, "ndata": None}}} + "B6cC3-1": {"sample_name": "B6cC3-1", "value": 7.51879, "variance": None, "ndata": None}, + "BXD1": {"sample_name": "BXD1", "value": 7.77141, "variance": None, "ndata": None}, + "BXD12": {"sample_name": "BXD12", "value": 8.39265, "variance": None, "ndata": None}, + "BXD16": {"sample_name": "BXD16", "value": 8.17443, "variance": None, "ndata": None}, + "BXD19": {"sample_name": "BXD19", "value": 8.30401, "variance": None, "ndata": None}, + "BXD2": {"sample_name": "BXD2", "value": 7.80944, "variance": None, "ndata": None}, + "BXD21": {"sample_name": "BXD21", "value": 8.93809, "variance": None, "ndata": None}, + "BXD24": {"sample_name": "BXD24", "value": 7.99415, "variance": None, "ndata": None}, + "BXD27": {"sample_name": "BXD27", "value": 8.12177, "variance": None, "ndata": None}, + "BXD28": {"sample_name": "BXD28", "value": 7.67688, "variance": None, "ndata": None}, + "BXD32": {"sample_name": "BXD32", "value": 7.79062, "variance": None, "ndata": None}, + "BXD39": {"sample_name": "BXD39", "value": 8.27641, "variance": None, "ndata": None}, + "BXD40": {"sample_name": "BXD40", "value": 8.18012, "variance": None, "ndata": None}, + "BXD42": {"sample_name": "BXD42", "value": 7.82433, "variance": None, "ndata": None}, + "BXD6": {"sample_name": "BXD6", "value": 8.09718, "variance": None, "ndata": None}, + "BXH14": {"sample_name": "BXH14", "value": 7.97475, "variance": None, "ndata": None}, + "BXH19": {"sample_name": "BXH19", "value": 7.67223, "variance": None, "ndata": None}, + "BXH2": {"sample_name": "BXH2", "value": 7.93622, "variance": None, "ndata": None}, + "BXH22": {"sample_name": "BXH22", "value": 7.43692, "variance": None, "ndata": None}, + "BXH4": {"sample_name": "BXH4", "value": 7.96336, "variance": None, "ndata": None}, + "BXH6": {"sample_name": "BXH6", "value": 7.75132, "variance": None, "ndata": None}, + "BXH7": {"sample_name": "BXH7", "value": 8.12927, "variance": None, "ndata": None}, + "BXH8": {"sample_name": "BXH8", "value": 6.77338, "variance": None, "ndata": None}, + "BXH9": {"sample_name": "BXH9", "value": 8.03836, "variance": None, "ndata": None}, + "C3H/HeJ": {"sample_name": "C3H/HeJ", "value": 7.42795, "variance": None, "ndata": None}, + "C57BL/6J": {"sample_name": "C57BL/6J", "value": 7.50606, "variance": None, "ndata": None}, + "DBA/2J": {"sample_name": "DBA/2J", "value": 7.72588, "variance": None, "ndata": None}}} slinked = ( (((0, 2, 0.16381088984330505), @@ -66,7 +66,7 @@ class TestHeatmap(TestCase): ["all", (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944)]]: with self.subTest(dtype=dtype): self.assertEqual( - export_trait_data(trait_data, strainlist, dtype=dtype), + export_trait_data(trait_data, samplelist, dtype=dtype), expected) def test_export_trait_data_dtype_all_flags(self): @@ -106,7 +106,7 @@ class TestHeatmap(TestCase): with self.subTest(dtype=dtype, vflag=vflag, nflag=nflag): self.assertEqual( export_trait_data( - trait_data, strainlist, dtype=dtype, var_exists=vflag, + trait_data, samplelist, dtype=dtype, var_exists=vflag, n_exists=nflag), expected) @@ -164,8 +164,8 @@ class TestHeatmap(TestCase): self.assertEqual( compute_traits_order(slinked), (0, 2, 1, 7, 5, 9, 3, 6, 8, 4)) - def test_retrieve_strains_and_values(self): - """Test retrieval of strains and values.""" + def test_retrieve_samples_and_values(self): + """Test retrieval of samples and values.""" for orders, slist, tdata, expected in [ [ [2], @@ -185,9 +185,9 @@ class TestHeatmap(TestCase): [6, None, None, 4, None]], [[3, ["s1", "s4"], [6, 4]]] ]]: - with self.subTest(strainlist=slist, traitdata=tdata): + with self.subTest(samplelist=slist, traitdata=tdata): self.assertEqual( - retrieve_strains_and_values(orders, slist, tdata), expected) + retrieve_samples_and_values(orders, slist, tdata), expected) def test_get_lrs_from_chr(self): """Check that function gets correct LRS values""" -- cgit v1.2.3 From 1d09a9222f8c661da3abd6d61c09ae19eeb5d793 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Mon, 27 Sep 2021 05:02:09 +0300 Subject: Update terminology: `riset` to `group` Issue: https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi * Update terminology to use the appropriate domain terminology according to Zachary's direction at https://github.com/genenetwork/genenetwork3/pull/37#issuecomment-926041744 --- gn3/db/datasets.py | 52 +++++++++++++++++++++--------------------- gn3/db/traits.py | 16 ++++++------- gn3/heatmaps.py | 2 +- tests/unit/db/test_datasets.py | 42 +++++++++++++++++----------------- 4 files changed, 56 insertions(+), 56 deletions(-) (limited to 'gn3/heatmaps.py') diff --git a/gn3/db/datasets.py b/gn3/db/datasets.py index 4a05499..6c328f5 100644 --- a/gn3/db/datasets.py +++ b/gn3/db/datasets.py @@ -119,9 +119,9 @@ def retrieve_dataset_name( return fn_map[trait_type](threshold, dataset_name, conn) -def retrieve_geno_riset_fields(name, conn): +def retrieve_geno_group_fields(name, conn): """ - Retrieve the RISet, and RISetID values for various Geno trait types. + Retrieve the Group, and GroupID values for various Geno trait types. """ query = ( "SELECT InbredSet.Name, InbredSet.Id " @@ -130,12 +130,12 @@ def retrieve_geno_riset_fields(name, conn): "AND GenoFreeze.Name = %(name)s") with conn.cursor() as cursor: cursor.execute(query, {"name": name}) - return dict(zip(["riset", "risetid"], cursor.fetchone())) + return dict(zip(["group", "groupid"], cursor.fetchone())) return {} -def retrieve_publish_riset_fields(name, conn): +def retrieve_publish_group_fields(name, conn): """ - Retrieve the RISet, and RISetID values for various Publish trait types. + Retrieve the Group, and GroupID values for various Publish trait types. """ query = ( "SELECT InbredSet.Name, InbredSet.Id " @@ -144,12 +144,12 @@ def retrieve_publish_riset_fields(name, conn): "AND PublishFreeze.Name = %(name)s") with conn.cursor() as cursor: cursor.execute(query, {"name": name}) - return dict(zip(["riset", "risetid"], cursor.fetchone())) + return dict(zip(["group", "groupid"], cursor.fetchone())) return {} -def retrieve_probeset_riset_fields(name, conn): +def retrieve_probeset_group_fields(name, conn): """ - Retrieve the RISet, and RISetID values for various ProbeSet trait types. + Retrieve the Group, and GroupID values for various ProbeSet trait types. """ query = ( "SELECT InbredSet.Name, InbredSet.Id " @@ -159,12 +159,12 @@ def retrieve_probeset_riset_fields(name, conn): "AND ProbeSetFreeze.Name = %(name)s") with conn.cursor() as cursor: cursor.execute(query, {"name": name}) - return dict(zip(["riset", "risetid"], cursor.fetchone())) + return dict(zip(["group", "groupid"], cursor.fetchone())) return {} -def retrieve_temp_riset_fields(name, conn): +def retrieve_temp_group_fields(name, conn): """ - Retrieve the RISet, and RISetID values for `Temp` trait types. + Retrieve the Group, and GroupID values for `Temp` trait types. """ query = ( "SELECT InbredSet.Name, InbredSet.Id " @@ -173,30 +173,30 @@ def retrieve_temp_riset_fields(name, conn): "AND Temp.Name = %(name)s") with conn.cursor() as cursor: cursor.execute(query, {"name": name}) - return dict(zip(["riset", "risetid"], cursor.fetchone())) + return dict(zip(["group", "groupid"], cursor.fetchone())) return {} -def retrieve_riset_fields(trait_type, trait_name, dataset_info, conn): +def retrieve_group_fields(trait_type, trait_name, dataset_info, conn): """ - Retrieve the RISet, and RISetID values for various trait types. + Retrieve the Group, and GroupID values for various trait types. """ - riset_fns_map = { - "Geno": retrieve_geno_riset_fields, - "Publish": retrieve_publish_riset_fields, - "ProbeSet": retrieve_probeset_riset_fields + group_fns_map = { + "Geno": retrieve_geno_group_fields, + "Publish": retrieve_publish_group_fields, + "ProbeSet": retrieve_probeset_group_fields } if trait_type == "Temp": - riset_info = retrieve_temp_riset_fields(trait_name, conn) + group_info = retrieve_temp_group_fields(trait_name, conn) else: - riset_info = riset_fns_map[trait_type](dataset_info["dataset_name"], conn) + group_info = group_fns_map[trait_type](dataset_info["dataset_name"], conn) return { **dataset_info, - **riset_info, - "riset": ( - "BXD" if riset_info.get("riset") == "BXD300" - else riset_info.get("riset", "")) + **group_info, + "group": ( + "BXD" if group_info.get("group") == "BXD300" + else group_info.get("group", "")) } def retrieve_temp_trait_dataset(): @@ -281,11 +281,11 @@ def retrieve_trait_dataset(trait_type, trait, threshold, conn): trait_type, threshold, trait["trait_name"], trait["db"]["dataset_name"], conn) } - riset = retrieve_riset_fields( + group = retrieve_group_fields( trait_type, trait["trait_name"], dataset_name_info, conn) return { "display_name": dataset_name_info["dataset_name"], **dataset_name_info, **dataset_fns[trait_type](), - **riset + **group } diff --git a/gn3/db/traits.py b/gn3/db/traits.py index c9d05d7..f2673c8 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -226,7 +226,7 @@ def set_homologene_id_field_probeset(trait_info, conn): """ query = ( "SELECT HomologeneId FROM Homologene, Species, InbredSet" - " WHERE Homologene.GeneId = %(geneid)s AND InbredSet.Name = %(riset)s" + " WHERE Homologene.GeneId = %(geneid)s AND InbredSet.Name = %(group)s" " AND InbredSet.SpeciesId = Species.Id AND" " Species.TaxonomyId = Homologene.TaxonomyId") with conn.cursor() as cursor: @@ -234,7 +234,7 @@ def set_homologene_id_field_probeset(trait_info, conn): query, { k:v for k, v in trait_info.items() - if k in ["geneid", "riset"] + if k in ["geneid", "group"] }) res = cursor.fetchone() if res: @@ -422,7 +422,7 @@ def retrieve_trait_info( if trait_info["haveinfo"]: return { **trait_post_processing_functions_table[trait_dataset_type]( - {**trait_info, "riset": trait_dataset["riset"]}), + {**trait_info, "group": trait_dataset["group"]}), "db": {**trait["db"], **trait_dataset} } return trait_info @@ -449,14 +449,14 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any): for row in cursor.fetchall()] return [] -def retrieve_species_id(riset, conn: Any): +def retrieve_species_id(group, conn: Any): """ - Retrieve a species id given the RISet value + Retrieve a species id given the Group value """ with conn.cursor as cursor: cursor.execute( - "SELECT SpeciesId from InbredSet WHERE Name = %(riset)s", - {"riset": riset}) + "SELECT SpeciesId from InbredSet WHERE Name = %(group)s", + {"group": group}) return cursor.fetchone()[0] return None @@ -482,7 +482,7 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any): {"trait_name": trait_info["trait_name"], "dataset_name": trait_info["db"]["dataset_name"], "species_id": retrieve_species_id( - trait_info["db"]["riset"], conn)}) + trait_info["db"]["group"], conn)}) return [dict(zip( ["sample_name", "value", "se_error", "id"], row)) for row in cursor.fetchall()] diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index b6fc6d3..a36940d 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -164,7 +164,7 @@ def build_heatmap(traits_names, conn: Any): retrieve_trait_info(threshold, fullname, conn) for fullname in traits_names] traits_data_list = [retrieve_trait_data(t, conn) for t in traits] - genotype_filename = build_genotype_file(traits[0]["riset"]) + genotype_filename = build_genotype_file(traits[0]["group"]) samples = load_genotype_samples(genotype_filename) exported_traits_data_list = [ export_trait_data(td, samples) for td in traits_data_list] diff --git a/tests/unit/db/test_datasets.py b/tests/unit/db/test_datasets.py index 38de0e2..39f4af9 100644 --- a/tests/unit/db/test_datasets.py +++ b/tests/unit/db/test_datasets.py @@ -3,10 +3,10 @@ from unittest import mock, TestCase from gn3.db.datasets import ( retrieve_dataset_name, - retrieve_riset_fields, - retrieve_geno_riset_fields, - retrieve_publish_riset_fields, - retrieve_probeset_riset_fields) + retrieve_group_fields, + retrieve_geno_group_fields, + retrieve_publish_group_fields, + retrieve_probeset_group_fields) class TestDatasetsDBFunctions(TestCase): """Test cases for datasets functions.""" @@ -40,9 +40,9 @@ class TestDatasetsDBFunctions(TestCase): table=table, cols=columns), {"threshold": thresh, "name": dataset_name}) - def test_retrieve_probeset_riset_fields(self): + def test_retrieve_probeset_group_fields(self): """ - Test that the `riset` and `riset_id` fields are retrieved appropriately + Test that the `group` and `group_id` fields are retrieved appropriately for the 'ProbeSet' trait type. """ for trait_name, expected in [ @@ -52,7 +52,7 @@ class TestDatasetsDBFunctions(TestCase): with db_mock.cursor() as cursor: cursor.execute.return_value = () self.assertEqual( - retrieve_probeset_riset_fields(trait_name, db_mock), + retrieve_probeset_group_fields(trait_name, db_mock), expected) cursor.execute.assert_called_once_with( ( @@ -63,34 +63,34 @@ class TestDatasetsDBFunctions(TestCase): " AND ProbeSetFreeze.Name = %(name)s"), {"name": trait_name}) - def test_retrieve_riset_fields(self): + def test_retrieve_group_fields(self): """ - Test that the riset fields are set up correctly for the different trait + Test that the group fields are set up correctly for the different trait types. """ for trait_type, trait_name, dataset_info, expected in [ ["Publish", "pubTraitName01", {"dataset_name": "pubDBName01"}, - {"dataset_name": "pubDBName01", "riset": ""}], + {"dataset_name": "pubDBName01", "group": ""}], ["ProbeSet", "prbTraitName01", {"dataset_name": "prbDBName01"}, - {"dataset_name": "prbDBName01", "riset": ""}], + {"dataset_name": "prbDBName01", "group": ""}], ["Geno", "genoTraitName01", {"dataset_name": "genoDBName01"}, - {"dataset_name": "genoDBName01", "riset": ""}], - ["Temp", "tempTraitName01", {}, {"riset": ""}], + {"dataset_name": "genoDBName01", "group": ""}], + ["Temp", "tempTraitName01", {}, {"group": ""}], ]: db_mock = mock.MagicMock() with self.subTest( trait_type=trait_type, trait_name=trait_name, dataset_info=dataset_info): with db_mock.cursor() as cursor: - cursor.execute.return_value = ("riset_name", 0) + cursor.execute.return_value = ("group_name", 0) self.assertEqual( - retrieve_riset_fields( + retrieve_group_fields( trait_type, trait_name, dataset_info, db_mock), expected) - def test_retrieve_publish_riset_fields(self): + def test_retrieve_publish_group_fields(self): """ - Test that the `riset` and `riset_id` fields are retrieved appropriately + Test that the `group` and `group_id` fields are retrieved appropriately for the 'Publish' trait type. """ for trait_name, expected in [ @@ -100,7 +100,7 @@ class TestDatasetsDBFunctions(TestCase): with db_mock.cursor() as cursor: cursor.execute.return_value = () self.assertEqual( - retrieve_publish_riset_fields(trait_name, db_mock), + retrieve_publish_group_fields(trait_name, db_mock), expected) cursor.execute.assert_called_once_with( ( @@ -110,9 +110,9 @@ class TestDatasetsDBFunctions(TestCase): " AND PublishFreeze.Name = %(name)s"), {"name": trait_name}) - def test_retrieve_geno_riset_fields(self): + def test_retrieve_geno_group_fields(self): """ - Test that the `riset` and `riset_id` fields are retrieved appropriately + Test that the `group` and `group_id` fields are retrieved appropriately for the 'Geno' trait type. """ for trait_name, expected in [ @@ -122,7 +122,7 @@ class TestDatasetsDBFunctions(TestCase): with db_mock.cursor() as cursor: cursor.execute.return_value = () self.assertEqual( - retrieve_geno_riset_fields(trait_name, db_mock), + retrieve_geno_group_fields(trait_name, db_mock), expected) cursor.execute.assert_called_once_with( ( -- cgit v1.2.3