aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py69
1 files changed, 38 insertions, 31 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 94720f54..2a2ad4a0 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -39,15 +39,14 @@ def chunk_dataset(dataset,steps,name):
strains = [trait_name] + [str(value) for (trait_name, strain, value) in matrix]
results.append(",".join(strains))
- breakpoint()
return results
def compute_top_n_sample(start_vars, dataset, trait_list):
- """only if dataset is of type probeset"""
-
-
+ """check if dataset is of type probeset"""
+ if dataset.type!= "Probeset":
+ return {}
def __fetch_sample_ids__(samples_vals, samples_group):
@@ -73,19 +72,9 @@ def compute_top_n_sample(start_vars, dataset, trait_list):
)
- return dict(curr.fetchall())
-
-
-
-
-
+ return (sample_data,dict(curr.fetchall()))
-
-
-
-
-
- ty = __fetch_sample_ids__(start_vars["sample_vals"], start_vars["corr_samples_group"])
+ (sample_data,sample_ids) = __fetch_sample_ids__(start_vars["sample_vals"], start_vars["corr_samples_group"])
@@ -93,6 +82,8 @@ def compute_top_n_sample(start_vars, dataset, trait_list):
curr = conn.cursor()
+ #fetching strain data in bulk
+
curr.execute(
"""
@@ -104,15 +95,14 @@ def compute_top_n_sample(start_vars, dataset, trait_list):
and ProbeSetFreeze.Name = '{}'
and ProbeSet.Name in {}
and ProbeSet.Id = ProbeSetXRef.ProbeSetId)
- """.format(create_in_clause(list(ty.values())),dataset.name,create_in_clause(trait_list))
+ """.format(create_in_clause(list(sample_ids.values())),dataset.name,create_in_clause(trait_list))
)
+ corr_data = chunk_dataset(list(curr.fetchall()),len(sample_ids.values()),dataset.name)
-
-
- return chunk_dataset(list(curr.fetchall()),len(ty.values()),dataset.name)
+ return run_correlation(corr_data,list(sample_data.values()),"pearson",",")
def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict:
@@ -170,7 +160,10 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
**dict_c.get(trait_name, {})
}
}
- return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
+ results = [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
+
+
+ return results
def __compute_sample_corr__(
@@ -249,27 +242,41 @@ def compute_correlation_rust(
}
results = corr_type_fns[corr_type](
start_vars, corr_type, method, n_top, target_trait_info)
+
# END: Replace this with `match ...` once we hit Python 3.10
- top_tissue_results = {}
- top_lit_results = {}
+ top_a = top_b = {}
- results = compute_top_n_sample(start_vars,target_dataset,list(results.keys()))
+ if compute_all:
+ if corr_type == "sample":
+ top_a = compute_top_n_tissue(
+ this_dataset, this_trait, results, method)
+
+ top_b = compute_top_n_lit(results, this_dataset, this_trait)
- breakpoint()
- if compute_all:
- # example compute of compute both correlation
- top_tissue_results = compute_top_n_tissue(
+ elif corr_type == "lit":
+
+ #currently fails for lit
+
+ top_a = compute_top_n_sample(start_vars,target_dataset,list(results.keys()))
+ top_b = compute_top_n_tissue(
this_dataset, this_trait, results, method)
- top_lit_results = compute_top_n_lit(results, this_dataset, this_trait)
- return {
+ else:
+
+ top_a = compute_top_n_sample(start_vars,target_dataset,list(results.keys()))
+
+ top_b = compute_top_n_lit(results, this_dataset, this_trait)
+
+
+
+ return {
"correlation_results": merge_results(
- results, top_tissue_results, top_lit_results),
+ results, top_a, top_b),
"this_trait": this_trait.name,
"target_dataset": start_vars['corr_dataset'],
"return_results": n_top