diff options
author | Arun Isaac | 2023-12-29 18:55:37 +0000 |
---|---|---|
committer | Arun Isaac | 2023-12-29 19:01:46 +0000 |
commit | 204a308be0f741726b9a620d88fbc22b22124c81 (patch) | |
tree | b3cf66906674020b530c844c2bb4982c8a0e2d39 /gn2/wqflask/correlation | |
parent | 83062c75442160427b50420161bfcae2c5c34c84 (diff) | |
download | genenetwork2-204a308be0f741726b9a620d88fbc22b22124c81.tar.gz |
Namespace all modules under gn2.
We move all modules under a gn2 directory. This is important for
"correct" packaging and deployment as a Guix service.
Diffstat (limited to 'gn2/wqflask/correlation')
-rw-r--r-- | gn2/wqflask/correlation/__init__.py | 0 | ||||
-rw-r--r-- | gn2/wqflask/correlation/corr_scatter_plot.py | 158 | ||||
-rw-r--r-- | gn2/wqflask/correlation/correlation_functions.py | 68 | ||||
-rw-r--r-- | gn2/wqflask/correlation/correlation_gn3_api.py | 262 | ||||
-rw-r--r-- | gn2/wqflask/correlation/exceptions.py | 16 | ||||
-rw-r--r-- | gn2/wqflask/correlation/pre_computes.py | 178 | ||||
-rw-r--r-- | gn2/wqflask/correlation/rust_correlation.py | 408 | ||||
-rw-r--r-- | gn2/wqflask/correlation/show_corr_results.py | 406 |
8 files changed, 1496 insertions, 0 deletions
diff --git a/gn2/wqflask/correlation/__init__.py b/gn2/wqflask/correlation/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/gn2/wqflask/correlation/__init__.py diff --git a/gn2/wqflask/correlation/corr_scatter_plot.py b/gn2/wqflask/correlation/corr_scatter_plot.py new file mode 100644 index 00000000..59e9ac4a --- /dev/null +++ b/gn2/wqflask/correlation/corr_scatter_plot.py @@ -0,0 +1,158 @@ +import json +import math + +from redis import Redis +Redis = Redis() + +from gn2.base.trait import create_trait, retrieve_sample_data +from gn2.base import data_set, webqtlCaseData +from gn2.utility import corr_result_helpers +from gn2.wqflask.oauth2.collections import num_collections + +from scipy import stats +import numpy as np + +import logging +logger = logging.getLogger(__name__) + +class CorrScatterPlot: + """Page that displays a correlation scatterplot with a line fitted to it""" + + def __init__(self, params): + if "Temp" in params['dataset_1']: + self.dataset_1 = data_set.create_dataset( + dataset_name="Temp", dataset_type="Temp", group_name=params['dataset_1'].split("_")[1]) + else: + self.dataset_1 = data_set.create_dataset(params['dataset_1']) + if "Temp" in params['dataset_2']: + self.dataset_2 = data_set.create_dataset( + dataset_name="Temp", dataset_type="Temp", group_name=params['dataset_2'].split("_")[1]) + else: + self.dataset_2 = data_set.create_dataset(params['dataset_2']) + + self.trait_1 = create_trait( + name=params['trait_1'], dataset=self.dataset_1) + self.trait_2 = create_trait( + name=params['trait_2'], dataset=self.dataset_2) + + self.method = params['method'] + + primary_samples = self.dataset_1.group.samplelist + if self.dataset_1.group.parlist != None: + primary_samples += self.dataset_1.group.parlist + if self.dataset_1.group.f1list != None: + primary_samples += self.dataset_1.group.f1list + + if 'dataid' in params: + trait_data_dict = json.loads(Redis.get(params['dataid'])) + trait_data = {key:webqtlCaseData.webqtlCaseData(key, float(trait_data_dict[key])) for (key, value) in trait_data_dict.items() if trait_data_dict[key] != "x"} + trait_1_data = trait_data + trait_2_data = self.trait_2.data + # Check if the cached data should be used for the second trait instead + if 'cached_trait' in params: + if params['cached_trait'] == 'trait_2': + trait_2_data = trait_data + trait_1_data = self.trait_1.data + samples_1, samples_2, num_overlap = corr_result_helpers.normalize_values_with_samples( + trait_1_data, trait_2_data) + else: + samples_1, samples_2, num_overlap = corr_result_helpers.normalize_values_with_samples( + self.trait_1.data, self.trait_2.data) + + self.data = [] + self.indIDs = list(samples_1.keys()) + vals_1 = [] + for sample in list(samples_1.keys()): + vals_1.append(samples_1[sample].value) + self.data.append(vals_1) + vals_2 = [] + for sample in list(samples_2.keys()): + vals_2.append(samples_2[sample].value) + self.data.append(vals_2) + + slope, intercept, r_value, p_value, std_err = stats.linregress( + vals_1, vals_2) + + if slope < 0.001: + slope_string = '%.3E' % slope + else: + slope_string = '%.3f' % slope + + x_buffer = (max(vals_1) - min(vals_1)) * 0.1 + y_buffer = (max(vals_2) - min(vals_2)) * 0.1 + + x_range = [min(vals_1) - x_buffer, max(vals_1) + x_buffer] + y_range = [min(vals_2) - y_buffer, max(vals_2) + y_buffer] + + intercept_coords = get_intercept_coords( + slope, intercept, x_range, y_range) + + rx = stats.rankdata(vals_1) + ry = stats.rankdata(vals_2) + self.rdata = [] + self.rdata.append(rx.tolist()) + self.rdata.append(ry.tolist()) + srslope, srintercept, srr_value, srp_value, srstd_err = stats.linregress( + rx, ry) + + if srslope < 0.001: + srslope_string = '%.3E' % srslope + else: + srslope_string = '%.3f' % srslope + + x_buffer = (max(rx) - min(rx)) * 0.1 + y_buffer = (max(ry) - min(ry)) * 0.1 + + sr_range = [min(rx) - x_buffer, max(rx) + x_buffer] + + sr_intercept_coords = get_intercept_coords( + srslope, srintercept, sr_range, sr_range) + + self.collections_exist = "False" + if num_collections() > 0: + self.collections_exist = "True" + + self.js_data = dict( + data=self.data, + rdata=self.rdata, + indIDs=self.indIDs, + trait_1=self.trait_1.dataset.name + ": " + str(self.trait_1.name), + trait_2=self.trait_2.dataset.name + ": " + str(self.trait_2.name), + samples_1=samples_1, + samples_2=samples_2, + num_overlap=num_overlap, + vals_1=vals_1, + vals_2=vals_2, + x_range=x_range, + y_range=y_range, + sr_range=sr_range, + intercept_coords=intercept_coords, + sr_intercept_coords=sr_intercept_coords, + + slope=slope, + slope_string=slope_string, + intercept=intercept, + r_value=r_value, + p_value=p_value, + + srslope=srslope, + srslope_string=srslope_string, + srintercept=srintercept, + srr_value=srr_value, + srp_value=srp_value + ) + self.jsdata = self.js_data + + +def get_intercept_coords(slope, intercept, x_range, y_range): + intercept_coords = [] + + y1 = slope * x_range[0] + intercept + y2 = slope * x_range[1] + intercept + x1 = (y1 - intercept) / slope + x2 = (y2 - intercept) / slope + + intercept_coords.append([x1, y1]) + intercept_coords.append([x2, y2]) + + return intercept_coords diff --git a/gn2/wqflask/correlation/correlation_functions.py b/gn2/wqflask/correlation/correlation_functions.py new file mode 100644 index 00000000..911f6dc8 --- /dev/null +++ b/gn2/wqflask/correlation/correlation_functions.py @@ -0,0 +1,68 @@ +# Copyright (C) University of Tennessee Health Science Center, Memphis, TN. +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License +# as published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +# See the GNU Affero General Public License for more details. +# +# This program is available from Source Forge: at GeneNetwork Project +# (sourceforge.net/projects/genenetwork/). +# +# Contact Drs. Robert W. Williams and Xiaodong Zhou (2010) +# at rwilliams@uthsc.edu and xzhou15@uthsc.edu +# +# +# +# This module is used by GeneNetwork project (www.genenetwork.org) +# +# Created by GeneNetwork Core Team 2010/08/10 + + +from gn2.base.mrna_assay_tissue_data import MrnaAssayTissueData +from gn3.computations.correlations import compute_corr_coeff_p_value +from gn2.wqflask.database import database_connection +from gn2.utility.tools import get_setting + +##################################################################################### +# Input: primaryValue(list): one list of expression values of one probeSet, +# targetValue(list): one list of expression values of one probeSet, +# method(string): indicate correlation method ('pearson' or 'spearman') +# Output: corr_result(list): first item is Correlation Value, second item is tissue number, +# third item is PValue +# Function: get correlation value,Tissue quantity ,p value result by using R; +# Note : This function is special case since both primaryValue and targetValue are from +# the same dataset. So the length of these two parameters is the same. They are pairs. +# Also, in the datatable TissueProbeSetData, all Tissue values are loaded based on +# the same tissue order +##################################################################################### + + +def cal_zero_order_corr_for_tiss(primary_values, target_values, method="pearson"): + """function use calls gn3 to compute corr,p_val""" + + (corr_coeff, p_val) = compute_corr_coeff_p_value( + primary_values=primary_values, target_values=target_values, corr_method=method) + + return (corr_coeff, len(primary_values), p_val) + +######################################################################################################## +# input: cursor, symbolList (list), dataIdDict(Dict): key is symbol +# output: SymbolValuePairDict(dictionary):one dictionary of Symbol and Value Pair. +# key is symbol, value is one list of expression values of one probeSet. +# function: wrapper function for getSymbolValuePairDict function +# build gene symbol list if necessary, cut it into small lists if necessary, +# then call getSymbolValuePairDict function and merge the results. +######################################################################################################## + + +def get_trait_symbol_and_tissue_values(symbol_list=None): + with database_connection(get_setting("SQL_URI")) as conn: + tissue_data = MrnaAssayTissueData(gene_symbols=symbol_list, conn=conn) + if len(tissue_data.gene_symbols) > 0: + results = tissue_data.get_symbol_values_pairs() + return results diff --git a/gn2/wqflask/correlation/correlation_gn3_api.py b/gn2/wqflask/correlation/correlation_gn3_api.py new file mode 100644 index 00000000..76c75ec3 --- /dev/null +++ b/gn2/wqflask/correlation/correlation_gn3_api.py @@ -0,0 +1,262 @@ +"""module that calls the gn3 api's to do the correlation """ +import json +import time +from functools import wraps + +from gn2.utility.tools import SQL_URI + +from gn2.wqflask.correlation import correlation_functions +from gn2.base import data_set + +from gn2.base.trait import create_trait +from gn2.base.trait import retrieve_sample_data + +from gn3.db_utils import database_connection +from gn3.commands import run_sample_corr_cmd +from gn3.computations.correlations import map_shared_keys_to_values +from gn3.computations.correlations import compute_all_lit_correlation +from gn3.computations.correlations import compute_tissue_correlation +from gn3.computations.correlations import fast_compute_all_sample_correlation + + +def create_target_this_trait(start_vars): + """this function creates the required trait and target dataset for correlation""" + + if start_vars['dataset'] == "Temp": + this_dataset = data_set.create_dataset( + dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group']) + else: + this_dataset = data_set.create_dataset( + dataset_name=start_vars['dataset']) + target_dataset = data_set.create_dataset( + dataset_name=start_vars['corr_dataset']) + this_trait = create_trait(dataset=this_dataset, + name=start_vars['trait_id']) + sample_data = () + return (this_dataset, this_trait, target_dataset, sample_data) + + +def test_process_data(this_trait, dataset, start_vars): + """test function for bxd,all and other sample data""" + + corr_samples_group = start_vars["corr_samples_group"] + + primary_samples = dataset.group.samplelist + if dataset.group.parlist != None: + primary_samples += dataset.group.parlist + if dataset.group.f1list != None: + primary_samples += dataset.group.f1list + + # If either BXD/whatever Only or All Samples, append all of that group's samplelist + if corr_samples_group != 'samples_other': + sample_data = process_samples(start_vars, primary_samples) + + # If either Non-BXD/whatever or All Samples, get all samples from this_trait.data and + # exclude the primary samples (because they would have been added in the previous + # if statement if the user selected All Samples) + if corr_samples_group != 'samples_primary': + if corr_samples_group == 'samples_other': + primary_samples = [x for x in primary_samples if x not in ( + dataset.group.parlist + dataset.group.f1list)] + sample_data = process_samples(start_vars, list( + this_trait.data.keys()), primary_samples) + + return sample_data + + +def process_samples(start_vars, sample_names=[], excluded_samples=[]): + """code to fetch correct samples""" + sample_data = {} + sample_vals_dict = json.loads(start_vars["sample_vals"]) + if sample_names: + for sample in sample_names: + if sample in sample_vals_dict and sample not in excluded_samples: + val = sample_vals_dict[sample] + if not val.strip().lower() == "x": + sample_data[str(sample)] = float(val) + + else: + for sample in sample_vals_dict.keys(): + if sample not in excluded_samples: + val = sample_vals_dict[sample] + if not val.strip().lower() == "x": + sample_data[str(sample)] = float(val) + return sample_data + + +def merge_correlation_results(correlation_results, target_correlation_results): + + corr_dict = {} + + for trait_dict in target_correlation_results: + for trait_name, values in trait_dict.items(): + + corr_dict[trait_name] = values + for trait_dict in correlation_results: + for trait_name, values in trait_dict.items(): + + if corr_dict.get(trait_name): + + trait_dict[trait_name].update(corr_dict.get(trait_name)) + + return correlation_results + + +def sample_for_trait_lists(corr_results, target_dataset, + this_trait, this_dataset, start_vars): + """interface function for correlation on top results""" + + (this_trait_data, target_dataset) = fetch_sample_data( + start_vars, this_trait, this_dataset, target_dataset) + correlation_results = run_sample_corr_cmd( + corr_method="pearson", this_trait=this_trait_data, + target_dataset=target_dataset) + + return correlation_results + + +def tissue_for_trait_lists(corr_results, this_dataset, this_trait): + """interface function for doing tissue corr_results on trait_list""" + trait_lists = dict([(list(corr_result)[0], True) + for corr_result in corr_results]) + # trait_lists = {list(corr_results)[0]: 1 for corr_result in corr_results} + traits_symbol_dict = this_dataset.retrieve_genes("Symbol") + traits_symbol_dict = dict({trait_name: symbol for ( + trait_name, symbol) in traits_symbol_dict.items() if trait_lists.get(trait_name)}) + tissue_input = get_tissue_correlation_input( + this_trait, traits_symbol_dict) + + if tissue_input is not None: + (primary_tissue_data, target_tissue_data) = tissue_input + corr_results = compute_tissue_correlation( + primary_tissue_dict=primary_tissue_data, + target_tissues_data=target_tissue_data, + corr_method="pearson") + return corr_results + + +def lit_for_trait_list(corr_results, this_dataset, this_trait): + (this_trait_geneid, geneid_dict, species) = do_lit_correlation( + this_trait, this_dataset) + + # trait_lists = {list(corr_results)[0]: 1 for corr_result in corr_results} + trait_lists = dict([(list(corr_result)[0], True) + for corr_result in corr_results]) + + geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if + trait_lists.get(trait_name)} + + with database_connection(SQL_URI) as conn: + correlation_results = compute_all_lit_correlation( + conn=conn, trait_lists=list(geneid_dict.items()), + species=species, gene_id=this_trait_geneid) + + return correlation_results + + +def fetch_sample_data(start_vars, this_trait, this_dataset, target_dataset): + + corr_samples_group = start_vars["corr_samples_group"] + if corr_samples_group == "samples_primary": + sample_data = process_samples( + start_vars, this_dataset.group.samplelist) + + elif corr_samples_group == "samples_other": + sample_data = process_samples( + start_vars, excluded_samples=this_dataset.group.samplelist) + + else: + sample_data = process_samples(start_vars, + this_dataset.group.all_samples_ordered()) + + target_dataset.get_trait_data(list(sample_data.keys())) + this_trait = retrieve_sample_data(this_trait, this_dataset) + this_trait_data = { + "trait_sample_data": sample_data, + "trait_id": start_vars["trait_id"] + } + results = map_shared_keys_to_values( + target_dataset.samplelist, target_dataset.trait_data) + + return (this_trait_data, results) + + +def compute_correlation(start_vars, method="pearson", compute_all=False): + """Compute correlations using GN3 API + + Keyword arguments: + start_vars -- All input from form; includes things like the trait/dataset names + method -- Correlation method to be used (pearson, spearman, or bicor) + compute_all -- Include sample, tissue, and literature correlations (when applicable) + """ + from gn2.wqflask.correlation.rust_correlation import compute_correlation_rust + + corr_type = start_vars['corr_type'] + method = start_vars['corr_sample_method'] + corr_return_results = int(start_vars.get("corr_return_results", 100)) + return compute_correlation_rust( + start_vars, corr_type, method, corr_return_results, compute_all) + + +def compute_corr_for_top_results(start_vars, + correlation_results, + this_trait, + this_dataset, + target_dataset, + corr_type): + if corr_type != "tissue" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet": + tissue_result = tissue_for_trait_lists( + correlation_results, this_dataset, this_trait) + + if tissue_result: + correlation_results = merge_correlation_results( + correlation_results, tissue_result) + + if corr_type != "lit" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet": + lit_result = lit_for_trait_list( + correlation_results, this_dataset, this_trait) + + if lit_result: + correlation_results = merge_correlation_results( + correlation_results, lit_result) + + if corr_type != "sample" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet": + sample_result = sample_for_trait_lists( + correlation_results, target_dataset, this_trait, this_dataset, start_vars) + if sample_result: + correlation_results = merge_correlation_results( + correlation_results, sample_result) + + return correlation_results + + +def do_lit_correlation(this_trait, this_dataset): + """function for fetching lit inputs""" + geneid_dict = this_dataset.retrieve_genes("GeneId") + species = this_dataset.group.species + if species: + species = species.lower() + trait_geneid = this_trait.geneid + return (trait_geneid, geneid_dict, species) + + +def get_tissue_correlation_input(this_trait, trait_symbol_dict): + """Gets tissue expression values for the primary trait and target tissues values""" + primary_trait_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values( + symbol_list=[this_trait.symbol]) + if this_trait.symbol and this_trait.symbol.lower() in primary_trait_tissue_vals_dict: + + primary_trait_tissue_values = primary_trait_tissue_vals_dict[this_trait.symbol.lower( + )] + corr_result_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values( + symbol_list=list(trait_symbol_dict.values())) + primary_tissue_data = { + "this_id": this_trait.name, + "tissue_values": primary_trait_tissue_values + + } + target_tissue_data = { + "trait_symbol_dict": trait_symbol_dict, + "symbol_tissue_vals_dict": corr_result_tissue_vals_dict + } + return (primary_tissue_data, target_tissue_data) diff --git a/gn2/wqflask/correlation/exceptions.py b/gn2/wqflask/correlation/exceptions.py new file mode 100644 index 00000000..f4e2b72b --- /dev/null +++ b/gn2/wqflask/correlation/exceptions.py @@ -0,0 +1,16 @@ +"""Correlation-Specific Exceptions""" + +class WrongCorrelationType(Exception): + """Raised when a correlation is requested for incompatible datasets.""" + + def __init__(self, trait, target_dataset, corr_method): + corr_method = { + "lit": "Literature", + "tissue": "Tissue" + }[corr_method] + message = ( + f"It is not possible to compute the '{corr_method}' correlations " + f"between trait '{trait.name}' and the data in the " + f"'{target_dataset.fullname}' dataset. " + "Please try again after selecting another type of correlation.") + super().__init__(message) diff --git a/gn2/wqflask/correlation/pre_computes.py b/gn2/wqflask/correlation/pre_computes.py new file mode 100644 index 00000000..4bd888ad --- /dev/null +++ b/gn2/wqflask/correlation/pre_computes.py @@ -0,0 +1,178 @@ +import csv +import json +import os +import hashlib +import datetime + +import lmdb +import pickle +from pathlib import Path + +from gn2.base.data_set import query_table_timestamp +from gn2.base.webqtlConfig import TEXTDIR +from gn2.base.webqtlConfig import TMPDIR + +from json.decoder import JSONDecodeError + +def cache_trait_metadata(dataset_name, data): + + + try: + with lmdb.open(os.path.join(TMPDIR,f"metadata_{dataset_name}"),map_size=20971520) as env: + with env.begin(write=True) as txn: + data_bytes = pickle.dumps(data) + txn.put(f"{dataset_name}".encode(), data_bytes) + current_date = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + txn.put(b"creation_date", current_date.encode()) + return "success" + + except lmdb.Error as error: + pass + +def read_trait_metadata(dataset_name): + try: + with lmdb.open(os.path.join(TMPDIR,f"metadata_{dataset_name}"), + readonly=True, lock=False) as env: + with env.begin() as txn: + db_name = txn.get(dataset_name.encode()) + return (pickle.loads(db_name) if db_name else {}) + except lmdb.Error as error: + return {} + + +def fetch_all_cached_metadata(dataset_name): + """in a gvein dataset fetch all the traits metadata""" + file_name = generate_filename(dataset_name, suffix="metadata") + + file_path = Path(TMPDIR, file_name) + + try: + with open(file_path, "r+") as file_handler: + dataset_metadata = json.load(file_handler) + + return (file_path, dataset_metadata) + + except FileNotFoundError: + pass + + except JSONDecodeError: + file_path.unlink() + + file_path.touch(exist_ok=True) + + return (file_path, {}) + + +def cache_new_traits_metadata(dataset_metadata: dict, new_traits_metadata, file_path: str): + """function to cache the new traits metadata""" + + if (dataset_metadata == {} and new_traits_metadata == {}): + return + + dataset_metadata.update(new_traits_metadata) + + with open(file_path, "w+") as file_handler: + json.dump(dataset_metadata, file_handler) + + +def generate_filename(*args, suffix="", file_ext="json"): + """given a list of args generate a unique filename""" + + string_unicode = f"{*args,}".encode() + return f"{hashlib.md5(string_unicode).hexdigest()}_{suffix}.{file_ext}" + + + + +def fetch_text_file(dataset_name, conn, text_dir=TMPDIR): + """fetch textfiles with strain vals if exists""" + + def __file_scanner__(text_dir, target_file): + for file in os.listdir(text_dir): + if file.startswith(f"ProbeSetFreezeId_{target_file}_"): + return os.path.join(text_dir, file) + + with conn.cursor() as cursor: + cursor.execute( + 'SELECT Id, FullName FROM ProbeSetFreeze WHERE Name = %s', (dataset_name,)) + results = cursor.fetchone() + if results: + try: + # checks first for recently generated textfiles if not use gn1 datamatrix + + return __file_scanner__(text_dir, results[0]) or __file_scanner__(TEXTDIR, results[0]) + + except Exception: + pass + + +def read_text_file(sample_dict, file_path): + + def __fetch_id_positions__(all_ids, target_ids): + _vals = [] + _posit = [0] # alternative for parsing + + for (idx, strain) in enumerate(all_ids, 1): + if strain in target_ids: + _vals.append(target_ids[strain]) + _posit.append(idx) + + return (_posit, _vals) + + with open(file_path) as csv_file: + csv_reader = csv.reader(csv_file, delimiter=',') + _posit, sample_vals = __fetch_id_positions__( + next(csv_reader)[1:], sample_dict) + return (sample_vals, [[line[i] for i in _posit] for line in csv_reader]) + + +def write_db_to_textfile(db_name, conn, text_dir=TMPDIR): + + def __sanitise_filename__(filename): + ttable = str.maketrans({" ": "_", "/": "_", "\\": "_"}) + return str.translate(filename, ttable) + + def __generate_file_name__(db_name): + # todo add expiry time and checker + with conn.cursor() as cursor: + cursor.execute( + 'SELECT Id, FullName FROM ProbeSetFreeze WHERE Name = %s', (db_name,)) + results = cursor.fetchone() + if (results): + return __sanitise_filename__( + f"ProbeSetFreezeId_{results[0]}_{results[1]}") + + def __parse_to_dict__(results): + ids = ["ID"] + data = {} + for (trait, strain, val) in results: + if strain not in ids: + ids.append(strain) + if trait in data: + data[trait].append(val) + else: + data[trait] = [trait, val] + return (data, ids) + + def __write_to_file__(file_path, data, col_names): + with open(file_path, 'w+', encoding='UTF8') as file_handler: + writer = csv.writer(file_handler) + writer.writerow(col_names) + writer.writerows(data.values()) + + with conn.cursor() as cursor: + cursor.execute( + "SELECT ProbeSet.Name, Strain.Name, ProbeSetData.value " + "FROM Strain LEFT JOIN ProbeSetData " + "ON Strain.Id = ProbeSetData.StrainId " + "LEFT JOIN ProbeSetXRef ON ProbeSetData.Id = ProbeSetXRef.DataId " + "LEFT JOIN ProbeSet ON ProbeSetXRef.ProbeSetId = ProbeSet.Id " + "WHERE ProbeSetXRef.ProbeSetFreezeId IN " + "(SELECT Id FROM ProbeSetFreeze WHERE Name = %s) " + "ORDER BY Strain.Name", + (db_name,)) + results = cursor.fetchall() + file_name = __generate_file_name__(db_name) + if (results and file_name): + __write_to_file__(os.path.join(text_dir, file_name), + *__parse_to_dict__(results)) diff --git a/gn2/wqflask/correlation/rust_correlation.py b/gn2/wqflask/correlation/rust_correlation.py new file mode 100644 index 00000000..a0dcbcb4 --- /dev/null +++ b/gn2/wqflask/correlation/rust_correlation.py @@ -0,0 +1,408 @@ +"""module contains integration code for rust-gn3""" +import json +from functools import reduce + +from gn2.utility.tools import SQL_URI +from gn2.utility.db_tools import mescape +from gn2.utility.db_tools import create_in_clause +from gn2.wqflask.correlation.correlation_functions\ + import get_trait_symbol_and_tissue_values +from gn2.wqflask.correlation.correlation_gn3_api import create_target_this_trait +from gn2.wqflask.correlation.correlation_gn3_api import lit_for_trait_list +from gn2.wqflask.correlation.correlation_gn3_api import do_lit_correlation +from gn2.wqflask.correlation.pre_computes import fetch_text_file +from gn2.wqflask.correlation.pre_computes import read_text_file +from gn2.wqflask.correlation.pre_computes import write_db_to_textfile +from gn2.wqflask.correlation.pre_computes import read_trait_metadata +from gn2.wqflask.correlation.pre_computes import cache_trait_metadata +from gn3.computations.correlations import compute_all_lit_correlation +from gn3.computations.rust_correlation import run_correlation +from gn3.computations.rust_correlation import get_sample_corr_data +from gn3.computations.rust_correlation import parse_tissue_corr_data +from gn3.db_utils import database_connection + +from gn2.wqflask.correlation.exceptions import WrongCorrelationType + + +def query_probes_metadata(dataset, trait_list): + """query traits metadata in bulk for probeset""" + + if not bool(trait_list) or dataset.type != "ProbeSet": + return [] + + with database_connection(SQL_URI) as conn: + with conn.cursor() as cursor: + + query = """ + SELECT ProbeSet.Name,ProbeSet.Chr,ProbeSet.Mb, + ProbeSet.Symbol,ProbeSetXRef.mean, + CONCAT_WS('; ', ProbeSet.description, ProbeSet.Probe_Target_Description) AS description, + ProbeSetXRef.additive,ProbeSetXRef.LRS,Geno.Chr, Geno.Mb + FROM ProbeSet INNER JOIN ProbeSetXRef + ON ProbeSet.Id=ProbeSetXRef.ProbeSetId + INNER JOIN Geno + ON ProbeSetXRef.Locus = Geno.Name + INNER JOIN Species + ON Geno.SpeciesId = Species.Id + WHERE ProbeSet.Name in ({}) AND + Species.Name = %s AND + ProbeSetXRef.ProbeSetFreezeId IN ( + SELECT ProbeSetFreeze.Id + FROM ProbeSetFreeze WHERE ProbeSetFreeze.Name = %s) + """.format(", ".join(["%s"] * len(trait_list))) + + cursor.execute(query, + (tuple(trait_list) + + (dataset.group.species,) + (dataset.name,)) + ) + + return cursor.fetchall() + + +def get_metadata(dataset, traits): + """Retrieve the metadata""" + def __location__(probe_chr, probe_mb): + if probe_mb: + return f"Chr{probe_chr}: {probe_mb:.6f}" + return f"Chr{probe_chr}: ???" + cached_metadata = read_trait_metadata(dataset.name) + to_fetch_metadata = list( + set(traits).difference(list(cached_metadata.keys()))) + if to_fetch_metadata: + results = {**({trait_name: { + "name": trait_name, + "view": True, + "symbol": symbol, + "dataset": dataset.name, + "dataset_name": dataset.shortname, + "mean": mean, + "description": description, + "additive": additive, + "lrs_score": f"{lrs:3.1f}" if lrs else "", + "location": __location__(probe_chr, probe_mb), + "chr": probe_chr, + "mb": probe_mb, + "lrs_location": f'Chr{chr_score}: {mb:{".6f" if mb else ""}}', + "lrs_chr": chr_score, + "lrs_mb": mb + + } for trait_name, probe_chr, probe_mb, symbol, mean, description, + additive, lrs, chr_score, mb + in query_probes_metadata(dataset, to_fetch_metadata)}), **cached_metadata} + cache_trait_metadata(dataset.name, results) + return results + return cached_metadata + + +def chunk_dataset(dataset, steps, name): + + results = [] + + query = """ + SELECT ProbeSetXRef.DataId,ProbeSet.Name + FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze + WHERE ProbeSetFreeze.Name = '{}' AND + ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND + ProbeSetXRef.ProbeSetId = ProbeSet.Id + """.format(name) + + with database_connection(SQL_URI) as conn: + with conn.cursor() as curr: + curr.execute(query) + traits_name_dict = dict(curr.fetchall()) + + for i in range(0, len(dataset), steps): + matrix = list(dataset[i:i + steps]) + results.append([traits_name_dict[matrix[0][0]]] + [str(value) + for (trait_name, strain, value) in matrix]) + return results + + +def compute_top_n_sample(start_vars, dataset, trait_list): + """check if dataset is of type probeset""" + + if dataset.type.lower() != "probeset": + return {} + + def __fetch_sample_ids__(samples_vals, samples_group): + sample_data = get_sample_corr_data( + sample_type=samples_group, + sample_data=json.loads(samples_vals), + dataset_samples=dataset.group.all_samples_ordered()) + + with database_connection(SQL_URI) as conn: + with conn.cursor() as curr: + curr.execute( + """ + SELECT Strain.Name, Strain.Id FROM Strain, Species + WHERE Strain.Name IN {} + and Strain.SpeciesId=Species.Id + and Species.name = '{}' + """.format(create_in_clause(list(sample_data.keys())), + *mescape(dataset.group.species))) + return (sample_data, dict(curr.fetchall())) + + (sample_data, sample_ids) = __fetch_sample_ids__( + start_vars["sample_vals"], start_vars["corr_samples_group"]) + + if len(trait_list) == 0: + return {} + + with database_connection(SQL_URI) as conn: + with conn.cursor() as curr: + # fetching strain data in bulk + query = ( + "SELECT * from ProbeSetData " + f"WHERE StrainID IN ({', '.join(['%s'] * len(sample_ids))}) " + "AND Id IN (" + " SELECT ProbeSetXRef.DataId " + " FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) " + " WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id " + " AND ProbeSetFreeze.Name = %s " + " AND ProbeSet.Name " + f" IN ({', '.join(['%s'] * len(trait_list))}) " + " AND ProbeSet.Id = ProbeSetXRef.ProbeSetId" + ")") + curr.execute( + query, + tuple(sample_ids.values()) + (dataset.name,) + tuple(trait_list)) + + corr_data = chunk_dataset( + list(curr.fetchall()), len(sample_ids.values()), dataset.name) + + return run_correlation( + corr_data, list(sample_data.values()), "pearson", ",") + + +def compute_top_n_lit(corr_results, target_dataset, this_trait) -> dict: + if not __datasets_compatible_p__(this_trait.dataset, target_dataset, "lit"): + return {} + + (this_trait_geneid, geneid_dict, species) = do_lit_correlation( + this_trait, target_dataset) + + geneid_dict = {trait_name: geneid for (trait_name, geneid) + in geneid_dict.items() if + corr_results.get(trait_name)} + with database_connection(SQL_URI) as conn: + return reduce( + lambda acc, corr: {**acc, **corr}, + compute_all_lit_correlation( + conn=conn, trait_lists=list(geneid_dict.items()), + species=species, gene_id=this_trait_geneid), + {}) + + return {} + + +def compute_top_n_tissue(target_dataset, this_trait, traits, method): + # refactor lots of rpt + if not __datasets_compatible_p__(this_trait.dataset, target_dataset, "tissue"): + return {} + + trait_symbol_dict = dict({ + trait_name: symbol + for (trait_name, symbol) + in target_dataset.retrieve_genes("Symbol").items() + if traits.get(trait_name)}) + + corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values( + symbol_list=list(trait_symbol_dict.values())) + + data = parse_tissue_corr_data(symbol_name=this_trait.symbol, + symbol_dict=get_trait_symbol_and_tissue_values( + symbol_list=[this_trait.symbol]), + dataset_symbols=trait_symbol_dict, + dataset_vals=corr_result_tissue_vals_dict) + + if data and data[0]: + return run_correlation( + data[1], data[0], method, ",", "tissue") + + return {} + + +def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]: + """code to merge diff corr into individual dicts + a""" + + def __merge__(trait_name, trait_corrs): + return { + trait_name: { + **trait_corrs, + **dict_b.get(trait_name, {}), + **dict_c.get(trait_name, {}) + } + } + return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()] + + +def __compute_sample_corr__( + start_vars: dict, corr_type: str, method: str, n_top: int, + target_trait_info: tuple): + """Compute the sample correlations""" + (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info + + if this_dataset.group.f1list != None: + this_dataset.group.samplelist += this_dataset.group.f1list + + if this_dataset.group.parlist != None: + this_dataset.group.samplelist += this_dataset.group.parlist + + sample_data = get_sample_corr_data( + sample_type=start_vars["corr_samples_group"], + sample_data=json.loads(start_vars["sample_vals"]), + dataset_samples=this_dataset.group.all_samples_ordered()) + + if not bool(sample_data): + return {} + + if target_dataset.type == "ProbeSet" and start_vars.get("use_cache") == "true": + with database_connection(SQL_URI) as conn: + file_path = fetch_text_file(target_dataset.name, conn) + if file_path: + (sample_vals, target_data) = read_text_file( + sample_data, file_path) + + return run_correlation(target_data, sample_vals, + method, ",", corr_type, n_top) + + write_db_to_textfile(target_dataset.name, conn) + file_path = fetch_text_file(target_dataset.name, conn) + if file_path: + (sample_vals, target_data) = read_text_file( + sample_data, file_path) + + return run_correlation(target_data, sample_vals, + method, ",", corr_type, n_top) + + target_dataset.get_trait_data(list(sample_data.keys())) + + def __merge_key_and_values__(rows, current): + wo_nones = [value for value in current[1]] + if len(wo_nones) > 0: + return rows + [[current[0]] + wo_nones] + return rows + + target_data = reduce( + __merge_key_and_values__, target_dataset.trait_data.items(), []) + + if len(target_data) == 0: + return {} + + return run_correlation( + target_data, list(sample_data.values()), method, ",", corr_type, + n_top) + + +def __datasets_compatible_p__(trait_dataset, target_dataset, corr_method): + return not ( + corr_method in ("tissue", "Tissue r", "Literature r", "lit") + and (trait_dataset.type == "ProbeSet" and + target_dataset.type in ("Publish", "Geno"))) + + +def __compute_tissue_corr__( + start_vars: dict, corr_type: str, method: str, n_top: int, + target_trait_info: tuple): + """Compute the tissue correlations""" + (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info + if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type): + raise WrongCorrelationType(this_trait, target_dataset, corr_type) + + trait_symbol_dict = target_dataset.retrieve_genes("Symbol") + corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values( + symbol_list=list(trait_symbol_dict.values())) + + data = parse_tissue_corr_data( + symbol_name=this_trait.symbol, + symbol_dict=get_trait_symbol_and_tissue_values( + symbol_list=[this_trait.symbol]), + dataset_symbols=trait_symbol_dict, + dataset_vals=corr_result_tissue_vals_dict) + + if data: + return run_correlation(data[1], data[0], method, ",", "tissue") + return {} + + +def __compute_lit_corr__( + start_vars: dict, corr_type: str, method: str, n_top: int, + target_trait_info: tuple): + """Compute the literature correlations""" + (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info + if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type): + raise WrongCorrelationType(this_trait, target_dataset, corr_type) + + target_dataset_type = target_dataset.type + this_dataset_type = this_dataset.type + (this_trait_geneid, geneid_dict, species) = do_lit_correlation( + this_trait, target_dataset) + + with database_connection(SQL_URI) as conn: + return reduce( + lambda acc, lit: {**acc, **lit}, + compute_all_lit_correlation( + conn=conn, trait_lists=list(geneid_dict.items()), + species=species, gene_id=this_trait_geneid)[:n_top], + {}) + return {} + + +def compute_correlation_rust( + start_vars: dict, corr_type: str, method: str = "pearson", + n_top: int = 500, should_compute_all: bool = False): + """function to compute correlation""" + target_trait_info = create_target_this_trait(start_vars) + (this_dataset, this_trait, target_dataset, sample_data) = ( + target_trait_info) + if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type): + raise WrongCorrelationType(this_trait, target_dataset, corr_type) + + # Replace this with `match ...` once we hit Python 3.10 + corr_type_fns = { + "sample": __compute_sample_corr__, + "tissue": __compute_tissue_corr__, + "lit": __compute_lit_corr__ + } + + results = corr_type_fns[corr_type]( + start_vars, corr_type, method, n_top, target_trait_info) + + # END: Replace this with `match ...` once we hit Python 3.10 + + top_a = top_b = {} + + if should_compute_all: + + if corr_type == "sample": + if this_dataset.type == "ProbeSet": + top_a = compute_top_n_tissue( + target_dataset, this_trait, results, method) + + top_b = compute_top_n_lit(results, target_dataset, this_trait) + else: + pass + + elif corr_type == "lit": + + # currently fails for lit + + top_a = compute_top_n_sample( + start_vars, target_dataset, list(results.keys())) + top_b = compute_top_n_tissue( + target_dataset, this_trait, results, method) + + else: + + top_a = compute_top_n_sample( + start_vars, target_dataset, list(results.keys())) + + return { + "correlation_results": merge_results( + results, top_a, top_b), + "this_trait": this_trait.name, + "target_dataset": start_vars['corr_dataset'], + "traits_metadata": get_metadata(target_dataset, list(results.keys())), + "return_results": n_top + } diff --git a/gn2/wqflask/correlation/show_corr_results.py b/gn2/wqflask/correlation/show_corr_results.py new file mode 100644 index 00000000..c8625222 --- /dev/null +++ b/gn2/wqflask/correlation/show_corr_results.py @@ -0,0 +1,406 @@ +# Copyright (C) University of Tennessee Health Science Center, Memphis, TN. +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License +# as published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +# See the GNU Affero General Public License for more details. +# +# This program is available from Source Forge: at GeneNetwork Project +# (sourceforge.net/projects/genenetwork/). +# +# Contact Dr. Robert W. Williams at rwilliams@uthsc.edu +# +# +# This module is used by GeneNetwork project (www.genenetwork.org) + +import hashlib +import html +import json + +from gn2.base.trait import create_trait, jsonable +from gn2.base.data_set import create_dataset + +from gn2.utility import hmac +from gn2.utility.type_checking import get_float, get_int, get_string +from gn2.utility.redis_tools import get_redis_conn +Redis = get_redis_conn() + +def set_template_vars(start_vars, correlation_data): + corr_type = start_vars['corr_type'] + corr_method = start_vars['corr_sample_method'] + + if start_vars['dataset'] == "Temp": + this_dataset_ob = create_dataset( + dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group']) + else: + this_dataset_ob = create_dataset(dataset_name=start_vars['dataset']) + this_trait = create_trait(dataset=this_dataset_ob, + name=start_vars['trait_id']) + + # Store trait sample data in Redis, so additive effect scatterplots can include edited values + dhash = hashlib.md5() + dhash.update(start_vars['sample_vals'].encode()) + samples_hash = dhash.hexdigest() + Redis.set(samples_hash, start_vars['sample_vals'], ex=7*24*60*60) + correlation_data['dataid'] = samples_hash + + correlation_data['this_trait'] = jsonable(this_trait, this_dataset_ob) + correlation_data['this_dataset'] = this_dataset_ob.as_monadic_dict().data + + target_dataset_ob = create_dataset(correlation_data['target_dataset']) + correlation_data['target_dataset'] = target_dataset_ob.as_monadic_dict().data + correlation_data['table_json'] = correlation_json_for_table( + start_vars, + correlation_data, + target_dataset_ob) + + if target_dataset_ob.type == "ProbeSet": + filter_cols = [7, 6] + elif target_dataset_ob.type == "Publish": + filter_cols = [8, 5] + else: + filter_cols = [4, 0] + + correlation_data['corr_method'] = corr_method + correlation_data['filter_cols'] = filter_cols + correlation_data['header_fields'] = get_header_fields( + target_dataset_ob.type, correlation_data['corr_method']) + correlation_data['formatted_corr_type'] = get_formatted_corr_type( + corr_type, corr_method) + + return correlation_data + + +def apply_filters(trait, target_trait, target_dataset, **filters): + def __p_val_filter__(p_lower, p_upper): + + return not (p_lower <= float(trait.get("corr_coefficient",0.0)) <= p_upper) + + def __min_filter__(min_expr): + if (target_dataset['type'] in ["ProbeSet", "Publish"] and target_trait['mean']): + return (min_expr != None) and (float(target_trait['mean']) < min_expr) + + return False + + def __location_filter__(location_type, location_chr, + min_location_mb, max_location_mb): + + if target_dataset["type"] in ["ProbeSet", "Geno"] and location_type == "gene": + return ( + ((location_chr!=None) and (target_trait["chr"]!=location_chr)) + or + ((min_location_mb!= None) and ( + float(target_trait['mb']) < min_location_mb) + ) + + or + ((max_location_mb != None) and + (float(target_trait['mb']) > float(max_location_mb) + )) + + ) + elif target_dataset["type"] in ["ProbeSet", "Publish"]: + + return ((location_chr!=None) and (target_trait["lrs_chr"] != location_chr) + or + ((min_location_mb != None) and ( + float(target_trait['lrs_mb']) < float(min_location_mb))) + or + ((max_location_mb != None) and ( + float(target_trait['lrs_mb']) > float(max_location_mb)) + ) + + ) + + return True + + if not target_trait: + return True + else: + # check if one of the condition is not met i.e One is True + return (__p_val_filter__( + filters.get("p_range_lower"), + filters.get("p_range_upper") + ) + or + ( + __min_filter__( + filters.get("min_expr") + ) + ) + or + __location_filter__( + filters.get("location_type"), + filters.get("location_chr"), + filters.get("min_location_mb"), + filters.get("max_location_mb") + + + ) + ) + + +def get_user_filters(start_vars): + (min_expr, p_min, p_max) = ( + get_float(start_vars, 'min_expr'), + get_float(start_vars, 'p_range_lower', -1.0), + get_float(start_vars, 'p_range_upper', 1.0) + ) + + if all(keys in start_vars for keys in ["loc_chr", + "min_loc_mb", + "max_location_mb"]): + + location_chr = get_string(start_vars, "loc_chr") + min_location_mb = get_int(start_vars, "min_loc_mb") + max_location_mb = get_int(start_vars, "max_loc_mb") + + else: + location_chr = min_location_mb = max_location_mb = None + + return { + + "min_expr": min_expr, + "p_range_lower": p_min, + "p_range_upper": p_max, + "location_chr": location_chr, + "location_type": start_vars['location_type'], + "min_location_mb": min_location_mb, + "max_location_mb": max_location_mb + + } + + +def generate_table_metadata(all_traits, dataset_metadata, dataset_obj): + + def __fetch_trait_data__(trait, dataset_obj): + target_trait_ob = create_trait(dataset=dataset_obj, + name=trait, + get_qtl_info=True) + return jsonable(target_trait_ob, dataset_obj) + + metadata = [__fetch_trait_data__(trait, dataset_obj) for + trait in (all_traits)] + + return (dataset_metadata | ({str(trait["name"]): trait for trait in metadata})) + + +def populate_table(dataset_metadata, target_dataset, this_dataset, corr_results, filters): + + def __populate_trait__(idx, trait): + + trait_name = list(trait.keys())[0] + target_trait = dataset_metadata.get(trait_name) + trait = trait[trait_name] + if not apply_filters(trait, target_trait, target_dataset, **filters): + results_dict = {} + results_dict['index'] = idx + 1 # + results_dict['trait_id'] = target_trait['name'] + results_dict['dataset'] = target_dataset['name'] + results_dict['hmac'] = hmac.data_hmac( + '{}:{}'.format(target_trait['name'], target_dataset['name'])) + results_dict['sample_r'] = f"{float(trait.get('corr_coefficient',0.0)):.3f}" + results_dict['num_overlap'] = trait.get('num_overlap', 0) + results_dict['sample_p'] = f"{float(trait.get('p_value',0)):.2e}" + if target_dataset['type'] == "ProbeSet": + results_dict['symbol'] = target_trait['symbol'] + results_dict['description'] = "N/A" + results_dict['location'] = target_trait['location'] + results_dict['mean'] = "N/A" + results_dict['additive'] = "N/A" + if target_trait['description'].strip(): + results_dict['description'] = html.escape( + target_trait['description'].strip(), quote=True) + if target_trait['mean']: + results_dict['mean'] = f"{float(target_trait['mean']):.3f}" + try: + results_dict['lod_score'] = f"{float(target_trait['lrs_score']) / 4.61:.1f}" + except: + results_dict['lod_score'] = "N/A" + results_dict['lrs_location'] = target_trait['lrs_location'] + if target_trait['additive']: + results_dict['additive'] = f"{float(target_trait['additive']):.3f}" + results_dict['lit_corr'] = "--" + results_dict['tissue_corr'] = "--" + results_dict['tissue_pvalue'] = "--" + if this_dataset['type'] == "ProbeSet": + if 'lit_corr' in trait: + results_dict['lit_corr'] = ( + f"{float(trait['lit_corr']):.3f}" + if trait["lit_corr"] else "--") + if 'tissue_corr' in trait: + results_dict['tissue_corr'] = f"{float(trait['tissue_corr']):.3f}" + results_dict['tissue_pvalue'] = f"{float(trait['tissue_p_val']):.3e}" + elif target_dataset['type'] == "Publish": + results_dict['abbreviation_display'] = "N/A" + results_dict['description'] = "N/A" + results_dict['mean'] = "N/A" + results_dict['authors_display'] = "N/A" + results_dict['additive'] = "N/A" + results_dict['pubmed_link'] = "N/A" + results_dict['pubmed_text'] = target_trait["pubmed_text"] + + if target_trait["abbreviation"]: + results_dict['abbreviation'] = target_trait['abbreviation'] + + if target_trait["description"].strip(): + results_dict['description'] = html.escape( + target_trait['description'].strip(), quote=True) + + if target_trait["mean"] != "N/A": + results_dict['mean'] = f"{float(target_trait['mean']):.3f}" + + results_dict['lrs_location'] = target_trait['lrs_location'] + + if target_trait["authors"]: + authors_list = target_trait['authors'].split(',') + results_dict['authors_display'] = ", ".join( + authors_list[:6]) + ", et al." if len(authors_list) > 6 else target_trait['authors'] + + if "pubmed_id" in target_trait: + results_dict['pubmed_link'] = target_trait['pubmed_link'] + results_dict['pubmed_text'] = target_trait['pubmed_text'] + try: + results_dict["lod_score"] = f"{float(target_trait['lrs_score']) / 4.61:.1f}" + except ValueError: + results_dict['lod_score'] = "N/A" + else: + results_dict['location'] = target_trait['location'] + + return results_dict + + return [__populate_trait__(idx, trait) + for (idx, trait) in enumerate(corr_results)] + + +def correlation_json_for_table(start_vars, correlation_data, target_dataset_ob): + """Return JSON data for use with the DataTable in the correlation result page + + Keyword arguments: + correlation_data -- Correlation results + this_trait -- Trait being correlated against a dataset, as a dict + this_dataset -- Dataset of this_trait, as a monadic dict + target_dataset_ob - Target dataset, as a Dataset ob + """ + this_dataset = correlation_data['this_dataset'] + + traits = set() + for trait in correlation_data["correlation_results"]: + traits.add(list(trait)[0]) + + dataset_metadata = generate_table_metadata(traits, + correlation_data["traits_metadata"], + target_dataset_ob) + return json.dumps([result for result in ( + populate_table(dataset_metadata=dataset_metadata, + target_dataset=target_dataset_ob.as_monadic_dict().data, + this_dataset=correlation_data['this_dataset'], + corr_results=correlation_data['correlation_results'], + filters=get_user_filters(start_vars))) if result]) + + +def get_formatted_corr_type(corr_type, corr_method): + formatted_corr_type = "" + if corr_type == "lit": + formatted_corr_type += "Literature Correlation " + elif corr_type == "tissue": + formatted_corr_type += "Tissue Correlation " + elif corr_type == "sample": + formatted_corr_type += "Genetic Correlation " + + if corr_method == "pearson": + formatted_corr_type += "(Pearson's r)" + elif corr_method == "spearman": + formatted_corr_type += "(Spearman's rho)" + elif corr_method == "bicor": + formatted_corr_type += "(Biweight r)" + + return formatted_corr_type + + +def get_header_fields(data_type, corr_method): + if data_type == "ProbeSet": + if corr_method == "spearman": + header_fields = ['Index', + 'Record', + 'Symbol', + 'Description', + 'Location', + 'Mean', + 'Sample rho', + 'N', + 'Sample p(rho)', + 'Lit rho', + 'Tissue rho', + 'Tissue p(rho)', + 'Max LRS', + 'Max LRS Location', + 'Additive Effect'] + else: + header_fields = ['Index', + 'Record', + 'Symbol', + 'Description', + 'Location', + 'Mean', + 'Sample r', + 'N', + 'Sample p(r)', + 'Lit r', + 'Tissue r', + 'Tissue p(r)', + 'Max LRS', + 'Max LRS Location', + 'Additive Effect'] + elif data_type == "Publish": + if corr_method == "spearman": + header_fields = ['Index', + 'Record', + 'Abbreviation', + 'Description', + 'Mean', + 'Authors', + 'Year', + 'Sample rho', + 'N', + 'Sample p(rho)', + 'Max LRS', + 'Max LRS Location', + 'Additive Effect'] + else: + header_fields = ['Index', + 'Record', + 'Abbreviation', + 'Description', + 'Mean', + 'Authors', + 'Year', + 'Sample r', + 'N', + 'Sample p(r)', + 'Max LRS', + 'Max LRS Location', + 'Additive Effect'] + + else: + if corr_method == "spearman": + header_fields = ['Index', + 'ID', + 'Location', + 'Sample rho', + 'N', + 'Sample p(rho)'] + else: + header_fields = ['Index', + 'ID', + 'Location', + 'Sample r', + 'N', + 'Sample p(r)'] + + return header_fields |