aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py124
1 files changed, 119 insertions, 5 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 3628f549..94720f54 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -1,6 +1,9 @@
"""module contains integration code for rust-gn3"""
import json
from functools import reduce
+from flask import g
+from utility.db_tools import mescape
+from utility.db_tools import create_in_clause
from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values
from wqflask.correlation.correlation_gn3_api import create_target_this_trait
from wqflask.correlation.correlation_gn3_api import lit_for_trait_list
@@ -12,6 +15,106 @@ from gn3.computations.rust_correlation import parse_tissue_corr_data
from gn3.db_utils import database_connector
+
+
+def chunk_dataset(dataset,steps,name):
+
+ results = []
+
+ query = """
+ SELECT ProbeSetXRef.DataId,ProbeSet.Name
+ FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze
+ WHERE ProbeSetFreeze.Name = '{}' AND
+ ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND
+ ProbeSetXRef.ProbeSetId = ProbeSet.Id
+ """.format(name)
+
+ traits_name_dict = dict(g.db.execute(query).fetchall())
+
+
+ for i in range(0, len(dataset), steps):
+ matrix = list(dataset[i:i + steps])
+ trait_name = traits_name_dict[matrix[0][0]]
+
+ strains = [trait_name] + [str(value) for (trait_name, strain, value) in matrix]
+ results.append(",".join(strains))
+
+ breakpoint()
+ return results
+
+
+def compute_top_n_sample(start_vars, dataset, trait_list):
+ """only if dataset is of type probeset"""
+
+
+
+
+ def __fetch_sample_ids__(samples_vals, samples_group):
+
+
+ all_samples = json.loads(samples_vals)
+ sample_data = get_sample_corr_data(
+ sample_type=samples_group, all_samples=all_samples,
+ dataset_samples=dataset.group.all_samples_ordered())
+
+
+ with database_connector() as conn:
+
+ curr = conn.cursor()
+
+ curr.execute(
+ """
+ SELECT Strain.Name, Strain.Id FROM Strain, Species
+ WHERE Strain.Name IN {}
+ and Strain.SpeciesId=Species.Id
+ and Species.name = '{}'
+ """.format(create_in_clause(list(sample_data.keys())),
+ *mescape(dataset.group.species))
+
+ )
+
+ return dict(curr.fetchall())
+
+
+
+
+
+
+
+
+
+
+
+ ty = __fetch_sample_ids__(start_vars["sample_vals"], start_vars["corr_samples_group"])
+
+
+
+ with database_connector() as conn:
+
+ curr = conn.cursor()
+
+ curr.execute(
+
+ """
+ SELECT * from ProbeSetData
+ where StrainID in {}
+ and id in (SELECT ProbeSetXRef.DataId
+ FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze)
+ WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id
+ and ProbeSetFreeze.Name = '{}'
+ and ProbeSet.Name in {}
+ and ProbeSet.Id = ProbeSetXRef.ProbeSetId)
+ """.format(create_in_clause(list(ty.values())),dataset.name,create_in_clause(trait_list))
+
+
+ )
+
+
+
+
+ return chunk_dataset(list(curr.fetchall()),len(ty.values()),dataset.name)
+
+
def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict:
(this_trait_geneid, geneid_dict, species) = do_lit_correlation(
this_trait, this_dataset)
@@ -69,6 +172,7 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
}
return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
+
def __compute_sample_corr__(
start_vars: dict, corr_type: str, method: str, n_top: int,
target_trait_info: tuple):
@@ -86,11 +190,11 @@ def __compute_sample_corr__(
r = ",".join(lts)
target_data.append(r)
-
return run_correlation(
target_data, list(sample_data.values()), method, ",", corr_type,
n_top)
+
def __compute_tissue_corr__(
start_vars: dict, corr_type: str, method: str, n_top: int,
target_trait_info: tuple):
@@ -111,6 +215,7 @@ def __compute_tissue_corr__(
return run_correlation(data[1], data[0], method, ",", "tissue")
return {}
+
def __compute_lit_corr__(
start_vars: dict, corr_type: str, method: str, n_top: int,
target_trait_info: tuple):
@@ -127,6 +232,7 @@ def __compute_lit_corr__(
species=species, gene_id=this_trait_geneid)
return {}
+
def compute_correlation_rust(
start_vars: dict, corr_type: str, method: str = "pearson",
n_top: int = 500, compute_all: bool = False):
@@ -135,7 +241,7 @@ def compute_correlation_rust(
(this_dataset, this_trait, target_dataset, sample_data) = (
target_trait_info)
- ## Replace this with `match ...` once we hit Python 3.10
+ # Replace this with `match ...` once we hit Python 3.10
corr_type_fns = {
"sample": __compute_sample_corr__,
"tissue": __compute_tissue_corr__,
@@ -143,15 +249,23 @@ def compute_correlation_rust(
}
results = corr_type_fns[corr_type](
start_vars, corr_type, method, n_top, target_trait_info)
- ## END: Replace this with `match ...` once we hit Python 3.10
+ # END: Replace this with `match ...` once we hit Python 3.10
top_tissue_results = {}
top_lit_results = {}
+
+
+ results = compute_top_n_sample(start_vars,target_dataset,list(results.keys()))
+
+
+
+ breakpoint()
+
if compute_all:
# example compute of compute both correlation
top_tissue_results = compute_top_n_tissue(
- this_dataset,this_trait,results,method)
- top_lit_results = compute_top_n_lit(results,this_dataset,this_trait)
+ this_dataset, this_trait, results, method)
+ top_lit_results = compute_top_n_lit(results, this_dataset, this_trait)
return {
"correlation_results": merge_results(