aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/correlation_gn3_api.py39
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py121
2 files changed, 78 insertions, 82 deletions
diff --git a/wqflask/wqflask/correlation/correlation_gn3_api.py b/wqflask/wqflask/correlation/correlation_gn3_api.py
index 6df4eafe..1a375501 100644
--- a/wqflask/wqflask/correlation/correlation_gn3_api.py
+++ b/wqflask/wqflask/correlation/correlation_gn3_api.py
@@ -194,46 +194,13 @@ def compute_correlation(start_vars, method="pearson", compute_all=False):
method -- Correlation method to be used (pearson, spearman, or bicor)
compute_all -- Include sample, tissue, and literature correlations (when applicable)
"""
- # pylint: disable-msg=too-many-locals
+ from wqflask.correlation.rust_correlation import compute_correlation_rust
corr_type = start_vars['corr_type']
-
method = start_vars['corr_sample_method']
corr_return_results = int(start_vars.get("corr_return_results", 100))
- corr_input_data = {}
-
- from wqflask.correlation.rust_correlation import compute_correlation_rust
- rust_correlation_results = compute_correlation_rust(
- start_vars, corr_type, method, corr_return_results)
- correlation_results = rust_correlation_results["correlation_results"]
-
- if corr_type == "lit":# elif corr_type == "lit":
- (this_dataset, this_trait, target_dataset,
- sample_data) = create_target_this_trait(start_vars)
- target_dataset_type = target_dataset.type
- this_dataset_type = this_dataset.type
- (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
- this_trait, this_dataset)
-
- conn = database_connector()
- with conn:
- correlation_results = compute_all_lit_correlation(
- conn=conn, trait_lists=list(geneid_dict.items()),
- species=species, gene_id=this_trait_geneid)
-
- correlation_results = correlation_results[0:corr_return_results]
-
- if (compute_all):
- correlation_results = compute_corr_for_top_results(
- start_vars, correlation_results, this_trait, this_dataset,
- target_dataset, corr_type)
-
- return {
- "correlation_results": correlation_results,
- "this_trait": this_trait.name,
- "target_dataset": start_vars['corr_dataset'],
- "return_results": corr_return_results
- }
+ return compute_correlation_rust(
+ start_vars, corr_type, method, corr_return_results, compute_all)
def compute_corr_for_top_results(start_vars,
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 4a22af72..b4435887 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -69,62 +69,91 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
}
return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
+def __compute_sample_corr__(
+ start_vars: dict, corr_type: str, method: str, n_top: int,
+ target_trait_info: tuple):
+ """Compute the sample correlations"""
+ (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+ all_samples = json.loads(start_vars["sample_vals"])
+ sample_data = get_sample_corr_data(
+ sample_type=start_vars["corr_samples_group"], all_samples=all_samples,
+ dataset_samples=this_dataset.group.all_samples_ordered())
+ target_dataset.get_trait_data(list(sample_data.keys()))
+
+ target_data = []
+ for (key, val) in target_dataset.trait_data.items():
+ lts = [key] + [str(x) for x in val]
+ r = ",".join(lts)
+ target_data.append(r)
+
+
+ return run_correlation(
+ target_data, list(sample_data.values()), method, ",", corr_type,
+ n_top)
+
+def __compute_tissue_corr__(
+ start_vars: dict, corr_type: str, method: str, n_top: int,
+ target_trait_info: tuple):
+ """Compute the tissue correlations"""
+ (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+ trait_symbol_dict = this_dataset.retrieve_genes("Symbol")
+ corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
+ symbol_list=list(trait_symbol_dict.values()))
-def compute_correlation_rust(
- start_vars: dict, corr_type: str, method: str = "pearson",
- n_top: int = 500):
- """function to compute correlation"""
-
- (this_dataset, this_trait, target_dataset,
- sample_data) = create_target_this_trait(start_vars)
-
- if corr_type == "sample":
-
- all_samples = json.loads(start_vars["sample_vals"])
- sample_data = get_sample_corr_data(sample_type=start_vars["corr_samples_group"],
- all_samples=all_samples,
- dataset_samples=this_dataset.group.all_samples_ordered())
+ data = parse_tissue_corr_data(
+ symbol_name=this_trait.symbol,
+ symbol_dict=get_trait_symbol_and_tissue_values(
+ symbol_list=[this_trait.symbol]),
+ dataset_symbols=trait_symbol_dict,
+ dataset_vals=corr_result_tissue_vals_dict)
- target_dataset.get_trait_data(list(sample_data.keys()))
+ if data:
+ return run_correlation(data[1], data[0], method, ",", "tissue")
+ return {}
- target_data = []
- for (key, val) in target_dataset.trait_data.items():
- lts = [key] + [str(x) for x in val]
- r = ",".join(lts)
- target_data.append(r)
+def __compute_lit_corr__(
+ start_vars: dict, corr_type: str, method: str, n_top: int,
+ target_trait_info: tuple):
+ """Compute the literature correlations"""
+ (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+ target_dataset_type = target_dataset.type
+ this_dataset_type = this_dataset.type
+ (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
+ this_trait, this_dataset)
+ with database_connector() as conn:
+ return compute_all_lit_correlation(
+ conn=conn, trait_lists=list(geneid_dict.items()),
+ species=species, gene_id=this_trait_geneid)
+ return {}
- results = run_correlation(
- target_data, list(sample_data.values()), method, ",", corr_type,
- n_top)
+def compute_correlation_rust(
+ start_vars: dict, corr_type: str, method: str = "pearson",
+ n_top: int = 500, compute_all: bool = False):
+ """function to compute correlation"""
+ target_trait_info = create_target_this_trait(start_vars)
+ (this_dataset, this_trait, target_dataset, sample_data) = (
+ target_trait_info)
+
+ corr_type_fns = {
+ "sample": __compute_sample_corr__,
+ "tissue": __compute_tissue_corr__,
+ "lit": __compute_lit_corr__
+ }
+ results = corr_type_fns[corr_type](
+ start_vars, corr_type, method, n_top, target_trait_info)
+ top_tissue_results = {}
+ top_lit_results = {}
+ if compute_all:
# example compute of compute both correlation
- top_tissue_results = compute_top_n_tissue(this_dataset,this_trait,results,method)
+ top_tissue_results = compute_top_n_tissue(
+ this_dataset,this_trait,results,method)
top_lit_results = compute_top_n_lit(results,this_dataset,this_trait)
- # merging the results
- results = merge_results(results, top_tissue_results, top_lit_results)
-
- if corr_type == "tissue":
-
- trait_symbol_dict = this_dataset.retrieve_genes("Symbol")
- corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
- symbol_list=list(trait_symbol_dict.values()))
-
- data = parse_tissue_corr_data(symbol_name=this_trait.symbol,
- symbol_dict=get_trait_symbol_and_tissue_values(
- symbol_list=[this_trait.symbol]
- ),
- dataset_symbols=trait_symbol_dict,
- dataset_vals=corr_result_tissue_vals_dict)
-
- if data:
- results = merge_results(
- run_correlation(data[1], data[0], method, ",", "tissue"),
- {}, {})
-
return {
- "correlation_results": results,
+ "correlation_results": merge_results(
+ results, top_tissue_results, top_lit_results),
"this_trait": this_trait.name,
"target_dataset": start_vars['corr_dataset'],
"return_results": n_top