diff options
author | zsloan | 2021-10-12 18:56:34 +0000 |
---|---|---|
committer | zsloan | 2021-10-12 18:56:34 +0000 |
commit | 0f396f4a1a753d449cf2975fc425d587d9350689 (patch) | |
tree | c9dac243dc05e5cb90ccb7f1d96fd599986bf60a /gn3 | |
parent | 976660ce9ef915426c7ce5ff9077b439e4102a2c (diff) | |
parent | 77c274b79c3ec01de60e90db3299763cb58f715b (diff) | |
download | genenetwork3-0f396f4a1a753d449cf2975fc425d587d9350689.tar.gz |
Merge branch 'main' of https://github.com/genenetwork/genenetwork3 into feature/add_rqtl_pairscan
Diffstat (limited to 'gn3')
-rw-r--r-- | gn3/api/heatmaps.py | 36 | ||||
-rw-r--r-- | gn3/api/wgcna.py | 25 | ||||
-rw-r--r-- | gn3/app.py | 14 | ||||
-rw-r--r-- | gn3/computations/heatmap.py | 177 | ||||
-rw-r--r-- | gn3/computations/parsers.py | 10 | ||||
-rw-r--r-- | gn3/computations/qtlreaper.py | 170 | ||||
-rw-r--r-- | gn3/computations/wgcna.py | 51 | ||||
-rw-r--r-- | gn3/db/datasets.py | 52 | ||||
-rw-r--r-- | gn3/db/genotypes.py | 184 | ||||
-rw-r--r-- | gn3/db/traits.py | 78 | ||||
-rw-r--r-- | gn3/heatmaps.py | 424 | ||||
-rw-r--r-- | gn3/heatmaps/heatmaps.py | 54 | ||||
-rw-r--r-- | gn3/random.py | 11 | ||||
-rw-r--r-- | gn3/settings.py | 21 |
14 files changed, 1009 insertions, 298 deletions
diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py new file mode 100644 index 0000000..62ca2ad --- /dev/null +++ b/gn3/api/heatmaps.py @@ -0,0 +1,36 @@ +""" +Module to hold the entrypoint functions that generate heatmaps +""" + +import io +from flask import jsonify +from flask import request +from flask import Blueprint +from gn3.heatmaps import build_heatmap +from gn3.db_utils import database_connector + +heatmaps = Blueprint("heatmaps", __name__) + +@heatmaps.route("/clustered", methods=("POST",)) +def clustered_heatmaps(): + """ + Parses the incoming data and responds with the JSON-serialized plotly figure + representing the clustered heatmap. + """ + traits_names = request.get_json().get("traits_names", tuple()) + if len(traits_names) < 2: + return jsonify({ + "message": "You need to provide at least two trait names." + }), 400 + conn, _cursor = database_connector() + def parse_trait_fullname(trait): + name_parts = trait.split(":") + return "{dataset_name}::{trait_name}".format( + dataset_name=name_parts[1], trait_name=name_parts[0]) + traits_fullnames = [parse_trait_fullname(trait) for trait in traits_names] + + with io.StringIO() as io_str: + _filename, figure = build_heatmap(traits_fullnames, conn) + figure.write_json(io_str) + fig_json = io_str.getvalue() + return fig_json, 200 diff --git a/gn3/api/wgcna.py b/gn3/api/wgcna.py new file mode 100644 index 0000000..fa044cf --- /dev/null +++ b/gn3/api/wgcna.py @@ -0,0 +1,25 @@ +"""endpoint to run wgcna analysis""" +from flask import Blueprint +from flask import request +from flask import current_app +from flask import jsonify + +from gn3.computations.wgcna import call_wgcna_script + +wgcna = Blueprint("wgcna", __name__) + + +@wgcna.route("/run_wgcna", methods=["POST"]) +def run_wgcna(): + """run wgcna:output should be a json with a the data""" + + wgcna_data = request.json + + wgcna_script = current_app.config["WGCNA_RSCRIPT"] + + results = call_wgcna_script(wgcna_script, wgcna_data) + + if results.get("data") is None: + return jsonify(results), 401 + + return jsonify(results), 200 @@ -3,13 +3,17 @@ import os from typing import Dict from typing import Union + from flask import Flask +from flask_cors import CORS # type: ignore + from gn3.api.gemma import gemma from gn3.api.rqtl import rqtl from gn3.api.general import general +from gn3.api.heatmaps import heatmaps from gn3.api.correlation import correlation from gn3.api.data_entry import data_entry - +from gn3.api.wgcna import wgcna def create_app(config: Union[Dict, str, None] = None) -> Flask: """Create a new flask object""" @@ -17,6 +21,12 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask: # Load default configuration app.config.from_object("gn3.settings") + CORS( + app, + origins=app.config["CORS_ORIGINS"], + allow_headers=app.config["CORS_HEADERS"], + supports_credentials=True, intercept_exceptions=False) + # Load environment configuration if "GN3_CONF" in os.environ: app.config.from_envvar('GN3_CONF') @@ -30,6 +40,8 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask: app.register_blueprint(general, url_prefix="/api/") app.register_blueprint(gemma, url_prefix="/api/gemma") app.register_blueprint(rqtl, url_prefix="/api/rqtl") + app.register_blueprint(heatmaps, url_prefix="/api/heatmaps") app.register_blueprint(correlation, url_prefix="/api/correlation") app.register_blueprint(data_entry, url_prefix="/api/dataentry") + app.register_blueprint(wgcna, url_prefix="/api/wgcna") return app diff --git a/gn3/computations/heatmap.py b/gn3/computations/heatmap.py deleted file mode 100644 index 3c35029..0000000 --- a/gn3/computations/heatmap.py +++ /dev/null @@ -1,177 +0,0 @@ -""" -This module will contain functions to be used in computation of the data used to -generate various kinds of heatmaps. -""" - -from functools import reduce -from typing import Any, Dict, Sequence -from gn3.computations.slink import slink -from gn3.db.traits import retrieve_trait_data, retrieve_trait_info -from gn3.computations.correlations2 import compute_correlation - -def export_trait_data( - trait_data: dict, strainlist: Sequence[str], dtype: str = "val", - var_exists: bool = False, n_exists: bool = False): - """ - Export data according to `strainlist`. Mostly used in calculating - correlations. - - DESCRIPTION: - Migrated from - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211 - - PARAMETERS - trait: (dict) - The dictionary of key-value pairs representing a trait - strainlist: (list) - A list of strain names - type: (str) - ... verify what this is ... - var_exists: (bool) - A flag indicating existence of variance - n_exists: (bool) - A flag indicating existence of ndata - """ - def __export_all_types(tdata, strain): - sample_data = [] - if tdata[strain]["value"]: - sample_data.append(tdata[strain]["value"]) - if var_exists: - if tdata[strain]["variance"]: - sample_data.append(tdata[strain]["variance"]) - else: - sample_data.append(None) - if n_exists: - if tdata[strain]["ndata"]: - sample_data.append(tdata[strain]["ndata"]) - else: - sample_data.append(None) - else: - if var_exists and n_exists: - sample_data += [None, None, None] - elif var_exists or n_exists: - sample_data += [None, None] - else: - sample_data.append(None) - - return tuple(sample_data) - - def __exporter(accumulator, strain): - # pylint: disable=[R0911] - if strain in trait_data["data"]: - if dtype == "val": - return accumulator + (trait_data["data"][strain]["value"], ) - if dtype == "var": - return accumulator + (trait_data["data"][strain]["variance"], ) - if dtype == "N": - return accumulator + (trait_data["data"][strain]["ndata"], ) - if dtype == "all": - return accumulator + __export_all_types(trait_data["data"], strain) - raise KeyError("Type `%s` is incorrect" % dtype) - if var_exists and n_exists: - return accumulator + (None, None, None) - if var_exists or n_exists: - return accumulator + (None, None) - return accumulator + (None,) - - return reduce(__exporter, strainlist, tuple()) - -def trait_display_name(trait: Dict): - """ - Given a trait, return a name to use to display the trait on a heatmap. - - DESCRIPTION - Migrated from - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L141-L157 - """ - if trait.get("db", None) and trait.get("trait_name", None): - if trait["db"]["dataset_type"] == "Temp": - desc = trait["description"] - if desc.find("PCA") >= 0: - return "%s::%s" % ( - trait["db"]["displayname"], - desc[desc.rindex(':')+1:].strip()) - return "%s::%s" % ( - trait["db"]["displayname"], - desc[:desc.index('entered')].strip()) - prefix = "%s::%s" % ( - trait["db"]["dataset_name"], trait["trait_name"]) - if trait["cellid"]: - return "%s::%s" % (prefix, trait["cellid"]) - return prefix - return trait["description"] - -def cluster_traits(traits_data_list: Sequence[Dict]): - """ - Clusters the trait values. - - DESCRIPTION - Attempts to replicate the clustering of the traits, as done at - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L162 - """ - def __compute_corr(tdata_i, tdata_j): - if tdata_i[0] == tdata_j[0]: - return 0.0 - corr_vals = compute_correlation(tdata_i[1], tdata_j[1]) - corr = corr_vals[0] - if (1 - corr) < 0: - return 0.0 - return 1 - corr - - def __cluster(tdata_i): - return tuple( - __compute_corr(tdata_i, tdata_j) - for tdata_j in enumerate(traits_data_list)) - - return tuple(__cluster(tdata_i) for tdata_i in enumerate(traits_data_list)) - -def heatmap_data(formd, search_result, conn: Any): - """ - heatmap function - - DESCRIPTION - This function is an attempt to reproduce the initialisation at - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L46-L64 - and also the clustering and slink computations at - https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L165 - with the help of the `gn3.computations.heatmap.cluster_traits` function. - - It does not try to actually draw the heatmap image. - - PARAMETERS: - TODO: Elaborate on the parameters here... - """ - threshold = 0 # webqtlConfig.PUBLICTHRESH - cluster_checked = formd.formdata.getvalue("clusterCheck", "") - strainlist = [ - strain for strain in formd.strainlist if strain not in formd.parlist] - genotype = formd.genotype - - def __retrieve_traitlist_and_datalist(threshold, fullname): - trait = retrieve_trait_info(threshold, fullname, conn) - return ( - trait, - export_trait_data(retrieve_trait_data(trait, conn), strainlist)) - - traits_details = [ - __retrieve_traitlist_and_datalist(threshold, fullname) - for fullname in search_result] - traits_list = map(lambda x: x[0], traits_details) - traits_data_list = map(lambda x: x[1], traits_details) - - return { - "target_description_checked": formd.formdata.getvalue( - "targetDescriptionCheck", ""), - "cluster_checked": cluster_checked, - "slink_data": ( - slink(cluster_traits(traits_data_list)) - if cluster_checked else False), - "sessionfile": formd.formdata.getvalue("session"), - "genotype": genotype, - "nLoci": sum(map(len, genotype)), - "strainlist": strainlist, - "ppolar": formd.ppolar, - "mpolar":formd.mpolar, - "traits_list": traits_list, - "traits_data_list": traits_data_list - } diff --git a/gn3/computations/parsers.py b/gn3/computations/parsers.py index 94387ff..1af35d6 100644 --- a/gn3/computations/parsers.py +++ b/gn3/computations/parsers.py @@ -14,7 +14,7 @@ def parse_genofile(file_path: str) -> Tuple[List[str], 'h': 0, 'u': None, } - genotypes, strains = [], [] + genotypes, samples = [], [] with open(file_path, "r") as _genofile: for line in _genofile: line = line.strip() @@ -22,8 +22,8 @@ def parse_genofile(file_path: str) -> Tuple[List[str], continue cells = line.split() if line.startswith("Chr"): - strains = cells[4:] - strains = [strain.lower() for strain in strains] + samples = cells[4:] + samples = [sample.lower() for sample in samples] continue values = [__map.get(value.lower(), None) for value in cells[4:]] genotype = { @@ -32,7 +32,7 @@ def parse_genofile(file_path: str) -> Tuple[List[str], "cm": cells[2], "mb": cells[3], "values": values, - "dicvalues": dict(zip(strains, values)), + "dicvalues": dict(zip(samples, values)), } genotypes.append(genotype) - return strains, genotypes + return samples, genotypes diff --git a/gn3/computations/qtlreaper.py b/gn3/computations/qtlreaper.py new file mode 100644 index 0000000..d1ff4ac --- /dev/null +++ b/gn3/computations/qtlreaper.py @@ -0,0 +1,170 @@ +""" +This module contains functions to interact with the `qtlreaper` utility for +computation of QTLs. +""" +import os +import subprocess +from typing import Union + +from gn3.random import random_string +from gn3.settings import TMPDIR, REAPER_COMMAND + +def generate_traits_file(samples, trait_values, traits_filename): + """ + Generate a traits file for use with `qtlreaper`. + + PARAMETERS: + samples: A list of samples to use as the headers for the various columns. + trait_values: A list of lists of values for each trait and sample. + traits_filename: The tab-separated value to put the values in for + computation of QTLs. + """ + header = "Trait\t{}\n".format("\t".join(samples)) + data = ( + [header] + + ["{}\t{}\n".format(i+1, "\t".join([str(i) for i in t])) + for i, t in enumerate(trait_values[:-1])] + + ["{}\t{}".format( + len(trait_values), "\t".join([str(i) for i in t])) + for t in trait_values[-1:]]) + with open(traits_filename, "w") as outfile: + outfile.writelines(data) + +def create_output_directory(path: str): + """Create the output directory at `path` if it does not exist.""" + try: + os.mkdir(path) + except FileExistsError: + # If the directory already exists, do nothing. + pass + +def run_reaper( + genotype_filename: str, traits_filename: str, + other_options: tuple = ("--n_permutations", "1000"), + separate_nperm_output: bool = False, + output_dir: str = TMPDIR): + """ + Run the QTLReaper command to compute the QTLs. + + PARAMETERS: + genotype_filename: The complete path to a genotype file to use in the QTL + computation. + traits_filename: A path to a file previously generated with the + `generate_traits_file` function in this module, to be used in the QTL + computation. + other_options: Other options to pass to the `qtlreaper` command to modify + the QTL computations. + separate_nperm_output: A flag indicating whether or not to provide a + separate output for the permutations computation. The default is False, + which means by default, no separate output file is created. + output_dir: A path to the directory where the outputs are put + + RETURNS: + The function returns a tuple of the main output file, and the output file + for the permutation computations. If the `separate_nperm_output` is `False`, + the second value in the tuple returned is `None`. + + RAISES: + The function will raise a `subprocess.CalledProcessError` exception in case + of any errors running the `qtlreaper` command. + """ + create_output_directory("{}/qtlreaper".format(output_dir)) + output_filename = "{}/qtlreaper/main_output_{}.txt".format( + output_dir, random_string(10)) + output_list = ["--main_output", output_filename] + if separate_nperm_output: + permu_output_filename: Union[None, str] = "{}/qtlreaper/permu_output_{}.txt".format( + output_dir, random_string(10)) + output_list = output_list + [ + "--permu_output", permu_output_filename] # type: ignore[list-item] + else: + permu_output_filename = None + + command_list = [ + REAPER_COMMAND, "--geno", genotype_filename, + *other_options, # this splices the `other_options` list here + "--traits", traits_filename, + *output_list # this splices the `output_list` list here + ] + + subprocess.run(command_list, check=True) + return (output_filename, permu_output_filename) + +def chromosome_sorter_key_fn(val): + """ + Useful for sorting the chromosomes + """ + if isinstance(val, int): + return val + return ord(val) + +def organise_reaper_main_results(parsed_results): + """ + Provide the results of running reaper in a format that is easier to use. + """ + def __organise_by_chromosome(chr_name, items): + chr_items = [item for item in items if item["Chr"] == chr_name] + return { + "Chr": chr_name, + "loci": [{ + "Locus": locus["Locus"], + "cM": locus["cM"], + "Mb": locus["Mb"], + "LRS": locus["LRS"], + "Additive": locus["Additive"], + "pValue": locus["pValue"] + } for locus in chr_items]} + + def __organise_by_id(identifier, items): + id_items = [item for item in items if item["ID"] == identifier] + unique_chromosomes = {item["Chr"] for item in id_items} + return { + "ID": identifier, + "chromosomes": { + _chr["Chr"]: _chr for _chr in [ + __organise_by_chromosome(chromo, id_items) + for chromo in sorted( + unique_chromosomes, key=chromosome_sorter_key_fn)]}} + + unique_ids = {res["ID"] for res in parsed_results} + return { + trait["ID"]: trait for trait in + [__organise_by_id(_id, parsed_results) for _id in sorted(unique_ids)]} + +def parse_reaper_main_results(results_file): + """ + Parse the results file of running QTLReaper into a list of dicts. + """ + with open(results_file, "r") as infile: + lines = infile.readlines() + + def __parse_column_float_value(value): + # pylint: disable=W0702 + try: + return float(value) + except: + return value + + def __parse_column_int_value(value): + # pylint: disable=W0702 + try: + return int(value) + except: + return value + + def __parse_line(line): + items = line.strip().split("\t") + return items[0:2] + [__parse_column_int_value(items[2])] + [ + __parse_column_float_value(item) for item in items[3:]] + + header = lines[0].strip().split("\t") + return [dict(zip(header, __parse_line(line))) for line in lines[1:]] + +def parse_reaper_permutation_results(results_file): + """ + Parse the results QTLReaper permutations into a list of values. + """ + with open(results_file, "r") as infile: + lines = infile.readlines() + + return [float(line.strip()) for line in lines] diff --git a/gn3/computations/wgcna.py b/gn3/computations/wgcna.py new file mode 100644 index 0000000..fd508fa --- /dev/null +++ b/gn3/computations/wgcna.py @@ -0,0 +1,51 @@ +"""module contains code to preprocess and call wgcna script""" + +import os +import json +import uuid +from gn3.settings import TMPDIR + +from gn3.commands import run_cmd + + +def dump_wgcna_data(request_data: dict): + """function to dump request data to json file""" + filename = f"{str(uuid.uuid4())}.json" + + temp_file_path = os.path.join(TMPDIR, filename) + + with open(temp_file_path, "w") as output_file: + json.dump(request_data, output_file) + + return temp_file_path + + +def compose_wgcna_cmd(rscript_path: str, temp_file_path: str): + """function to componse wgcna cmd""" + # (todo):issue relative paths to abs paths + cmd = f"Rscript ./scripts/{rscript_path} {temp_file_path}" + return cmd + + +def call_wgcna_script(rscript_path: str, request_data: dict): + """function to call wgcna script""" + generated_file = dump_wgcna_data(request_data) + cmd = compose_wgcna_cmd(rscript_path, generated_file) + + try: + + run_cmd_results = run_cmd(cmd) + + with open(generated_file, "r") as outputfile: + + if run_cmd_results["code"] != 0: + return run_cmd_results + return { + "data": json.load(outputfile), + **run_cmd_results + } + except FileNotFoundError: + # relook at handling errors gn3 + return { + "output": "output file not found" + } diff --git a/gn3/db/datasets.py b/gn3/db/datasets.py index 4a05499..6c328f5 100644 --- a/gn3/db/datasets.py +++ b/gn3/db/datasets.py @@ -119,9 +119,9 @@ def retrieve_dataset_name( return fn_map[trait_type](threshold, dataset_name, conn) -def retrieve_geno_riset_fields(name, conn): +def retrieve_geno_group_fields(name, conn): """ - Retrieve the RISet, and RISetID values for various Geno trait types. + Retrieve the Group, and GroupID values for various Geno trait types. """ query = ( "SELECT InbredSet.Name, InbredSet.Id " @@ -130,12 +130,12 @@ def retrieve_geno_riset_fields(name, conn): "AND GenoFreeze.Name = %(name)s") with conn.cursor() as cursor: cursor.execute(query, {"name": name}) - return dict(zip(["riset", "risetid"], cursor.fetchone())) + return dict(zip(["group", "groupid"], cursor.fetchone())) return {} -def retrieve_publish_riset_fields(name, conn): +def retrieve_publish_group_fields(name, conn): """ - Retrieve the RISet, and RISetID values for various Publish trait types. + Retrieve the Group, and GroupID values for various Publish trait types. """ query = ( "SELECT InbredSet.Name, InbredSet.Id " @@ -144,12 +144,12 @@ def retrieve_publish_riset_fields(name, conn): "AND PublishFreeze.Name = %(name)s") with conn.cursor() as cursor: cursor.execute(query, {"name": name}) - return dict(zip(["riset", "risetid"], cursor.fetchone())) + return dict(zip(["group", "groupid"], cursor.fetchone())) return {} -def retrieve_probeset_riset_fields(name, conn): +def retrieve_probeset_group_fields(name, conn): """ - Retrieve the RISet, and RISetID values for various ProbeSet trait types. + Retrieve the Group, and GroupID values for various ProbeSet trait types. """ query = ( "SELECT InbredSet.Name, InbredSet.Id " @@ -159,12 +159,12 @@ def retrieve_probeset_riset_fields(name, conn): "AND ProbeSetFreeze.Name = %(name)s") with conn.cursor() as cursor: cursor.execute(query, {"name": name}) - return dict(zip(["riset", "risetid"], cursor.fetchone())) + return dict(zip(["group", "groupid"], cursor.fetchone())) return {} -def retrieve_temp_riset_fields(name, conn): +def retrieve_temp_group_fields(name, conn): """ - Retrieve the RISet, and RISetID values for `Temp` trait types. + Retrieve the Group, and GroupID values for `Temp` trait types. """ query = ( "SELECT InbredSet.Name, InbredSet.Id " @@ -173,30 +173,30 @@ def retrieve_temp_riset_fields(name, conn): "AND Temp.Name = %(name)s") with conn.cursor() as cursor: cursor.execute(query, {"name": name}) - return dict(zip(["riset", "risetid"], cursor.fetchone())) + return dict(zip(["group", "groupid"], cursor.fetchone())) return {} -def retrieve_riset_fields(trait_type, trait_name, dataset_info, conn): +def retrieve_group_fields(trait_type, trait_name, dataset_info, conn): """ - Retrieve the RISet, and RISetID values for various trait types. + Retrieve the Group, and GroupID values for various trait types. """ - riset_fns_map = { - "Geno": retrieve_geno_riset_fields, - "Publish": retrieve_publish_riset_fields, - "ProbeSet": retrieve_probeset_riset_fields + group_fns_map = { + "Geno": retrieve_geno_group_fields, + "Publish": retrieve_publish_group_fields, + "ProbeSet": retrieve_probeset_group_fields } if trait_type == "Temp": - riset_info = retrieve_temp_riset_fields(trait_name, conn) + group_info = retrieve_temp_group_fields(trait_name, conn) else: - riset_info = riset_fns_map[trait_type](dataset_info["dataset_name"], conn) + group_info = group_fns_map[trait_type](dataset_info["dataset_name"], conn) return { **dataset_info, - **riset_info, - "riset": ( - "BXD" if riset_info.get("riset") == "BXD300" - else riset_info.get("riset", "")) + **group_info, + "group": ( + "BXD" if group_info.get("group") == "BXD300" + else group_info.get("group", "")) } def retrieve_temp_trait_dataset(): @@ -281,11 +281,11 @@ def retrieve_trait_dataset(trait_type, trait, threshold, conn): trait_type, threshold, trait["trait_name"], trait["db"]["dataset_name"], conn) } - riset = retrieve_riset_fields( + group = retrieve_group_fields( trait_type, trait["trait_name"], dataset_name_info, conn) return { "display_name": dataset_name_info["dataset_name"], **dataset_name_info, **dataset_fns[trait_type](), - **riset + **group } diff --git a/gn3/db/genotypes.py b/gn3/db/genotypes.py new file mode 100644 index 0000000..8f18cac --- /dev/null +++ b/gn3/db/genotypes.py @@ -0,0 +1,184 @@ +"""Genotype utilities""" + +import os +import gzip +from typing import Union, TextIO + +from gn3.settings import GENOTYPE_FILES + +def build_genotype_file( + geno_name: str, base_dir: str = GENOTYPE_FILES, + extension: str = "geno"): + """Build the absolute path for the genotype file.""" + return "{}/{}.{}".format(os.path.abspath(base_dir), geno_name, extension) + +def load_genotype_samples(genotype_filename: str, file_type: str = "geno"): + """ + Load sample of samples from genotype files. + + DESCRIPTION: + Traits can contain a varied number of samples, some of which do not exist in + certain genotypes. In order to compute QTLs, GEMMAs, etc, we need to ensure + to pick only those samples that exist in the genotype under consideration + for the traits used in the computation. + + This function loads a list of samples from the genotype files for use in + filtering out unusable samples. + + + PARAMETERS: + genotype_filename: The absolute path to the genotype file to load the + samples from. + file_type: The type of file. Currently supported values are 'geno' and + 'plink'. + """ + file_type_fns = { + "geno": __load_genotype_samples_from_geno, + "plink": __load_genotype_samples_from_plink + } + return file_type_fns[file_type](genotype_filename) + +def __load_genotype_samples_from_geno(genotype_filename: str): + """ + Helper function for `load_genotype_samples` function. + + Loads samples from '.geno' files. + """ + gzipped_filename = "{}.gz".format(genotype_filename) + if os.path.isfile(gzipped_filename): + genofile: Union[TextIO, gzip.GzipFile] = gzip.open(gzipped_filename) + else: + genofile = open(genotype_filename) + + for row in genofile: + line = row.strip() + if (not line) or (line.startswith(("#", "@"))): # type: ignore[arg-type] + continue + break + + headers = line.split("\t") # type: ignore[arg-type] + if headers[3] == "Mb": + return headers[4:] + return headers[3:] + +def __load_genotype_samples_from_plink(genotype_filename: str): + """ + Helper function for `load_genotype_samples` function. + + Loads samples from '.plink' files. + """ + genofile = open(genotype_filename) + return [line.split(" ")[1] for line in genofile] + +def parse_genotype_labels(lines: list): + """ + Parse label lines into usable genotype values + + DESCRIPTION: + Reworks + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/utility/gen_geno_ob.py#L75-L93 + """ + acceptable_labels = ["name", "filler", "type", "mat", "pat", "het", "unk"] + def __parse_label(line): + label, value = [l.strip() for l in line[1:].split(":")] + if label not in acceptable_labels: + return None + if label == "name": + return ("group", value) + return (label, value) + return tuple( + item for item in (__parse_label(line) for line in lines) + if item is not None) + +def parse_genotype_header(line: str, parlist: tuple = tuple()): + """ + Parse the genotype file header line + + DESCRIPTION: + Reworks + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/utility/gen_geno_ob.py#L94-L114 + """ + items = [item.strip() for item in line.split("\t")] + mbmap = "Mb" in items + prgy = ((parlist + tuple(items[4:])) if mbmap + else (parlist + tuple(items[3:]))) + return ( + ("Mbmap", mbmap), + ("cm_column", items.index("cM")), + ("mb_column", None if not mbmap else items.index("Mb")), + ("prgy", prgy), + ("nprgy", len(prgy))) + +def parse_genotype_marker(line: str, geno_obj: dict, parlist: tuple): + """ + Parse a data line in a genotype file + + DESCRIPTION: + Reworks + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/utility/gen_geno_ob.py#L143-L190 + """ + # pylint: disable=W0702 + marker_row = [item.strip() for item in line.split("\t")] + geno_table = { + geno_obj["mat"]: -1, geno_obj["pat"]: 1, geno_obj["het"]: 0, + geno_obj["unk"]: "U" + } + start_pos = 4 if geno_obj["Mbmap"] else 3 + if len(parlist) > 0: + start_pos = start_pos + 2 + + alleles = marker_row[start_pos:] + genotype = tuple( + (geno_table[allele] if allele in geno_table.keys() else "U") + for allele in alleles) + if len(parlist) > 0: + genotype = (-1, 1) + genotype + try: + cm_val = float(geno_obj["cm_column"]) + except: + if geno_obj["Mbmap"]: + cm_val = float(geno_obj["mb_column"]) + else: + cm_val = 0 + return ( + ("chr", marker_row[0]), + ("name", marker_row[1]), + ("cM", cm_val), + ("Mb", float(geno_obj["mb_column"]) if geno_obj["Mbmap"] else None), + ("genotype", genotype)) + +def build_genotype_chromosomes(geno_obj, markers): + """ + Build up the chromosomes from the given markers and partially built geno + object + """ + mrks = [dict(marker) for marker in markers] + chr_names = {marker["chr"] for marker in mrks} + return tuple(( + ("name", chr_name), ("mb_exists", geno_obj["Mbmap"]), ("cm_column", 2), + ("mb_column", geno_obj["mb_column"]), + ("loci", tuple(marker for marker in mrks if marker["chr"] == chr_name))) + for chr_name in sorted(chr_names)) + +def parse_genotype_file(filename: str, parlist: tuple = tuple()): + """ + Parse the provided genotype file into a usable pytho3 data structure. + """ + with open(filename, "r") as infile: + contents = infile.readlines() + + lines = tuple(line for line in contents if + ((not line.strip().startswith("#")) and + (not line.strip() == ""))) + labels = parse_genotype_labels( + [line for line in lines if line.startswith("@")]) + data_lines = tuple(line for line in lines if not line.startswith("@")) + header = parse_genotype_header(data_lines[0], parlist) + geno_obj = dict(labels + header) + markers = tuple( + [parse_genotype_marker(line, geno_obj, parlist) + for line in data_lines[1:]]) + chromosomes = tuple( + dict(chromosome) for chromosome in + build_genotype_chromosomes(geno_obj, markers)) + return {**geno_obj, "chromosomes": chromosomes} diff --git a/gn3/db/traits.py b/gn3/db/traits.py index 1031e44..f2673c8 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -1,5 +1,8 @@ """This class contains functions relating to trait data manipulation""" +import os from typing import Any, Dict, Union, Sequence +from gn3.settings import TMPDIR +from gn3.random import random_string from gn3.function_helpers import compose from gn3.db.datasets import retrieve_trait_dataset @@ -43,7 +46,7 @@ def update_sample_data(conn: Any, count: Union[int, str]): """Given the right parameters, update sample-data from the relevant table.""" - # pylint: disable=[R0913, R0914] + # pylint: disable=[R0913, R0914, C0103] STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s" PUBLISH_DATA_SQL: str = ("UPDATE PublishData SET value = %s " "WHERE StrainId = %s AND Id = %s") @@ -60,22 +63,22 @@ def update_sample_data(conn: Any, with conn.cursor() as cursor: # Update the Strains table cursor.execute(STRAIN_ID_SQL, (strain_name, strain_id)) - updated_strains: int = cursor.rowcount + updated_strains = cursor.rowcount # Update the PublishData table cursor.execute(PUBLISH_DATA_SQL, (None if value == "x" else value, strain_id, publish_data_id)) - updated_published_data: int = cursor.rowcount + updated_published_data = cursor.rowcount # Update the PublishSE table cursor.execute(PUBLISH_SE_SQL, (None if error == "x" else error, strain_id, publish_data_id)) - updated_se_data: int = cursor.rowcount + updated_se_data = cursor.rowcount # Update the NStrain table cursor.execute(N_STRAIN_SQL, (None if count == "x" else count, strain_id, publish_data_id)) - updated_n_strains: int = cursor.rowcount + updated_n_strains = cursor.rowcount return (updated_strains, updated_published_data, updated_se_data, updated_n_strains) @@ -223,7 +226,7 @@ def set_homologene_id_field_probeset(trait_info, conn): """ query = ( "SELECT HomologeneId FROM Homologene, Species, InbredSet" - " WHERE Homologene.GeneId = %(geneid)s AND InbredSet.Name = %(riset)s" + " WHERE Homologene.GeneId = %(geneid)s AND InbredSet.Name = %(group)s" " AND InbredSet.SpeciesId = Species.Id AND" " Species.TaxonomyId = Homologene.TaxonomyId") with conn.cursor() as cursor: @@ -231,7 +234,7 @@ def set_homologene_id_field_probeset(trait_info, conn): query, { k:v for k, v in trait_info.items() - if k in ["geneid", "riset"] + if k in ["geneid", "group"] }) res = cursor.fetchone() if res: @@ -419,7 +422,7 @@ def retrieve_trait_info( if trait_info["haveinfo"]: return { **trait_post_processing_functions_table[trait_dataset_type]( - {**trait_info, "riset": trait_dataset["riset"]}), + {**trait_info, "group": trait_dataset["group"]}), "db": {**trait["db"], **trait_dataset} } return trait_info @@ -442,18 +445,18 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any): query, {"trait_name": trait_info["trait_name"]}) return [dict(zip( - ["strain_name", "value", "se_error", "nstrain", "id"], row)) + ["sample_name", "value", "se_error", "nstrain", "id"], row)) for row in cursor.fetchall()] return [] -def retrieve_species_id(riset, conn: Any): +def retrieve_species_id(group, conn: Any): """ - Retrieve a species id given the RISet value + Retrieve a species id given the Group value """ with conn.cursor as cursor: cursor.execute( - "SELECT SpeciesId from InbredSet WHERE Name = %(riset)s", - {"riset": riset}) + "SELECT SpeciesId from InbredSet WHERE Name = %(group)s", + {"group": group}) return cursor.fetchone()[0] return None @@ -479,9 +482,9 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any): {"trait_name": trait_info["trait_name"], "dataset_name": trait_info["db"]["dataset_name"], "species_id": retrieve_species_id( - trait_info["db"]["riset"], conn)}) + trait_info["db"]["group"], conn)}) return [dict(zip( - ["strain_name", "value", "se_error", "id"], row)) + ["sample_name", "value", "se_error", "id"], row)) for row in cursor.fetchall()] return [] @@ -512,7 +515,7 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any): {"trait_name": trait_info["trait_name"], "dataset_id": trait_info["db"]["dataset_id"]}) return [dict(zip( - ["strain_name", "value", "se_error", "nstrain", "id"], row)) + ["sample_name", "value", "se_error", "nstrain", "id"], row)) for row in cursor.fetchall()] return [] @@ -545,7 +548,7 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any): "trait_name": trait_info["trait_name"], "dataset_id": trait_info["db"]["dataset_id"]}) return [dict(zip( - ["strain_name", "value", "se_error", "id"], row)) + ["sample_name", "value", "se_error", "id"], row)) for row in cursor.fetchall()] return [] @@ -574,29 +577,29 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any): {"trait_name": trait_info["trait_name"], "dataset_name": trait_info["db"]["dataset_name"]}) return [dict(zip( - ["strain_name", "value", "se_error", "id"], row)) + ["sample_name", "value", "se_error", "id"], row)) for row in cursor.fetchall()] return [] -def with_strainlist_data_setup(strainlist: Sequence[str]): +def with_samplelist_data_setup(samplelist: Sequence[str]): """ - Build function that computes the trait data from provided list of strains. + Build function that computes the trait data from provided list of samples. PARAMETERS - strainlist: (list) - A list of strain names + samplelist: (list) + A list of sample names RETURNS: Returns a function that given some data from the database, computes the - strain's value, variance and ndata values, only if the strain is present - in the provided `strainlist` variable. + sample's value, variance and ndata values, only if the sample is present + in the provided `samplelist` variable. """ def setup_fn(tdata): - if tdata["strain_name"] in strainlist: + if tdata["sample_name"] in samplelist: val = tdata["value"] if val is not None: return { - "strain_name": tdata["strain_name"], + "sample_name": tdata["sample_name"], "value": val, "variance": tdata["se_error"], "ndata": tdata.get("nstrain", None) @@ -604,19 +607,19 @@ def with_strainlist_data_setup(strainlist: Sequence[str]): return None return setup_fn -def without_strainlist_data_setup(): +def without_samplelist_data_setup(): """ Build function that computes the trait data. RETURNS: Returns a function that given some data from the database, computes the - strain's value, variance and ndata values. + sample's value, variance and ndata values. """ def setup_fn(tdata): val = tdata["value"] if val is not None: return { - "strain_name": tdata["strain_name"], + "sample_name": tdata["sample_name"], "value": val, "variance": tdata["se_error"], "ndata": tdata.get("nstrain", None) @@ -624,7 +627,7 @@ def without_strainlist_data_setup(): return None return setup_fn -def retrieve_trait_data(trait: dict, conn: Any, strainlist: Sequence[str] = tuple()): +def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tuple()): """ Retrieve trait data @@ -647,22 +650,27 @@ def retrieve_trait_data(trait: dict, conn: Any, strainlist: Sequence[str] = tupl if results: # do something with mysqlid mysqlid = results[0]["id"] - if strainlist: + if samplelist: data = [ item for item in - map(with_strainlist_data_setup(strainlist), results) + map(with_samplelist_data_setup(samplelist), results) if item is not None] else: data = [ item for item in - map(without_strainlist_data_setup(), results) + map(without_samplelist_data_setup(), results) if item is not None] return { "mysqlid": mysqlid, "data": dict(map( lambda x: ( - x["strain_name"], - {k:v for k, v in x.items() if x != "strain_name"}), + x["sample_name"], + {k:v for k, v in x.items() if x != "sample_name"}), data))} return {} + +def generate_traits_filename(base_path: str = TMPDIR): + """Generate a unique filename for use with generated traits files.""" + return "{}/traits_test_file_{}.txt".format( + os.path.abspath(base_path), random_string(10)) diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py new file mode 100644 index 0000000..adbfbc6 --- /dev/null +++ b/gn3/heatmaps.py @@ -0,0 +1,424 @@ +""" +This module will contain functions to be used in computation of the data used to +generate various kinds of heatmaps. +""" + +from functools import reduce +from typing import Any, Dict, Union, Sequence + +import numpy as np +import plotly.graph_objects as go # type: ignore +import plotly.figure_factory as ff # type: ignore +from plotly.subplots import make_subplots # type: ignore + +from gn3.settings import TMPDIR +from gn3.random import random_string +from gn3.computations.slink import slink +from gn3.computations.correlations2 import compute_correlation +from gn3.db.genotypes import ( + build_genotype_file, load_genotype_samples) +from gn3.db.traits import ( + retrieve_trait_data, retrieve_trait_info) +from gn3.computations.qtlreaper import ( + run_reaper, + generate_traits_file, + chromosome_sorter_key_fn, + parse_reaper_main_results, + organise_reaper_main_results) + +def export_trait_data( + trait_data: dict, samplelist: Sequence[str], dtype: str = "val", + var_exists: bool = False, n_exists: bool = False): + """ + Export data according to `samplelist`. Mostly used in calculating + correlations. + + DESCRIPTION: + Migrated from + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211 + + PARAMETERS + trait: (dict) + The dictionary of key-value pairs representing a trait + samplelist: (list) + A list of sample names + dtype: (str) + ... verify what this is ... + var_exists: (bool) + A flag indicating existence of variance + n_exists: (bool) + A flag indicating existence of ndata + """ + def __export_all_types(tdata, sample): + sample_data = [] + if tdata[sample]["value"]: + sample_data.append(tdata[sample]["value"]) + if var_exists: + if tdata[sample]["variance"]: + sample_data.append(tdata[sample]["variance"]) + else: + sample_data.append(None) + if n_exists: + if tdata[sample]["ndata"]: + sample_data.append(tdata[sample]["ndata"]) + else: + sample_data.append(None) + else: + if var_exists and n_exists: + sample_data += [None, None, None] + elif var_exists or n_exists: + sample_data += [None, None] + else: + sample_data.append(None) + + return tuple(sample_data) + + def __exporter(accumulator, sample): + # pylint: disable=[R0911] + if sample in trait_data["data"]: + if dtype == "val": + return accumulator + (trait_data["data"][sample]["value"], ) + if dtype == "var": + return accumulator + (trait_data["data"][sample]["variance"], ) + if dtype == "N": + return accumulator + (trait_data["data"][sample]["ndata"], ) + if dtype == "all": + return accumulator + __export_all_types(trait_data["data"], sample) + raise KeyError("Type `%s` is incorrect" % dtype) + if var_exists and n_exists: + return accumulator + (None, None, None) + if var_exists or n_exists: + return accumulator + (None, None) + return accumulator + (None,) + + return reduce(__exporter, samplelist, tuple()) + +def trait_display_name(trait: Dict): + """ + Given a trait, return a name to use to display the trait on a heatmap. + + DESCRIPTION + Migrated from + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L141-L157 + """ + if trait.get("db", None) and trait.get("trait_name", None): + if trait["db"]["dataset_type"] == "Temp": + desc = trait["description"] + if desc.find("PCA") >= 0: + return "%s::%s" % ( + trait["db"]["displayname"], + desc[desc.rindex(':')+1:].strip()) + return "%s::%s" % ( + trait["db"]["displayname"], + desc[:desc.index('entered')].strip()) + prefix = "%s::%s" % ( + trait["db"]["dataset_name"], trait["trait_name"]) + if trait["cellid"]: + return "%s::%s" % (prefix, trait["cellid"]) + return prefix + return trait["description"] + +def cluster_traits(traits_data_list: Sequence[Dict]): + """ + Clusters the trait values. + + DESCRIPTION + Attempts to replicate the clustering of the traits, as done at + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L162 + """ + def __compute_corr(tdata_i, tdata_j): + if tdata_i[0] == tdata_j[0]: + return 0.0 + corr_vals = compute_correlation(tdata_i[1], tdata_j[1]) + corr = corr_vals[0] + if (1 - corr) < 0: + return 0.0 + return 1 - corr + + def __cluster(tdata_i): + return tuple( + __compute_corr(tdata_i, tdata_j) + for tdata_j in enumerate(traits_data_list)) + + return tuple(__cluster(tdata_i) for tdata_i in enumerate(traits_data_list)) + +def get_loci_names( + organised: dict, + chromosome_names: Sequence[str]) -> Sequence[Sequence[str]]: + """ + Get the loci names organised by the same order as the `chromosome_names`. + """ + def __get_trait_loci(accumulator, trait): + chrs = tuple(trait["chromosomes"].keys()) + trait_loci = { + _chr: tuple( + locus["Locus"] + for locus in trait["chromosomes"][_chr]["loci"] + ) for _chr in chrs + } + return { + **accumulator, + **{ + _chr: tuple(sorted(set( + trait_loci[_chr] + accumulator.get(_chr, tuple())))) + for _chr in trait_loci.keys() + } + } + loci_dict: Dict[Union[str, int], Sequence[str]] = reduce( + __get_trait_loci, [v[1] for v in organised.items()], {}) + return tuple(loci_dict[_chr] for _chr in chromosome_names) + +def build_heatmap(traits_names, conn: Any): + """ + heatmap function + + DESCRIPTION + This function is an attempt to reproduce the initialisation at + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L46-L64 + and also the clustering and slink computations at + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L138-L165 + with the help of the `gn3.computations.heatmap.cluster_traits` function. + + It does not try to actually draw the heatmap image. + + PARAMETERS: + TODO: Elaborate on the parameters here... + """ + # pylint: disable=[R0914] + threshold = 0 # webqtlConfig.PUBLICTHRESH + traits = [ + retrieve_trait_info(threshold, fullname, conn) + for fullname in traits_names] + traits_data_list = [retrieve_trait_data(t, conn) for t in traits] + genotype_filename = build_genotype_file(traits[0]["group"]) + samples = load_genotype_samples(genotype_filename) + exported_traits_data_list = [ + export_trait_data(td, samples) for td in traits_data_list] + clustered = cluster_traits(exported_traits_data_list) + slinked = slink(clustered) + traits_order = compute_traits_order(slinked) + samples_and_values = retrieve_samples_and_values( + traits_order, samples, exported_traits_data_list) + traits_filename = "{}/traits_test_file_{}.txt".format( + TMPDIR, random_string(10)) + generate_traits_file( + samples_and_values[0][1], + [t[2] for t in samples_and_values], + traits_filename) + + main_output, _permutations_output = run_reaper( + genotype_filename, traits_filename, separate_nperm_output=True) + + qtlresults = parse_reaper_main_results(main_output) + organised = organise_reaper_main_results(qtlresults) + + traits_ids = [# sort numerically, but retain the ids as strings + str(i) for i in sorted({int(row["ID"]) for row in qtlresults})] + chromosome_names = sorted( + {row["Chr"] for row in qtlresults}, key=chromosome_sorter_key_fn) + ordered_traits_names = dict( + zip(traits_ids, + [traits[idx]["trait_fullname"] for idx in traits_order])) + + return generate_clustered_heatmap( + process_traits_data_for_heatmap( + organised, traits_ids, chromosome_names), + clustered, + "single_heatmap_{}".format(random_string(10)), + y_axis=tuple( + ordered_traits_names[traits_ids[order]] + for order in traits_order), + y_label="Traits", + x_axis=chromosome_names, + x_label="Chromosomes", + loci_names=get_loci_names(organised, chromosome_names)) + +def compute_traits_order(slink_data, neworder: tuple = tuple()): + """ + Compute the order of the traits for clustering from `slink_data`. + + This function tries to reproduce the creation and update of the `neworder` + variable in + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L120 + and in the `web.webqtl.heatmap.Heatmap.draw` function in GN1 + """ + def __order_maker(norder, slnk_dt): + if isinstance(slnk_dt[0], int) and isinstance(slnk_dt[1], int): + return norder + (slnk_dt[0], slnk_dt[1]) + + if isinstance(slnk_dt[0], int): + return __order_maker((norder + (slnk_dt[0], )), slnk_dt[1]) + + if isinstance(slnk_dt[1], int): + return __order_maker(norder, slnk_dt[0]) + (slnk_dt[1], ) + + return __order_maker(__order_maker(norder, slnk_dt[0]), slnk_dt[1]) + + return __order_maker(neworder, slink_data) + +def retrieve_samples_and_values(orders, samplelist, traits_data_list): + """ + Get the samples and their corresponding values from `samplelist` and + `traits_data_list`. + + This migrates the code in + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L215-221 + """ + # This feels nasty! There's a lot of mutation of values here, that might + # indicate something untoward in the design of this function and its + # dependents ==> Review + samples = [] + values = [] + rets = [] + for order in orders: + temp_val = traits_data_list[order] + for i, sample in enumerate(samplelist): + if temp_val[i] is not None: + samples.append(sample) + values.append(temp_val[i]) + rets.append([order, samples[:], values[:]]) + samples = [] + values = [] + + return rets + +def nearest_marker_finder(genotype): + """ + Returns a function to be used with `genotype` to compute the nearest marker + to the trait passed to the returned function. + + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L425-434 + """ + def __compute_distances(chromo, trait): + loci = chromo.get("loci", None) + if not loci: + return None + return tuple( + { + "name": locus["name"], + "distance": abs(locus["Mb"] - trait["mb"]) + } for locus in loci) + + def __finder(trait): + _chrs = tuple( + _chr for _chr in genotype["chromosomes"] + if str(_chr["name"]) == str(trait["chr"])) + if len(_chrs) == 0: + return None + distances = tuple( + distance for dists in + filter( + lambda x: x is not None, + (__compute_distances(_chr, trait) for _chr in _chrs)) + for distance in dists) + nearest = min(distances, key=lambda d: d["distance"]) + return nearest["name"] + return __finder + +def get_nearest_marker(traits_list, genotype): + """ + Retrieves the nearest marker for each of the traits in the list. + + DESCRIPTION: + This migrates the code in + https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/heatmap/Heatmap.py#L419-L438 + """ + if not genotype["Mbmap"]: + return [None] * len(traits_list) + + marker_finder = nearest_marker_finder(genotype) + return [marker_finder(trait) for trait in traits_list] + +def get_lrs_from_chr(trait, chr_name): + """ + Retrieve the LRS values for a specific chromosome in the given trait. + """ + chromosome = trait["chromosomes"].get(chr_name) + if chromosome: + return [ + locus["LRS"] for locus in + sorted(chromosome["loci"], key=lambda loc: loc["Locus"])] + return [None] + +def process_traits_data_for_heatmap(data, trait_names, chromosome_names): + """ + Process the traits data in a format useful for generating heatmap diagrams. + """ + hdata = [ + [get_lrs_from_chr(data[trait], chr_name) for trait in trait_names] + for chr_name in chromosome_names] + return hdata + +def generate_clustered_heatmap( + data, clustering_data, image_filename_prefix, x_axis=None, + x_label: str = "", y_axis=None, y_label: str = "", + loci_names: Sequence[Sequence[str]] = tuple(), + output_dir: str = TMPDIR, + colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))): + """ + Generate a dendrogram, and heatmaps for each chromosome, and put them all + into one plot. + """ + # pylint: disable=[R0913, R0914] + num_cols = 1 + len(x_axis) + fig = make_subplots( + rows=1, + cols=num_cols, + shared_yaxes="rows", + horizontal_spacing=0.001, + subplot_titles=["distance"] + x_axis, + figure=ff.create_dendrogram( + np.array(clustering_data), orientation="right", labels=y_axis)) + hms = [go.Heatmap( + name=chromo, + x=loci, + y=y_axis, + z=data_array, + showscale=False) + for chromo, data_array, loci + in zip(x_axis, data, loci_names)] + for i, heatmap in enumerate(hms): + fig.add_trace(heatmap, row=1, col=(i + 2)) + + fig.update_layout( + { + "width": 1500, + "height": 800, + "xaxis": { + "mirror": False, + "showgrid": True, + "title": x_label + }, + "yaxis": { + "title": y_label + } + }) + + x_axes_layouts = { + "xaxis{}".format(i+1 if i > 0 else ""): { + "mirror": False, + "showticklabels": i == 0, + "ticks": "outside" if i == 0 else "" + } + for i in range(num_cols)} + + fig.update_layout( + { + "width": 4000, + "height": 800, + "yaxis": { + "mirror": False, + "ticks": "" + }, + **x_axes_layouts}) + fig.update_traces( + showlegend=False, + colorscale=colorscale, + selector={"type": "heatmap"}) + fig.update_traces( + showlegend=True, + showscale=True, + selector={"name": x_axis[-1]}) + image_filename = "{}/{}.html".format(output_dir, image_filename_prefix) + fig.write_html(image_filename) + return image_filename, fig diff --git a/gn3/heatmaps/heatmaps.py b/gn3/heatmaps/heatmaps.py deleted file mode 100644 index 3bf7917..0000000 --- a/gn3/heatmaps/heatmaps.py +++ /dev/null @@ -1,54 +0,0 @@ -import random -import plotly.express as px - -#### Remove these #### - -heatmap_dir = "heatmap_images" - -def generate_random_data(data_stop: float = 2, width: int = 10, height: int = 30): - """ - This is mostly a utility function to be used to generate random data, useful - for development of the heatmap generation code, without access to the actual - database data. - """ - return [[random.uniform(0,data_stop) for i in range(0, width)] - for j in range(0, height)] - -def heatmap_x_axis_names(): - return [ - "UCLA_BXDBXH_CARTILAGE_V2::ILM103710672", - "UCLA_BXDBXH_CARTILAGE_V2::ILM2260338", - "UCLA_BXDBXH_CARTILAGE_V2::ILM3140576", - "UCLA_BXDBXH_CARTILAGE_V2::ILM5670577", - "UCLA_BXDBXH_CARTILAGE_V2::ILM2070121", - "UCLA_BXDBXH_CARTILAGE_V2::ILM103990541", - "UCLA_BXDBXH_CARTILAGE_V2::ILM1190722", - "UCLA_BXDBXH_CARTILAGE_V2::ILM6590722", - "UCLA_BXDBXH_CARTILAGE_V2::ILM4200064", - "UCLA_BXDBXH_CARTILAGE_V2::ILM3140463"] -#### END: Remove these #### - -# Grey + Blue + Red -def generate_heatmap(): - rows = 20 - data = generate_random_data(height=rows) - y = (["%s"%x for x in range(1, rows+1)][:-1] + ["X"]) #replace last item with x for now - fig = px.imshow( - data, - x=heatmap_x_axis_names(), - y=y, - width=500) - fig.update_traces(xtype="array") - fig.update_traces(ytype="array") - # fig.update_traces(xgap=10) - fig.update_xaxes( - visible=True, - title_text="Traits", - title_font_size=16) - fig.update_layout( - coloraxis_colorscale=[ - [0.0, '#3B3B3B'], [0.4999999999999999, '#ABABAB'], - [0.5, '#F5DE11'], [1.0, '#FF0D00']]) - - fig.write_html("%s/%s"%(heatmap_dir, "test_image.html")) - return fig diff --git a/gn3/random.py b/gn3/random.py new file mode 100644 index 0000000..f0ba574 --- /dev/null +++ b/gn3/random.py @@ -0,0 +1,11 @@ +""" +Functions to generate complex random data. +""" +import random +import string + +def random_string(length): + """Generate a random string of length `length`.""" + return "".join( + random.choices( + string.ascii_letters + string.digits, k=length)) diff --git a/gn3/settings.py b/gn3/settings.py index f4866d5..150d96d 100644 --- a/gn3/settings.py +++ b/gn3/settings.py @@ -24,3 +24,24 @@ GN2_BASE_URL = "http://www.genenetwork.org/" # biweight script BIWEIGHT_RSCRIPT = "~/genenetwork3/scripts/calculate_biweight.R" + +# wgcna script +WGCNA_RSCRIPT = "wgcna_analysis.R" +# qtlreaper command +REAPER_COMMAND = "{}/bin/qtlreaper".format(os.environ.get("GUIX_ENVIRONMENT")) + +# genotype files +GENOTYPE_FILES = os.environ.get( + "GENOTYPE_FILES", "{}/genotype_files/genotype".format(os.environ.get("HOME"))) + +# CROSS-ORIGIN SETUP +CORS_ORIGINS = [ + "http://localhost:*", + "http://127.0.0.1:*" +] + +CORS_HEADERS = [ + "Content-Type", + "Authorization", + "Access-Control-Allow-Credentials" +] |