aboutsummaryrefslogtreecommitdiff
path: root/gn2/wqflask/correlation
diff options
context:
space:
mode:
authorArun Isaac2023-12-29 18:55:37 +0000
committerArun Isaac2023-12-29 19:01:46 +0000
commit204a308be0f741726b9a620d88fbc22b22124c81 (patch)
treeb3cf66906674020b530c844c2bb4982c8a0e2d39 /gn2/wqflask/correlation
parent83062c75442160427b50420161bfcae2c5c34c84 (diff)
downloadgenenetwork2-204a308be0f741726b9a620d88fbc22b22124c81.tar.gz
Namespace all modules under gn2.
We move all modules under a gn2 directory. This is important for "correct" packaging and deployment as a Guix service.
Diffstat (limited to 'gn2/wqflask/correlation')
-rw-r--r--gn2/wqflask/correlation/__init__.py0
-rw-r--r--gn2/wqflask/correlation/corr_scatter_plot.py158
-rw-r--r--gn2/wqflask/correlation/correlation_functions.py68
-rw-r--r--gn2/wqflask/correlation/correlation_gn3_api.py262
-rw-r--r--gn2/wqflask/correlation/exceptions.py16
-rw-r--r--gn2/wqflask/correlation/pre_computes.py178
-rw-r--r--gn2/wqflask/correlation/rust_correlation.py408
-rw-r--r--gn2/wqflask/correlation/show_corr_results.py406
8 files changed, 1496 insertions, 0 deletions
diff --git a/gn2/wqflask/correlation/__init__.py b/gn2/wqflask/correlation/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/gn2/wqflask/correlation/__init__.py
diff --git a/gn2/wqflask/correlation/corr_scatter_plot.py b/gn2/wqflask/correlation/corr_scatter_plot.py
new file mode 100644
index 00000000..59e9ac4a
--- /dev/null
+++ b/gn2/wqflask/correlation/corr_scatter_plot.py
@@ -0,0 +1,158 @@
+import json
+import math
+
+from redis import Redis
+Redis = Redis()
+
+from gn2.base.trait import create_trait, retrieve_sample_data
+from gn2.base import data_set, webqtlCaseData
+from gn2.utility import corr_result_helpers
+from gn2.wqflask.oauth2.collections import num_collections
+
+from scipy import stats
+import numpy as np
+
+import logging
+logger = logging.getLogger(__name__)
+
+class CorrScatterPlot:
+ """Page that displays a correlation scatterplot with a line fitted to it"""
+
+ def __init__(self, params):
+ if "Temp" in params['dataset_1']:
+ self.dataset_1 = data_set.create_dataset(
+ dataset_name="Temp", dataset_type="Temp", group_name=params['dataset_1'].split("_")[1])
+ else:
+ self.dataset_1 = data_set.create_dataset(params['dataset_1'])
+ if "Temp" in params['dataset_2']:
+ self.dataset_2 = data_set.create_dataset(
+ dataset_name="Temp", dataset_type="Temp", group_name=params['dataset_2'].split("_")[1])
+ else:
+ self.dataset_2 = data_set.create_dataset(params['dataset_2'])
+
+ self.trait_1 = create_trait(
+ name=params['trait_1'], dataset=self.dataset_1)
+ self.trait_2 = create_trait(
+ name=params['trait_2'], dataset=self.dataset_2)
+
+ self.method = params['method']
+
+ primary_samples = self.dataset_1.group.samplelist
+ if self.dataset_1.group.parlist != None:
+ primary_samples += self.dataset_1.group.parlist
+ if self.dataset_1.group.f1list != None:
+ primary_samples += self.dataset_1.group.f1list
+
+ if 'dataid' in params:
+ trait_data_dict = json.loads(Redis.get(params['dataid']))
+ trait_data = {key:webqtlCaseData.webqtlCaseData(key, float(trait_data_dict[key])) for (key, value) in trait_data_dict.items() if trait_data_dict[key] != "x"}
+ trait_1_data = trait_data
+ trait_2_data = self.trait_2.data
+ # Check if the cached data should be used for the second trait instead
+ if 'cached_trait' in params:
+ if params['cached_trait'] == 'trait_2':
+ trait_2_data = trait_data
+ trait_1_data = self.trait_1.data
+ samples_1, samples_2, num_overlap = corr_result_helpers.normalize_values_with_samples(
+ trait_1_data, trait_2_data)
+ else:
+ samples_1, samples_2, num_overlap = corr_result_helpers.normalize_values_with_samples(
+ self.trait_1.data, self.trait_2.data)
+
+ self.data = []
+ self.indIDs = list(samples_1.keys())
+ vals_1 = []
+ for sample in list(samples_1.keys()):
+ vals_1.append(samples_1[sample].value)
+ self.data.append(vals_1)
+ vals_2 = []
+ for sample in list(samples_2.keys()):
+ vals_2.append(samples_2[sample].value)
+ self.data.append(vals_2)
+
+ slope, intercept, r_value, p_value, std_err = stats.linregress(
+ vals_1, vals_2)
+
+ if slope < 0.001:
+ slope_string = '%.3E' % slope
+ else:
+ slope_string = '%.3f' % slope
+
+ x_buffer = (max(vals_1) - min(vals_1)) * 0.1
+ y_buffer = (max(vals_2) - min(vals_2)) * 0.1
+
+ x_range = [min(vals_1) - x_buffer, max(vals_1) + x_buffer]
+ y_range = [min(vals_2) - y_buffer, max(vals_2) + y_buffer]
+
+ intercept_coords = get_intercept_coords(
+ slope, intercept, x_range, y_range)
+
+ rx = stats.rankdata(vals_1)
+ ry = stats.rankdata(vals_2)
+ self.rdata = []
+ self.rdata.append(rx.tolist())
+ self.rdata.append(ry.tolist())
+ srslope, srintercept, srr_value, srp_value, srstd_err = stats.linregress(
+ rx, ry)
+
+ if srslope < 0.001:
+ srslope_string = '%.3E' % srslope
+ else:
+ srslope_string = '%.3f' % srslope
+
+ x_buffer = (max(rx) - min(rx)) * 0.1
+ y_buffer = (max(ry) - min(ry)) * 0.1
+
+ sr_range = [min(rx) - x_buffer, max(rx) + x_buffer]
+
+ sr_intercept_coords = get_intercept_coords(
+ srslope, srintercept, sr_range, sr_range)
+
+ self.collections_exist = "False"
+ if num_collections() > 0:
+ self.collections_exist = "True"
+
+ self.js_data = dict(
+ data=self.data,
+ rdata=self.rdata,
+ indIDs=self.indIDs,
+ trait_1=self.trait_1.dataset.name + ": " + str(self.trait_1.name),
+ trait_2=self.trait_2.dataset.name + ": " + str(self.trait_2.name),
+ samples_1=samples_1,
+ samples_2=samples_2,
+ num_overlap=num_overlap,
+ vals_1=vals_1,
+ vals_2=vals_2,
+ x_range=x_range,
+ y_range=y_range,
+ sr_range=sr_range,
+ intercept_coords=intercept_coords,
+ sr_intercept_coords=sr_intercept_coords,
+
+ slope=slope,
+ slope_string=slope_string,
+ intercept=intercept,
+ r_value=r_value,
+ p_value=p_value,
+
+ srslope=srslope,
+ srslope_string=srslope_string,
+ srintercept=srintercept,
+ srr_value=srr_value,
+ srp_value=srp_value
+ )
+ self.jsdata = self.js_data
+
+
+def get_intercept_coords(slope, intercept, x_range, y_range):
+ intercept_coords = []
+
+ y1 = slope * x_range[0] + intercept
+ y2 = slope * x_range[1] + intercept
+ x1 = (y1 - intercept) / slope
+ x2 = (y2 - intercept) / slope
+
+ intercept_coords.append([x1, y1])
+ intercept_coords.append([x2, y2])
+
+ return intercept_coords
diff --git a/gn2/wqflask/correlation/correlation_functions.py b/gn2/wqflask/correlation/correlation_functions.py
new file mode 100644
index 00000000..911f6dc8
--- /dev/null
+++ b/gn2/wqflask/correlation/correlation_functions.py
@@ -0,0 +1,68 @@
+# Copyright (C) University of Tennessee Health Science Center, Memphis, TN.
+#
+# This program is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Affero General Public License
+# as published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+# See the GNU Affero General Public License for more details.
+#
+# This program is available from Source Forge: at GeneNetwork Project
+# (sourceforge.net/projects/genenetwork/).
+#
+# Contact Drs. Robert W. Williams and Xiaodong Zhou (2010)
+# at rwilliams@uthsc.edu and xzhou15@uthsc.edu
+#
+#
+#
+# This module is used by GeneNetwork project (www.genenetwork.org)
+#
+# Created by GeneNetwork Core Team 2010/08/10
+
+
+from gn2.base.mrna_assay_tissue_data import MrnaAssayTissueData
+from gn3.computations.correlations import compute_corr_coeff_p_value
+from gn2.wqflask.database import database_connection
+from gn2.utility.tools import get_setting
+
+#####################################################################################
+# Input: primaryValue(list): one list of expression values of one probeSet,
+# targetValue(list): one list of expression values of one probeSet,
+# method(string): indicate correlation method ('pearson' or 'spearman')
+# Output: corr_result(list): first item is Correlation Value, second item is tissue number,
+# third item is PValue
+# Function: get correlation value,Tissue quantity ,p value result by using R;
+# Note : This function is special case since both primaryValue and targetValue are from
+# the same dataset. So the length of these two parameters is the same. They are pairs.
+# Also, in the datatable TissueProbeSetData, all Tissue values are loaded based on
+# the same tissue order
+#####################################################################################
+
+
+def cal_zero_order_corr_for_tiss(primary_values, target_values, method="pearson"):
+ """function use calls gn3 to compute corr,p_val"""
+
+ (corr_coeff, p_val) = compute_corr_coeff_p_value(
+ primary_values=primary_values, target_values=target_values, corr_method=method)
+
+ return (corr_coeff, len(primary_values), p_val)
+
+########################################################################################################
+# input: cursor, symbolList (list), dataIdDict(Dict): key is symbol
+# output: SymbolValuePairDict(dictionary):one dictionary of Symbol and Value Pair.
+# key is symbol, value is one list of expression values of one probeSet.
+# function: wrapper function for getSymbolValuePairDict function
+# build gene symbol list if necessary, cut it into small lists if necessary,
+# then call getSymbolValuePairDict function and merge the results.
+########################################################################################################
+
+
+def get_trait_symbol_and_tissue_values(symbol_list=None):
+ with database_connection(get_setting("SQL_URI")) as conn:
+ tissue_data = MrnaAssayTissueData(gene_symbols=symbol_list, conn=conn)
+ if len(tissue_data.gene_symbols) > 0:
+ results = tissue_data.get_symbol_values_pairs()
+ return results
diff --git a/gn2/wqflask/correlation/correlation_gn3_api.py b/gn2/wqflask/correlation/correlation_gn3_api.py
new file mode 100644
index 00000000..76c75ec3
--- /dev/null
+++ b/gn2/wqflask/correlation/correlation_gn3_api.py
@@ -0,0 +1,262 @@
+"""module that calls the gn3 api's to do the correlation """
+import json
+import time
+from functools import wraps
+
+from gn2.utility.tools import SQL_URI
+
+from gn2.wqflask.correlation import correlation_functions
+from gn2.base import data_set
+
+from gn2.base.trait import create_trait
+from gn2.base.trait import retrieve_sample_data
+
+from gn3.db_utils import database_connection
+from gn3.commands import run_sample_corr_cmd
+from gn3.computations.correlations import map_shared_keys_to_values
+from gn3.computations.correlations import compute_all_lit_correlation
+from gn3.computations.correlations import compute_tissue_correlation
+from gn3.computations.correlations import fast_compute_all_sample_correlation
+
+
+def create_target_this_trait(start_vars):
+ """this function creates the required trait and target dataset for correlation"""
+
+ if start_vars['dataset'] == "Temp":
+ this_dataset = data_set.create_dataset(
+ dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group'])
+ else:
+ this_dataset = data_set.create_dataset(
+ dataset_name=start_vars['dataset'])
+ target_dataset = data_set.create_dataset(
+ dataset_name=start_vars['corr_dataset'])
+ this_trait = create_trait(dataset=this_dataset,
+ name=start_vars['trait_id'])
+ sample_data = ()
+ return (this_dataset, this_trait, target_dataset, sample_data)
+
+
+def test_process_data(this_trait, dataset, start_vars):
+ """test function for bxd,all and other sample data"""
+
+ corr_samples_group = start_vars["corr_samples_group"]
+
+ primary_samples = dataset.group.samplelist
+ if dataset.group.parlist != None:
+ primary_samples += dataset.group.parlist
+ if dataset.group.f1list != None:
+ primary_samples += dataset.group.f1list
+
+ # If either BXD/whatever Only or All Samples, append all of that group's samplelist
+ if corr_samples_group != 'samples_other':
+ sample_data = process_samples(start_vars, primary_samples)
+
+ # If either Non-BXD/whatever or All Samples, get all samples from this_trait.data and
+ # exclude the primary samples (because they would have been added in the previous
+ # if statement if the user selected All Samples)
+ if corr_samples_group != 'samples_primary':
+ if corr_samples_group == 'samples_other':
+ primary_samples = [x for x in primary_samples if x not in (
+ dataset.group.parlist + dataset.group.f1list)]
+ sample_data = process_samples(start_vars, list(
+ this_trait.data.keys()), primary_samples)
+
+ return sample_data
+
+
+def process_samples(start_vars, sample_names=[], excluded_samples=[]):
+ """code to fetch correct samples"""
+ sample_data = {}
+ sample_vals_dict = json.loads(start_vars["sample_vals"])
+ if sample_names:
+ for sample in sample_names:
+ if sample in sample_vals_dict and sample not in excluded_samples:
+ val = sample_vals_dict[sample]
+ if not val.strip().lower() == "x":
+ sample_data[str(sample)] = float(val)
+
+ else:
+ for sample in sample_vals_dict.keys():
+ if sample not in excluded_samples:
+ val = sample_vals_dict[sample]
+ if not val.strip().lower() == "x":
+ sample_data[str(sample)] = float(val)
+ return sample_data
+
+
+def merge_correlation_results(correlation_results, target_correlation_results):
+
+ corr_dict = {}
+
+ for trait_dict in target_correlation_results:
+ for trait_name, values in trait_dict.items():
+
+ corr_dict[trait_name] = values
+ for trait_dict in correlation_results:
+ for trait_name, values in trait_dict.items():
+
+ if corr_dict.get(trait_name):
+
+ trait_dict[trait_name].update(corr_dict.get(trait_name))
+
+ return correlation_results
+
+
+def sample_for_trait_lists(corr_results, target_dataset,
+ this_trait, this_dataset, start_vars):
+ """interface function for correlation on top results"""
+
+ (this_trait_data, target_dataset) = fetch_sample_data(
+ start_vars, this_trait, this_dataset, target_dataset)
+ correlation_results = run_sample_corr_cmd(
+ corr_method="pearson", this_trait=this_trait_data,
+ target_dataset=target_dataset)
+
+ return correlation_results
+
+
+def tissue_for_trait_lists(corr_results, this_dataset, this_trait):
+ """interface function for doing tissue corr_results on trait_list"""
+ trait_lists = dict([(list(corr_result)[0], True)
+ for corr_result in corr_results])
+ # trait_lists = {list(corr_results)[0]: 1 for corr_result in corr_results}
+ traits_symbol_dict = this_dataset.retrieve_genes("Symbol")
+ traits_symbol_dict = dict({trait_name: symbol for (
+ trait_name, symbol) in traits_symbol_dict.items() if trait_lists.get(trait_name)})
+ tissue_input = get_tissue_correlation_input(
+ this_trait, traits_symbol_dict)
+
+ if tissue_input is not None:
+ (primary_tissue_data, target_tissue_data) = tissue_input
+ corr_results = compute_tissue_correlation(
+ primary_tissue_dict=primary_tissue_data,
+ target_tissues_data=target_tissue_data,
+ corr_method="pearson")
+ return corr_results
+
+
+def lit_for_trait_list(corr_results, this_dataset, this_trait):
+ (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
+ this_trait, this_dataset)
+
+ # trait_lists = {list(corr_results)[0]: 1 for corr_result in corr_results}
+ trait_lists = dict([(list(corr_result)[0], True)
+ for corr_result in corr_results])
+
+ geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if
+ trait_lists.get(trait_name)}
+
+ with database_connection(SQL_URI) as conn:
+ correlation_results = compute_all_lit_correlation(
+ conn=conn, trait_lists=list(geneid_dict.items()),
+ species=species, gene_id=this_trait_geneid)
+
+ return correlation_results
+
+
+def fetch_sample_data(start_vars, this_trait, this_dataset, target_dataset):
+
+ corr_samples_group = start_vars["corr_samples_group"]
+ if corr_samples_group == "samples_primary":
+ sample_data = process_samples(
+ start_vars, this_dataset.group.samplelist)
+
+ elif corr_samples_group == "samples_other":
+ sample_data = process_samples(
+ start_vars, excluded_samples=this_dataset.group.samplelist)
+
+ else:
+ sample_data = process_samples(start_vars,
+ this_dataset.group.all_samples_ordered())
+
+ target_dataset.get_trait_data(list(sample_data.keys()))
+ this_trait = retrieve_sample_data(this_trait, this_dataset)
+ this_trait_data = {
+ "trait_sample_data": sample_data,
+ "trait_id": start_vars["trait_id"]
+ }
+ results = map_shared_keys_to_values(
+ target_dataset.samplelist, target_dataset.trait_data)
+
+ return (this_trait_data, results)
+
+
+def compute_correlation(start_vars, method="pearson", compute_all=False):
+ """Compute correlations using GN3 API
+
+ Keyword arguments:
+ start_vars -- All input from form; includes things like the trait/dataset names
+ method -- Correlation method to be used (pearson, spearman, or bicor)
+ compute_all -- Include sample, tissue, and literature correlations (when applicable)
+ """
+ from gn2.wqflask.correlation.rust_correlation import compute_correlation_rust
+
+ corr_type = start_vars['corr_type']
+ method = start_vars['corr_sample_method']
+ corr_return_results = int(start_vars.get("corr_return_results", 100))
+ return compute_correlation_rust(
+ start_vars, corr_type, method, corr_return_results, compute_all)
+
+
+def compute_corr_for_top_results(start_vars,
+ correlation_results,
+ this_trait,
+ this_dataset,
+ target_dataset,
+ corr_type):
+ if corr_type != "tissue" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet":
+ tissue_result = tissue_for_trait_lists(
+ correlation_results, this_dataset, this_trait)
+
+ if tissue_result:
+ correlation_results = merge_correlation_results(
+ correlation_results, tissue_result)
+
+ if corr_type != "lit" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet":
+ lit_result = lit_for_trait_list(
+ correlation_results, this_dataset, this_trait)
+
+ if lit_result:
+ correlation_results = merge_correlation_results(
+ correlation_results, lit_result)
+
+ if corr_type != "sample" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet":
+ sample_result = sample_for_trait_lists(
+ correlation_results, target_dataset, this_trait, this_dataset, start_vars)
+ if sample_result:
+ correlation_results = merge_correlation_results(
+ correlation_results, sample_result)
+
+ return correlation_results
+
+
+def do_lit_correlation(this_trait, this_dataset):
+ """function for fetching lit inputs"""
+ geneid_dict = this_dataset.retrieve_genes("GeneId")
+ species = this_dataset.group.species
+ if species:
+ species = species.lower()
+ trait_geneid = this_trait.geneid
+ return (trait_geneid, geneid_dict, species)
+
+
+def get_tissue_correlation_input(this_trait, trait_symbol_dict):
+ """Gets tissue expression values for the primary trait and target tissues values"""
+ primary_trait_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values(
+ symbol_list=[this_trait.symbol])
+ if this_trait.symbol and this_trait.symbol.lower() in primary_trait_tissue_vals_dict:
+
+ primary_trait_tissue_values = primary_trait_tissue_vals_dict[this_trait.symbol.lower(
+ )]
+ corr_result_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values(
+ symbol_list=list(trait_symbol_dict.values()))
+ primary_tissue_data = {
+ "this_id": this_trait.name,
+ "tissue_values": primary_trait_tissue_values
+
+ }
+ target_tissue_data = {
+ "trait_symbol_dict": trait_symbol_dict,
+ "symbol_tissue_vals_dict": corr_result_tissue_vals_dict
+ }
+ return (primary_tissue_data, target_tissue_data)
diff --git a/gn2/wqflask/correlation/exceptions.py b/gn2/wqflask/correlation/exceptions.py
new file mode 100644
index 00000000..f4e2b72b
--- /dev/null
+++ b/gn2/wqflask/correlation/exceptions.py
@@ -0,0 +1,16 @@
+"""Correlation-Specific Exceptions"""
+
+class WrongCorrelationType(Exception):
+ """Raised when a correlation is requested for incompatible datasets."""
+
+ def __init__(self, trait, target_dataset, corr_method):
+ corr_method = {
+ "lit": "Literature",
+ "tissue": "Tissue"
+ }[corr_method]
+ message = (
+ f"It is not possible to compute the '{corr_method}' correlations "
+ f"between trait '{trait.name}' and the data in the "
+ f"'{target_dataset.fullname}' dataset. "
+ "Please try again after selecting another type of correlation.")
+ super().__init__(message)
diff --git a/gn2/wqflask/correlation/pre_computes.py b/gn2/wqflask/correlation/pre_computes.py
new file mode 100644
index 00000000..4bd888ad
--- /dev/null
+++ b/gn2/wqflask/correlation/pre_computes.py
@@ -0,0 +1,178 @@
+import csv
+import json
+import os
+import hashlib
+import datetime
+
+import lmdb
+import pickle
+from pathlib import Path
+
+from gn2.base.data_set import query_table_timestamp
+from gn2.base.webqtlConfig import TEXTDIR
+from gn2.base.webqtlConfig import TMPDIR
+
+from json.decoder import JSONDecodeError
+
+def cache_trait_metadata(dataset_name, data):
+
+
+ try:
+ with lmdb.open(os.path.join(TMPDIR,f"metadata_{dataset_name}"),map_size=20971520) as env:
+ with env.begin(write=True) as txn:
+ data_bytes = pickle.dumps(data)
+ txn.put(f"{dataset_name}".encode(), data_bytes)
+ current_date = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+ txn.put(b"creation_date", current_date.encode())
+ return "success"
+
+ except lmdb.Error as error:
+ pass
+
+def read_trait_metadata(dataset_name):
+ try:
+ with lmdb.open(os.path.join(TMPDIR,f"metadata_{dataset_name}"),
+ readonly=True, lock=False) as env:
+ with env.begin() as txn:
+ db_name = txn.get(dataset_name.encode())
+ return (pickle.loads(db_name) if db_name else {})
+ except lmdb.Error as error:
+ return {}
+
+
+def fetch_all_cached_metadata(dataset_name):
+ """in a gvein dataset fetch all the traits metadata"""
+ file_name = generate_filename(dataset_name, suffix="metadata")
+
+ file_path = Path(TMPDIR, file_name)
+
+ try:
+ with open(file_path, "r+") as file_handler:
+ dataset_metadata = json.load(file_handler)
+
+ return (file_path, dataset_metadata)
+
+ except FileNotFoundError:
+ pass
+
+ except JSONDecodeError:
+ file_path.unlink()
+
+ file_path.touch(exist_ok=True)
+
+ return (file_path, {})
+
+
+def cache_new_traits_metadata(dataset_metadata: dict, new_traits_metadata, file_path: str):
+ """function to cache the new traits metadata"""
+
+ if (dataset_metadata == {} and new_traits_metadata == {}):
+ return
+
+ dataset_metadata.update(new_traits_metadata)
+
+ with open(file_path, "w+") as file_handler:
+ json.dump(dataset_metadata, file_handler)
+
+
+def generate_filename(*args, suffix="", file_ext="json"):
+ """given a list of args generate a unique filename"""
+
+ string_unicode = f"{*args,}".encode()
+ return f"{hashlib.md5(string_unicode).hexdigest()}_{suffix}.{file_ext}"
+
+
+
+
+def fetch_text_file(dataset_name, conn, text_dir=TMPDIR):
+ """fetch textfiles with strain vals if exists"""
+
+ def __file_scanner__(text_dir, target_file):
+ for file in os.listdir(text_dir):
+ if file.startswith(f"ProbeSetFreezeId_{target_file}_"):
+ return os.path.join(text_dir, file)
+
+ with conn.cursor() as cursor:
+ cursor.execute(
+ 'SELECT Id, FullName FROM ProbeSetFreeze WHERE Name = %s', (dataset_name,))
+ results = cursor.fetchone()
+ if results:
+ try:
+ # checks first for recently generated textfiles if not use gn1 datamatrix
+
+ return __file_scanner__(text_dir, results[0]) or __file_scanner__(TEXTDIR, results[0])
+
+ except Exception:
+ pass
+
+
+def read_text_file(sample_dict, file_path):
+
+ def __fetch_id_positions__(all_ids, target_ids):
+ _vals = []
+ _posit = [0] # alternative for parsing
+
+ for (idx, strain) in enumerate(all_ids, 1):
+ if strain in target_ids:
+ _vals.append(target_ids[strain])
+ _posit.append(idx)
+
+ return (_posit, _vals)
+
+ with open(file_path) as csv_file:
+ csv_reader = csv.reader(csv_file, delimiter=',')
+ _posit, sample_vals = __fetch_id_positions__(
+ next(csv_reader)[1:], sample_dict)
+ return (sample_vals, [[line[i] for i in _posit] for line in csv_reader])
+
+
+def write_db_to_textfile(db_name, conn, text_dir=TMPDIR):
+
+ def __sanitise_filename__(filename):
+ ttable = str.maketrans({" ": "_", "/": "_", "\\": "_"})
+ return str.translate(filename, ttable)
+
+ def __generate_file_name__(db_name):
+ # todo add expiry time and checker
+ with conn.cursor() as cursor:
+ cursor.execute(
+ 'SELECT Id, FullName FROM ProbeSetFreeze WHERE Name = %s', (db_name,))
+ results = cursor.fetchone()
+ if (results):
+ return __sanitise_filename__(
+ f"ProbeSetFreezeId_{results[0]}_{results[1]}")
+
+ def __parse_to_dict__(results):
+ ids = ["ID"]
+ data = {}
+ for (trait, strain, val) in results:
+ if strain not in ids:
+ ids.append(strain)
+ if trait in data:
+ data[trait].append(val)
+ else:
+ data[trait] = [trait, val]
+ return (data, ids)
+
+ def __write_to_file__(file_path, data, col_names):
+ with open(file_path, 'w+', encoding='UTF8') as file_handler:
+ writer = csv.writer(file_handler)
+ writer.writerow(col_names)
+ writer.writerows(data.values())
+
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "SELECT ProbeSet.Name, Strain.Name, ProbeSetData.value "
+ "FROM Strain LEFT JOIN ProbeSetData "
+ "ON Strain.Id = ProbeSetData.StrainId "
+ "LEFT JOIN ProbeSetXRef ON ProbeSetData.Id = ProbeSetXRef.DataId "
+ "LEFT JOIN ProbeSet ON ProbeSetXRef.ProbeSetId = ProbeSet.Id "
+ "WHERE ProbeSetXRef.ProbeSetFreezeId IN "
+ "(SELECT Id FROM ProbeSetFreeze WHERE Name = %s) "
+ "ORDER BY Strain.Name",
+ (db_name,))
+ results = cursor.fetchall()
+ file_name = __generate_file_name__(db_name)
+ if (results and file_name):
+ __write_to_file__(os.path.join(text_dir, file_name),
+ *__parse_to_dict__(results))
diff --git a/gn2/wqflask/correlation/rust_correlation.py b/gn2/wqflask/correlation/rust_correlation.py
new file mode 100644
index 00000000..a0dcbcb4
--- /dev/null
+++ b/gn2/wqflask/correlation/rust_correlation.py
@@ -0,0 +1,408 @@
+"""module contains integration code for rust-gn3"""
+import json
+from functools import reduce
+
+from gn2.utility.tools import SQL_URI
+from gn2.utility.db_tools import mescape
+from gn2.utility.db_tools import create_in_clause
+from gn2.wqflask.correlation.correlation_functions\
+ import get_trait_symbol_and_tissue_values
+from gn2.wqflask.correlation.correlation_gn3_api import create_target_this_trait
+from gn2.wqflask.correlation.correlation_gn3_api import lit_for_trait_list
+from gn2.wqflask.correlation.correlation_gn3_api import do_lit_correlation
+from gn2.wqflask.correlation.pre_computes import fetch_text_file
+from gn2.wqflask.correlation.pre_computes import read_text_file
+from gn2.wqflask.correlation.pre_computes import write_db_to_textfile
+from gn2.wqflask.correlation.pre_computes import read_trait_metadata
+from gn2.wqflask.correlation.pre_computes import cache_trait_metadata
+from gn3.computations.correlations import compute_all_lit_correlation
+from gn3.computations.rust_correlation import run_correlation
+from gn3.computations.rust_correlation import get_sample_corr_data
+from gn3.computations.rust_correlation import parse_tissue_corr_data
+from gn3.db_utils import database_connection
+
+from gn2.wqflask.correlation.exceptions import WrongCorrelationType
+
+
+def query_probes_metadata(dataset, trait_list):
+ """query traits metadata in bulk for probeset"""
+
+ if not bool(trait_list) or dataset.type != "ProbeSet":
+ return []
+
+ with database_connection(SQL_URI) as conn:
+ with conn.cursor() as cursor:
+
+ query = """
+ SELECT ProbeSet.Name,ProbeSet.Chr,ProbeSet.Mb,
+ ProbeSet.Symbol,ProbeSetXRef.mean,
+ CONCAT_WS('; ', ProbeSet.description, ProbeSet.Probe_Target_Description) AS description,
+ ProbeSetXRef.additive,ProbeSetXRef.LRS,Geno.Chr, Geno.Mb
+ FROM ProbeSet INNER JOIN ProbeSetXRef
+ ON ProbeSet.Id=ProbeSetXRef.ProbeSetId
+ INNER JOIN Geno
+ ON ProbeSetXRef.Locus = Geno.Name
+ INNER JOIN Species
+ ON Geno.SpeciesId = Species.Id
+ WHERE ProbeSet.Name in ({}) AND
+ Species.Name = %s AND
+ ProbeSetXRef.ProbeSetFreezeId IN (
+ SELECT ProbeSetFreeze.Id
+ FROM ProbeSetFreeze WHERE ProbeSetFreeze.Name = %s)
+ """.format(", ".join(["%s"] * len(trait_list)))
+
+ cursor.execute(query,
+ (tuple(trait_list) +
+ (dataset.group.species,) + (dataset.name,))
+ )
+
+ return cursor.fetchall()
+
+
+def get_metadata(dataset, traits):
+ """Retrieve the metadata"""
+ def __location__(probe_chr, probe_mb):
+ if probe_mb:
+ return f"Chr{probe_chr}: {probe_mb:.6f}"
+ return f"Chr{probe_chr}: ???"
+ cached_metadata = read_trait_metadata(dataset.name)
+ to_fetch_metadata = list(
+ set(traits).difference(list(cached_metadata.keys())))
+ if to_fetch_metadata:
+ results = {**({trait_name: {
+ "name": trait_name,
+ "view": True,
+ "symbol": symbol,
+ "dataset": dataset.name,
+ "dataset_name": dataset.shortname,
+ "mean": mean,
+ "description": description,
+ "additive": additive,
+ "lrs_score": f"{lrs:3.1f}" if lrs else "",
+ "location": __location__(probe_chr, probe_mb),
+ "chr": probe_chr,
+ "mb": probe_mb,
+ "lrs_location": f'Chr{chr_score}: {mb:{".6f" if mb else ""}}',
+ "lrs_chr": chr_score,
+ "lrs_mb": mb
+
+ } for trait_name, probe_chr, probe_mb, symbol, mean, description,
+ additive, lrs, chr_score, mb
+ in query_probes_metadata(dataset, to_fetch_metadata)}), **cached_metadata}
+ cache_trait_metadata(dataset.name, results)
+ return results
+ return cached_metadata
+
+
+def chunk_dataset(dataset, steps, name):
+
+ results = []
+
+ query = """
+ SELECT ProbeSetXRef.DataId,ProbeSet.Name
+ FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze
+ WHERE ProbeSetFreeze.Name = '{}' AND
+ ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND
+ ProbeSetXRef.ProbeSetId = ProbeSet.Id
+ """.format(name)
+
+ with database_connection(SQL_URI) as conn:
+ with conn.cursor() as curr:
+ curr.execute(query)
+ traits_name_dict = dict(curr.fetchall())
+
+ for i in range(0, len(dataset), steps):
+ matrix = list(dataset[i:i + steps])
+ results.append([traits_name_dict[matrix[0][0]]] + [str(value)
+ for (trait_name, strain, value) in matrix])
+ return results
+
+
+def compute_top_n_sample(start_vars, dataset, trait_list):
+ """check if dataset is of type probeset"""
+
+ if dataset.type.lower() != "probeset":
+ return {}
+
+ def __fetch_sample_ids__(samples_vals, samples_group):
+ sample_data = get_sample_corr_data(
+ sample_type=samples_group,
+ sample_data=json.loads(samples_vals),
+ dataset_samples=dataset.group.all_samples_ordered())
+
+ with database_connection(SQL_URI) as conn:
+ with conn.cursor() as curr:
+ curr.execute(
+ """
+ SELECT Strain.Name, Strain.Id FROM Strain, Species
+ WHERE Strain.Name IN {}
+ and Strain.SpeciesId=Species.Id
+ and Species.name = '{}'
+ """.format(create_in_clause(list(sample_data.keys())),
+ *mescape(dataset.group.species)))
+ return (sample_data, dict(curr.fetchall()))
+
+ (sample_data, sample_ids) = __fetch_sample_ids__(
+ start_vars["sample_vals"], start_vars["corr_samples_group"])
+
+ if len(trait_list) == 0:
+ return {}
+
+ with database_connection(SQL_URI) as conn:
+ with conn.cursor() as curr:
+ # fetching strain data in bulk
+ query = (
+ "SELECT * from ProbeSetData "
+ f"WHERE StrainID IN ({', '.join(['%s'] * len(sample_ids))}) "
+ "AND Id IN ("
+ " SELECT ProbeSetXRef.DataId "
+ " FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) "
+ " WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id "
+ " AND ProbeSetFreeze.Name = %s "
+ " AND ProbeSet.Name "
+ f" IN ({', '.join(['%s'] * len(trait_list))}) "
+ " AND ProbeSet.Id = ProbeSetXRef.ProbeSetId"
+ ")")
+ curr.execute(
+ query,
+ tuple(sample_ids.values()) + (dataset.name,) + tuple(trait_list))
+
+ corr_data = chunk_dataset(
+ list(curr.fetchall()), len(sample_ids.values()), dataset.name)
+
+ return run_correlation(
+ corr_data, list(sample_data.values()), "pearson", ",")
+
+
+def compute_top_n_lit(corr_results, target_dataset, this_trait) -> dict:
+ if not __datasets_compatible_p__(this_trait.dataset, target_dataset, "lit"):
+ return {}
+
+ (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
+ this_trait, target_dataset)
+
+ geneid_dict = {trait_name: geneid for (trait_name, geneid)
+ in geneid_dict.items() if
+ corr_results.get(trait_name)}
+ with database_connection(SQL_URI) as conn:
+ return reduce(
+ lambda acc, corr: {**acc, **corr},
+ compute_all_lit_correlation(
+ conn=conn, trait_lists=list(geneid_dict.items()),
+ species=species, gene_id=this_trait_geneid),
+ {})
+
+ return {}
+
+
+def compute_top_n_tissue(target_dataset, this_trait, traits, method):
+ # refactor lots of rpt
+ if not __datasets_compatible_p__(this_trait.dataset, target_dataset, "tissue"):
+ return {}
+
+ trait_symbol_dict = dict({
+ trait_name: symbol
+ for (trait_name, symbol)
+ in target_dataset.retrieve_genes("Symbol").items()
+ if traits.get(trait_name)})
+
+ corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
+ symbol_list=list(trait_symbol_dict.values()))
+
+ data = parse_tissue_corr_data(symbol_name=this_trait.symbol,
+ symbol_dict=get_trait_symbol_and_tissue_values(
+ symbol_list=[this_trait.symbol]),
+ dataset_symbols=trait_symbol_dict,
+ dataset_vals=corr_result_tissue_vals_dict)
+
+ if data and data[0]:
+ return run_correlation(
+ data[1], data[0], method, ",", "tissue")
+
+ return {}
+
+
+def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
+ """code to merge diff corr into individual dicts
+ a"""
+
+ def __merge__(trait_name, trait_corrs):
+ return {
+ trait_name: {
+ **trait_corrs,
+ **dict_b.get(trait_name, {}),
+ **dict_c.get(trait_name, {})
+ }
+ }
+ return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
+
+
+def __compute_sample_corr__(
+ start_vars: dict, corr_type: str, method: str, n_top: int,
+ target_trait_info: tuple):
+ """Compute the sample correlations"""
+ (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+
+ if this_dataset.group.f1list != None:
+ this_dataset.group.samplelist += this_dataset.group.f1list
+
+ if this_dataset.group.parlist != None:
+ this_dataset.group.samplelist += this_dataset.group.parlist
+
+ sample_data = get_sample_corr_data(
+ sample_type=start_vars["corr_samples_group"],
+ sample_data=json.loads(start_vars["sample_vals"]),
+ dataset_samples=this_dataset.group.all_samples_ordered())
+
+ if not bool(sample_data):
+ return {}
+
+ if target_dataset.type == "ProbeSet" and start_vars.get("use_cache") == "true":
+ with database_connection(SQL_URI) as conn:
+ file_path = fetch_text_file(target_dataset.name, conn)
+ if file_path:
+ (sample_vals, target_data) = read_text_file(
+ sample_data, file_path)
+
+ return run_correlation(target_data, sample_vals,
+ method, ",", corr_type, n_top)
+
+ write_db_to_textfile(target_dataset.name, conn)
+ file_path = fetch_text_file(target_dataset.name, conn)
+ if file_path:
+ (sample_vals, target_data) = read_text_file(
+ sample_data, file_path)
+
+ return run_correlation(target_data, sample_vals,
+ method, ",", corr_type, n_top)
+
+ target_dataset.get_trait_data(list(sample_data.keys()))
+
+ def __merge_key_and_values__(rows, current):
+ wo_nones = [value for value in current[1]]
+ if len(wo_nones) > 0:
+ return rows + [[current[0]] + wo_nones]
+ return rows
+
+ target_data = reduce(
+ __merge_key_and_values__, target_dataset.trait_data.items(), [])
+
+ if len(target_data) == 0:
+ return {}
+
+ return run_correlation(
+ target_data, list(sample_data.values()), method, ",", corr_type,
+ n_top)
+
+
+def __datasets_compatible_p__(trait_dataset, target_dataset, corr_method):
+ return not (
+ corr_method in ("tissue", "Tissue r", "Literature r", "lit")
+ and (trait_dataset.type == "ProbeSet" and
+ target_dataset.type in ("Publish", "Geno")))
+
+
+def __compute_tissue_corr__(
+ start_vars: dict, corr_type: str, method: str, n_top: int,
+ target_trait_info: tuple):
+ """Compute the tissue correlations"""
+ (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+ if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type):
+ raise WrongCorrelationType(this_trait, target_dataset, corr_type)
+
+ trait_symbol_dict = target_dataset.retrieve_genes("Symbol")
+ corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
+ symbol_list=list(trait_symbol_dict.values()))
+
+ data = parse_tissue_corr_data(
+ symbol_name=this_trait.symbol,
+ symbol_dict=get_trait_symbol_and_tissue_values(
+ symbol_list=[this_trait.symbol]),
+ dataset_symbols=trait_symbol_dict,
+ dataset_vals=corr_result_tissue_vals_dict)
+
+ if data:
+ return run_correlation(data[1], data[0], method, ",", "tissue")
+ return {}
+
+
+def __compute_lit_corr__(
+ start_vars: dict, corr_type: str, method: str, n_top: int,
+ target_trait_info: tuple):
+ """Compute the literature correlations"""
+ (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+ if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type):
+ raise WrongCorrelationType(this_trait, target_dataset, corr_type)
+
+ target_dataset_type = target_dataset.type
+ this_dataset_type = this_dataset.type
+ (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
+ this_trait, target_dataset)
+
+ with database_connection(SQL_URI) as conn:
+ return reduce(
+ lambda acc, lit: {**acc, **lit},
+ compute_all_lit_correlation(
+ conn=conn, trait_lists=list(geneid_dict.items()),
+ species=species, gene_id=this_trait_geneid)[:n_top],
+ {})
+ return {}
+
+
+def compute_correlation_rust(
+ start_vars: dict, corr_type: str, method: str = "pearson",
+ n_top: int = 500, should_compute_all: bool = False):
+ """function to compute correlation"""
+ target_trait_info = create_target_this_trait(start_vars)
+ (this_dataset, this_trait, target_dataset, sample_data) = (
+ target_trait_info)
+ if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type):
+ raise WrongCorrelationType(this_trait, target_dataset, corr_type)
+
+ # Replace this with `match ...` once we hit Python 3.10
+ corr_type_fns = {
+ "sample": __compute_sample_corr__,
+ "tissue": __compute_tissue_corr__,
+ "lit": __compute_lit_corr__
+ }
+
+ results = corr_type_fns[corr_type](
+ start_vars, corr_type, method, n_top, target_trait_info)
+
+ # END: Replace this with `match ...` once we hit Python 3.10
+
+ top_a = top_b = {}
+
+ if should_compute_all:
+
+ if corr_type == "sample":
+ if this_dataset.type == "ProbeSet":
+ top_a = compute_top_n_tissue(
+ target_dataset, this_trait, results, method)
+
+ top_b = compute_top_n_lit(results, target_dataset, this_trait)
+ else:
+ pass
+
+ elif corr_type == "lit":
+
+ # currently fails for lit
+
+ top_a = compute_top_n_sample(
+ start_vars, target_dataset, list(results.keys()))
+ top_b = compute_top_n_tissue(
+ target_dataset, this_trait, results, method)
+
+ else:
+
+ top_a = compute_top_n_sample(
+ start_vars, target_dataset, list(results.keys()))
+
+ return {
+ "correlation_results": merge_results(
+ results, top_a, top_b),
+ "this_trait": this_trait.name,
+ "target_dataset": start_vars['corr_dataset'],
+ "traits_metadata": get_metadata(target_dataset, list(results.keys())),
+ "return_results": n_top
+ }
diff --git a/gn2/wqflask/correlation/show_corr_results.py b/gn2/wqflask/correlation/show_corr_results.py
new file mode 100644
index 00000000..c8625222
--- /dev/null
+++ b/gn2/wqflask/correlation/show_corr_results.py
@@ -0,0 +1,406 @@
+# Copyright (C) University of Tennessee Health Science Center, Memphis, TN.
+#
+# This program is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Affero General Public License
+# as published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+# See the GNU Affero General Public License for more details.
+#
+# This program is available from Source Forge: at GeneNetwork Project
+# (sourceforge.net/projects/genenetwork/).
+#
+# Contact Dr. Robert W. Williams at rwilliams@uthsc.edu
+#
+#
+# This module is used by GeneNetwork project (www.genenetwork.org)
+
+import hashlib
+import html
+import json
+
+from gn2.base.trait import create_trait, jsonable
+from gn2.base.data_set import create_dataset
+
+from gn2.utility import hmac
+from gn2.utility.type_checking import get_float, get_int, get_string
+from gn2.utility.redis_tools import get_redis_conn
+Redis = get_redis_conn()
+
+def set_template_vars(start_vars, correlation_data):
+ corr_type = start_vars['corr_type']
+ corr_method = start_vars['corr_sample_method']
+
+ if start_vars['dataset'] == "Temp":
+ this_dataset_ob = create_dataset(
+ dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group'])
+ else:
+ this_dataset_ob = create_dataset(dataset_name=start_vars['dataset'])
+ this_trait = create_trait(dataset=this_dataset_ob,
+ name=start_vars['trait_id'])
+
+ # Store trait sample data in Redis, so additive effect scatterplots can include edited values
+ dhash = hashlib.md5()
+ dhash.update(start_vars['sample_vals'].encode())
+ samples_hash = dhash.hexdigest()
+ Redis.set(samples_hash, start_vars['sample_vals'], ex=7*24*60*60)
+ correlation_data['dataid'] = samples_hash
+
+ correlation_data['this_trait'] = jsonable(this_trait, this_dataset_ob)
+ correlation_data['this_dataset'] = this_dataset_ob.as_monadic_dict().data
+
+ target_dataset_ob = create_dataset(correlation_data['target_dataset'])
+ correlation_data['target_dataset'] = target_dataset_ob.as_monadic_dict().data
+ correlation_data['table_json'] = correlation_json_for_table(
+ start_vars,
+ correlation_data,
+ target_dataset_ob)
+
+ if target_dataset_ob.type == "ProbeSet":
+ filter_cols = [7, 6]
+ elif target_dataset_ob.type == "Publish":
+ filter_cols = [8, 5]
+ else:
+ filter_cols = [4, 0]
+
+ correlation_data['corr_method'] = corr_method
+ correlation_data['filter_cols'] = filter_cols
+ correlation_data['header_fields'] = get_header_fields(
+ target_dataset_ob.type, correlation_data['corr_method'])
+ correlation_data['formatted_corr_type'] = get_formatted_corr_type(
+ corr_type, corr_method)
+
+ return correlation_data
+
+
+def apply_filters(trait, target_trait, target_dataset, **filters):
+ def __p_val_filter__(p_lower, p_upper):
+
+ return not (p_lower <= float(trait.get("corr_coefficient",0.0)) <= p_upper)
+
+ def __min_filter__(min_expr):
+ if (target_dataset['type'] in ["ProbeSet", "Publish"] and target_trait['mean']):
+ return (min_expr != None) and (float(target_trait['mean']) < min_expr)
+
+ return False
+
+ def __location_filter__(location_type, location_chr,
+ min_location_mb, max_location_mb):
+
+ if target_dataset["type"] in ["ProbeSet", "Geno"] and location_type == "gene":
+ return (
+ ((location_chr!=None) and (target_trait["chr"]!=location_chr))
+ or
+ ((min_location_mb!= None) and (
+ float(target_trait['mb']) < min_location_mb)
+ )
+
+ or
+ ((max_location_mb != None) and
+ (float(target_trait['mb']) > float(max_location_mb)
+ ))
+
+ )
+ elif target_dataset["type"] in ["ProbeSet", "Publish"]:
+
+ return ((location_chr!=None) and (target_trait["lrs_chr"] != location_chr)
+ or
+ ((min_location_mb != None) and (
+ float(target_trait['lrs_mb']) < float(min_location_mb)))
+ or
+ ((max_location_mb != None) and (
+ float(target_trait['lrs_mb']) > float(max_location_mb))
+ )
+
+ )
+
+ return True
+
+ if not target_trait:
+ return True
+ else:
+ # check if one of the condition is not met i.e One is True
+ return (__p_val_filter__(
+ filters.get("p_range_lower"),
+ filters.get("p_range_upper")
+ )
+ or
+ (
+ __min_filter__(
+ filters.get("min_expr")
+ )
+ )
+ or
+ __location_filter__(
+ filters.get("location_type"),
+ filters.get("location_chr"),
+ filters.get("min_location_mb"),
+ filters.get("max_location_mb")
+
+
+ )
+ )
+
+
+def get_user_filters(start_vars):
+ (min_expr, p_min, p_max) = (
+ get_float(start_vars, 'min_expr'),
+ get_float(start_vars, 'p_range_lower', -1.0),
+ get_float(start_vars, 'p_range_upper', 1.0)
+ )
+
+ if all(keys in start_vars for keys in ["loc_chr",
+ "min_loc_mb",
+ "max_location_mb"]):
+
+ location_chr = get_string(start_vars, "loc_chr")
+ min_location_mb = get_int(start_vars, "min_loc_mb")
+ max_location_mb = get_int(start_vars, "max_loc_mb")
+
+ else:
+ location_chr = min_location_mb = max_location_mb = None
+
+ return {
+
+ "min_expr": min_expr,
+ "p_range_lower": p_min,
+ "p_range_upper": p_max,
+ "location_chr": location_chr,
+ "location_type": start_vars['location_type'],
+ "min_location_mb": min_location_mb,
+ "max_location_mb": max_location_mb
+
+ }
+
+
+def generate_table_metadata(all_traits, dataset_metadata, dataset_obj):
+
+ def __fetch_trait_data__(trait, dataset_obj):
+ target_trait_ob = create_trait(dataset=dataset_obj,
+ name=trait,
+ get_qtl_info=True)
+ return jsonable(target_trait_ob, dataset_obj)
+
+ metadata = [__fetch_trait_data__(trait, dataset_obj) for
+ trait in (all_traits)]
+
+ return (dataset_metadata | ({str(trait["name"]): trait for trait in metadata}))
+
+
+def populate_table(dataset_metadata, target_dataset, this_dataset, corr_results, filters):
+
+ def __populate_trait__(idx, trait):
+
+ trait_name = list(trait.keys())[0]
+ target_trait = dataset_metadata.get(trait_name)
+ trait = trait[trait_name]
+ if not apply_filters(trait, target_trait, target_dataset, **filters):
+ results_dict = {}
+ results_dict['index'] = idx + 1 #
+ results_dict['trait_id'] = target_trait['name']
+ results_dict['dataset'] = target_dataset['name']
+ results_dict['hmac'] = hmac.data_hmac(
+ '{}:{}'.format(target_trait['name'], target_dataset['name']))
+ results_dict['sample_r'] = f"{float(trait.get('corr_coefficient',0.0)):.3f}"
+ results_dict['num_overlap'] = trait.get('num_overlap', 0)
+ results_dict['sample_p'] = f"{float(trait.get('p_value',0)):.2e}"
+ if target_dataset['type'] == "ProbeSet":
+ results_dict['symbol'] = target_trait['symbol']
+ results_dict['description'] = "N/A"
+ results_dict['location'] = target_trait['location']
+ results_dict['mean'] = "N/A"
+ results_dict['additive'] = "N/A"
+ if target_trait['description'].strip():
+ results_dict['description'] = html.escape(
+ target_trait['description'].strip(), quote=True)
+ if target_trait['mean']:
+ results_dict['mean'] = f"{float(target_trait['mean']):.3f}"
+ try:
+ results_dict['lod_score'] = f"{float(target_trait['lrs_score']) / 4.61:.1f}"
+ except:
+ results_dict['lod_score'] = "N/A"
+ results_dict['lrs_location'] = target_trait['lrs_location']
+ if target_trait['additive']:
+ results_dict['additive'] = f"{float(target_trait['additive']):.3f}"
+ results_dict['lit_corr'] = "--"
+ results_dict['tissue_corr'] = "--"
+ results_dict['tissue_pvalue'] = "--"
+ if this_dataset['type'] == "ProbeSet":
+ if 'lit_corr' in trait:
+ results_dict['lit_corr'] = (
+ f"{float(trait['lit_corr']):.3f}"
+ if trait["lit_corr"] else "--")
+ if 'tissue_corr' in trait:
+ results_dict['tissue_corr'] = f"{float(trait['tissue_corr']):.3f}"
+ results_dict['tissue_pvalue'] = f"{float(trait['tissue_p_val']):.3e}"
+ elif target_dataset['type'] == "Publish":
+ results_dict['abbreviation_display'] = "N/A"
+ results_dict['description'] = "N/A"
+ results_dict['mean'] = "N/A"
+ results_dict['authors_display'] = "N/A"
+ results_dict['additive'] = "N/A"
+ results_dict['pubmed_link'] = "N/A"
+ results_dict['pubmed_text'] = target_trait["pubmed_text"]
+
+ if target_trait["abbreviation"]:
+ results_dict['abbreviation'] = target_trait['abbreviation']
+
+ if target_trait["description"].strip():
+ results_dict['description'] = html.escape(
+ target_trait['description'].strip(), quote=True)
+
+ if target_trait["mean"] != "N/A":
+ results_dict['mean'] = f"{float(target_trait['mean']):.3f}"
+
+ results_dict['lrs_location'] = target_trait['lrs_location']
+
+ if target_trait["authors"]:
+ authors_list = target_trait['authors'].split(',')
+ results_dict['authors_display'] = ", ".join(
+ authors_list[:6]) + ", et al." if len(authors_list) > 6 else target_trait['authors']
+
+ if "pubmed_id" in target_trait:
+ results_dict['pubmed_link'] = target_trait['pubmed_link']
+ results_dict['pubmed_text'] = target_trait['pubmed_text']
+ try:
+ results_dict["lod_score"] = f"{float(target_trait['lrs_score']) / 4.61:.1f}"
+ except ValueError:
+ results_dict['lod_score'] = "N/A"
+ else:
+ results_dict['location'] = target_trait['location']
+
+ return results_dict
+
+ return [__populate_trait__(idx, trait)
+ for (idx, trait) in enumerate(corr_results)]
+
+
+def correlation_json_for_table(start_vars, correlation_data, target_dataset_ob):
+ """Return JSON data for use with the DataTable in the correlation result page
+
+ Keyword arguments:
+ correlation_data -- Correlation results
+ this_trait -- Trait being correlated against a dataset, as a dict
+ this_dataset -- Dataset of this_trait, as a monadic dict
+ target_dataset_ob - Target dataset, as a Dataset ob
+ """
+ this_dataset = correlation_data['this_dataset']
+
+ traits = set()
+ for trait in correlation_data["correlation_results"]:
+ traits.add(list(trait)[0])
+
+ dataset_metadata = generate_table_metadata(traits,
+ correlation_data["traits_metadata"],
+ target_dataset_ob)
+ return json.dumps([result for result in (
+ populate_table(dataset_metadata=dataset_metadata,
+ target_dataset=target_dataset_ob.as_monadic_dict().data,
+ this_dataset=correlation_data['this_dataset'],
+ corr_results=correlation_data['correlation_results'],
+ filters=get_user_filters(start_vars))) if result])
+
+
+def get_formatted_corr_type(corr_type, corr_method):
+ formatted_corr_type = ""
+ if corr_type == "lit":
+ formatted_corr_type += "Literature Correlation "
+ elif corr_type == "tissue":
+ formatted_corr_type += "Tissue Correlation "
+ elif corr_type == "sample":
+ formatted_corr_type += "Genetic Correlation "
+
+ if corr_method == "pearson":
+ formatted_corr_type += "(Pearson's r)"
+ elif corr_method == "spearman":
+ formatted_corr_type += "(Spearman's rho)"
+ elif corr_method == "bicor":
+ formatted_corr_type += "(Biweight r)"
+
+ return formatted_corr_type
+
+
+def get_header_fields(data_type, corr_method):
+ if data_type == "ProbeSet":
+ if corr_method == "spearman":
+ header_fields = ['Index',
+ 'Record',
+ 'Symbol',
+ 'Description',
+ 'Location',
+ 'Mean',
+ 'Sample rho',
+ 'N',
+ 'Sample p(rho)',
+ 'Lit rho',
+ 'Tissue rho',
+ 'Tissue p(rho)',
+ 'Max LRS',
+ 'Max LRS Location',
+ 'Additive Effect']
+ else:
+ header_fields = ['Index',
+ 'Record',
+ 'Symbol',
+ 'Description',
+ 'Location',
+ 'Mean',
+ 'Sample r',
+ 'N',
+ 'Sample p(r)',
+ 'Lit r',
+ 'Tissue r',
+ 'Tissue p(r)',
+ 'Max LRS',
+ 'Max LRS Location',
+ 'Additive Effect']
+ elif data_type == "Publish":
+ if corr_method == "spearman":
+ header_fields = ['Index',
+ 'Record',
+ 'Abbreviation',
+ 'Description',
+ 'Mean',
+ 'Authors',
+ 'Year',
+ 'Sample rho',
+ 'N',
+ 'Sample p(rho)',
+ 'Max LRS',
+ 'Max LRS Location',
+ 'Additive Effect']
+ else:
+ header_fields = ['Index',
+ 'Record',
+ 'Abbreviation',
+ 'Description',
+ 'Mean',
+ 'Authors',
+ 'Year',
+ 'Sample r',
+ 'N',
+ 'Sample p(r)',
+ 'Max LRS',
+ 'Max LRS Location',
+ 'Additive Effect']
+
+ else:
+ if corr_method == "spearman":
+ header_fields = ['Index',
+ 'ID',
+ 'Location',
+ 'Sample rho',
+ 'N',
+ 'Sample p(rho)']
+ else:
+ header_fields = ['Index',
+ 'ID',
+ 'Location',
+ 'Sample r',
+ 'N',
+ 'Sample p(r)']
+
+ return header_fields