aboutsummaryrefslogtreecommitdiff
path: root/wqflask
diff options
context:
space:
mode:
Diffstat (limited to 'wqflask')
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py63
1 files changed, 30 insertions, 33 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 95354994..7d796e70 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -3,7 +3,8 @@ import json
from functools import reduce
from utility.db_tools import mescape
from utility.db_tools import create_in_clause
-from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values
+from wqflask.correlation.correlation_functions\
+ import get_trait_symbol_and_tissue_values
from wqflask.correlation.correlation_gn3_api import create_target_this_trait
from wqflask.correlation.correlation_gn3_api import lit_for_trait_list
from wqflask.correlation.correlation_gn3_api import do_lit_correlation
@@ -14,9 +15,7 @@ from gn3.computations.rust_correlation import parse_tissue_corr_data
from gn3.db_utils import database_connector
-
-
-def chunk_dataset(dataset,steps,name):
+def chunk_dataset(dataset, steps, name):
results = []
@@ -39,7 +38,8 @@ def chunk_dataset(dataset,steps,name):
matrix = list(dataset[i:i + steps])
trait_name = traits_name_dict[matrix[0][0]]
- strains = [trait_name] + [str(value) for (trait_name, strain, value) in matrix]
+ strains = [trait_name] + [str(value)
+ for (trait_name, strain, value) in matrix]
results.append(",".join(strains))
return results
@@ -48,18 +48,16 @@ def chunk_dataset(dataset,steps,name):
def compute_top_n_sample(start_vars, dataset, trait_list):
"""check if dataset is of type probeset"""
- if dataset.type.lower()!= "probeset":
- return {}
+ if dataset.type.lower() != "probeset":
+ return {}
def __fetch_sample_ids__(samples_vals, samples_group):
-
all_samples = json.loads(samples_vals)
sample_data = get_sample_corr_data(
sample_type=samples_group, all_samples=all_samples,
dataset_samples=dataset.group.all_samples_ordered())
-
with database_connector() as conn:
curr = conn.cursor()
@@ -75,21 +73,20 @@ def compute_top_n_sample(start_vars, dataset, trait_list):
)
- return (sample_data,dict(curr.fetchall()))
-
- (sample_data,sample_ids) = __fetch_sample_ids__(start_vars["sample_vals"], start_vars["corr_samples_group"])
-
+ return (sample_data, dict(curr.fetchall()))
+ (sample_data, sample_ids) = __fetch_sample_ids__(
+ start_vars["sample_vals"], start_vars["corr_samples_group"])
with database_connector() as conn:
curr = conn.cursor()
- #fetching strain data in bulk
+ # fetching strain data in bulk
curr.execute(
- """
+ """
SELECT * from ProbeSetData
where StrainID in {}
and id in (SELECT ProbeSetXRef.DataId
@@ -98,21 +95,25 @@ def compute_top_n_sample(start_vars, dataset, trait_list):
and ProbeSetFreeze.Name = '{}'
and ProbeSet.Name in {}
and ProbeSet.Id = ProbeSetXRef.ProbeSetId)
- """.format(create_in_clause(list(sample_ids.values())),dataset.name,create_in_clause(trait_list))
+ """.format(create_in_clause(list(sample_ids.values())), dataset.name, create_in_clause(trait_list))
)
- corr_data = chunk_dataset(list(curr.fetchall()),len(sample_ids.values()),dataset.name)
+ corr_data = chunk_dataset(list(curr.fetchall()), len(
+ sample_ids.values()), dataset.name)
- return run_correlation(corr_data,list(sample_data.values()),"pearson",",")
+ return run_correlation(corr_data,
+ list(sample_data.values()),
+ "pearson", ",")
def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict:
(this_trait_geneid, geneid_dict, species) = do_lit_correlation(
this_trait, this_dataset)
- geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if
+ geneid_dict = {trait_name: geneid for (trait_name, geneid)
+ in geneid_dict.items() if
corr_results.get(trait_name)}
with database_connector() as conn:
return reduce(
@@ -166,8 +167,6 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
-
-
def __compute_sample_corr__(
start_vars: dict, corr_type: str, method: str, n_top: int,
target_trait_info: tuple):
@@ -247,7 +246,6 @@ def compute_correlation_rust(
# END: Replace this with `match ...` once we hit Python 3.10
-
top_a = top_b = {}
if compute_all:
@@ -255,28 +253,27 @@ def compute_correlation_rust(
if corr_type == "sample":
top_a = compute_top_n_tissue(
- this_dataset, this_trait, results, method)
-
- top_b = compute_top_n_lit(results, this_dataset, this_trait)
+ this_dataset, this_trait, results, method)
+ top_b = compute_top_n_lit(results, this_dataset, this_trait)
elif corr_type == "lit":
- #currently fails for lit
+ # currently fails for lit
- top_a = compute_top_n_sample(start_vars,target_dataset,list(results.keys()))
- top_b = compute_top_n_tissue(
- this_dataset, this_trait, results, method)
+ top_a = compute_top_n_sample(
+ start_vars, target_dataset, list(results.keys()))
+ top_b = compute_top_n_tissue(
+ this_dataset, this_trait, results, method)
else:
- top_a = compute_top_n_sample(start_vars,target_dataset,list(results.keys()))
+ top_a = compute_top_n_sample(
+ start_vars, target_dataset, list(results.keys()))
top_b = compute_top_n_lit(results, this_dataset, this_trait)
-
-
- return {
+ return {
"correlation_results": merge_results(
results, top_a, top_b),
"this_trait": this_trait.name,