aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py135
-rw-r--r--wqflask/wqflask/correlation/show_corr_results.py6
-rw-r--r--wqflask/wqflask/views.py2
3 files changed, 125 insertions, 18 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 79d08a59..5c22efbf 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -1,7 +1,10 @@
"""module contains integration code for rust-gn3"""
import json
from functools import reduce
-from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values
+from utility.db_tools import mescape
+from utility.db_tools import create_in_clause
+from wqflask.correlation.correlation_functions\
+ import get_trait_symbol_and_tissue_values
from wqflask.correlation.correlation_gn3_api import create_target_this_trait
from wqflask.correlation.correlation_gn3_api import lit_for_trait_list
from wqflask.correlation.correlation_gn3_api import do_lit_correlation
@@ -12,11 +15,92 @@ from gn3.computations.rust_correlation import parse_tissue_corr_data
from gn3.db_utils import database_connector
+def chunk_dataset(dataset, steps, name):
+
+ results = []
+
+ query = """
+ SELECT ProbeSetXRef.DataId,ProbeSet.Name
+ FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze
+ WHERE ProbeSetFreeze.Name = '{}' AND
+ ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND
+ ProbeSetXRef.ProbeSetId = ProbeSet.Id
+ """.format(name)
+
+ with database_connector() as conn:
+ with conn.cursor() as curr:
+ curr.execute(query)
+ traits_name_dict = dict(curr.fetchall())
+
+ for i in range(0, len(dataset), steps):
+ matrix = list(dataset[i:i + steps])
+ trait_name = traits_name_dict[matrix[0][0]]
+
+ strains = [trait_name] + [str(value)
+ for (trait_name, strain, value) in matrix]
+ results.append(",".join(strains))
+
+ return results
+
+
+def compute_top_n_sample(start_vars, dataset, trait_list):
+ """check if dataset is of type probeset"""
+
+ if dataset.type.lower() != "probeset":
+ return {}
+
+ def __fetch_sample_ids__(samples_vals, samples_group):
+ all_samples = json.loads(samples_vals)
+ sample_data = get_sample_corr_data(
+ sample_type=samples_group, all_samples=all_samples,
+ dataset_samples=dataset.group.all_samples_ordered())
+
+ with database_connector() as conn:
+ with conn.cursor() as curr:
+ curr.execute(
+ """
+ SELECT Strain.Name, Strain.Id FROM Strain, Species
+ WHERE Strain.Name IN {}
+ and Strain.SpeciesId=Species.Id
+ and Species.name = '{}'
+ """.format(create_in_clause(list(sample_data.keys())),
+ *mescape(dataset.group.species)))
+ return (sample_data, dict(curr.fetchall()))
+
+ (sample_data, sample_ids) = __fetch_sample_ids__(
+ start_vars["sample_vals"], start_vars["corr_samples_group"])
+
+ with database_connector() as conn:
+ with conn.cursor() as curr:
+ # fetching strain data in bulk
+ curr.execute(
+ """
+ SELECT * from ProbeSetData
+ where StrainID in {}
+ and id in (SELECT ProbeSetXRef.DataId
+ FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze)
+ WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id
+ and ProbeSetFreeze.Name = '{}'
+ and ProbeSet.Name in {}
+ and ProbeSet.Id = ProbeSetXRef.ProbeSetId)
+ """.format(
+ create_in_clause(list(sample_ids.values())),
+ dataset.name,
+ create_in_clause(trait_list)))
+
+ corr_data = chunk_dataset(
+ list(curr.fetchall()), len(sample_ids.values()), dataset.name)
+
+ return run_correlation(
+ corr_data, list(sample_data.values()), "pearson", ",")
+
+
def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict:
(this_trait_geneid, geneid_dict, species) = do_lit_correlation(
this_trait, this_dataset)
- geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if
+ geneid_dict = {trait_name: geneid for (trait_name, geneid)
+ in geneid_dict.items() if
corr_results.get(trait_name)}
with database_connector() as conn:
return reduce(
@@ -69,6 +153,7 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
}
return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
+
def __compute_sample_corr__(
start_vars: dict, corr_type: str, method: str, n_top: int,
target_trait_info: tuple):
@@ -86,11 +171,11 @@ def __compute_sample_corr__(
r = ",".join(lts)
target_data.append(r)
-
return run_correlation(
target_data, list(sample_data.values()), method, ",", corr_type,
n_top)
+
def __compute_tissue_corr__(
start_vars: dict, corr_type: str, method: str, n_top: int,
target_trait_info: tuple):
@@ -111,6 +196,7 @@ def __compute_tissue_corr__(
return run_correlation(data[1], data[0], method, ",", "tissue")
return {}
+
def __compute_lit_corr__(
start_vars: dict, corr_type: str, method: str, n_top: int,
target_trait_info: tuple):
@@ -130,15 +216,16 @@ def __compute_lit_corr__(
{})
return {}
+
def compute_correlation_rust(
start_vars: dict, corr_type: str, method: str = "pearson",
- n_top: int = 500, compute_all: bool = False):
+ n_top: int = 500, should_compute_all: bool = False):
"""function to compute correlation"""
target_trait_info = create_target_this_trait(start_vars)
(this_dataset, this_trait, target_dataset, sample_data) = (
target_trait_info)
- ## Replace this with `match ...` once we hit Python 3.10
+ # Replace this with `match ...` once we hit Python 3.10
corr_type_fns = {
"sample": __compute_sample_corr__,
"tissue": __compute_tissue_corr__,
@@ -146,19 +233,39 @@ def compute_correlation_rust(
}
results = corr_type_fns[corr_type](
start_vars, corr_type, method, n_top, target_trait_info)
- ## END: Replace this with `match ...` once we hit Python 3.10
- top_tissue_results = {}
- top_lit_results = {}
- if compute_all:
- # example compute of compute both correlation
- top_tissue_results = compute_top_n_tissue(
- this_dataset,this_trait,results,method)
- top_lit_results = compute_top_n_lit(results,this_dataset,this_trait)
+ # END: Replace this with `match ...` once we hit Python 3.10
+
+ top_a = top_b = {}
+
+ if should_compute_all:
+
+ if corr_type == "sample":
+
+ top_a = compute_top_n_tissue(
+ this_dataset, this_trait, results, method)
+
+ top_b = compute_top_n_lit(results, this_dataset, this_trait)
+
+ elif corr_type == "lit":
+
+ # currently fails for lit
+
+ top_a = compute_top_n_sample(
+ start_vars, target_dataset, list(results.keys()))
+ top_b = compute_top_n_tissue(
+ this_dataset, this_trait, results, method)
+
+ else:
+
+ top_a = compute_top_n_sample(
+ start_vars, target_dataset, list(results.keys()))
+
+ top_b = compute_top_n_lit(results, this_dataset, this_trait)
return {
"correlation_results": merge_results(
- results, top_tissue_results, top_lit_results),
+ results, top_a, top_b),
"this_trait": this_trait.name,
"target_dataset": start_vars['corr_dataset'],
"return_results": n_top
diff --git a/wqflask/wqflask/correlation/show_corr_results.py b/wqflask/wqflask/correlation/show_corr_results.py
index 1c391386..f5fdd9b3 100644
--- a/wqflask/wqflask/correlation/show_corr_results.py
+++ b/wqflask/wqflask/correlation/show_corr_results.py
@@ -121,9 +121,9 @@ def correlation_json_for_table(correlation_data, this_trait, this_dataset, targe
results_dict['dataset'] = target_dataset['name']
results_dict['hmac'] = hmac.data_hmac(
'{}:{}'.format(target_trait['name'], target_dataset['name']))
- results_dict['sample_r'] = f"{float(trait['corr_coefficient']):.3f}"
- results_dict['num_overlap'] = trait['num_overlap']
- results_dict['sample_p'] = f"{float(trait['p_value']):.3e}"
+ results_dict['sample_r'] = f"{float(trait.get('corr_coefficient',0.0)):.3f}"
+ results_dict['num_overlap'] = trait.get('num_overlap',0)
+ results_dict['sample_p'] = f"{float(trait.get('p_value',0)):.3e}"
if target_dataset['type'] == "ProbeSet":
results_dict['symbol'] = target_trait['symbol']
results_dict['description'] = "N/A"
diff --git a/wqflask/wqflask/views.py b/wqflask/wqflask/views.py
index 2e13451d..e054cd49 100644
--- a/wqflask/wqflask/views.py
+++ b/wqflask/wqflask/views.py
@@ -876,7 +876,7 @@ def test_corr_compute_page():
correlation_results = compute_correlation_rust(start_vars,
start_vars["corr_type"],
start_vars['corr_sample_method'],
- int(start_vars.get("corr_return_results", 500)))
+ int(start_vars.get("corr_return_results", 500)),True)
correlation_results = set_template_vars(request.form, correlation_results)