aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-08-11 12:30:18 +0300
committerFrederick Muriuki Muriithi2022-08-11 12:30:18 +0300
commit3c1cb6a94b64dae28c62f481e1f4499f8f5b89e7 (patch)
tree8207d0212a3562d2b375667a11770cb1978e0717
parent309785e95696567a35b42690b032808eaa59c86d (diff)
downloadgenenetwork2-3c1cb6a94b64dae28c62f481e1f4499f8f5b89e7.tar.gz
Refactor: separate the three correlation types
Refactor the code such that each correlation type (sample, tissue, literature) is computed in its own function. This makes the code clearer, and helps reduce repetition.
-rw-r--r--wqflask/wqflask/correlation/correlation_gn3_api.py39
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py121
2 files changed, 78 insertions, 82 deletions
diff --git a/wqflask/wqflask/correlation/correlation_gn3_api.py b/wqflask/wqflask/correlation/correlation_gn3_api.py
index 6df4eafe..1a375501 100644
--- a/wqflask/wqflask/correlation/correlation_gn3_api.py
+++ b/wqflask/wqflask/correlation/correlation_gn3_api.py
@@ -194,46 +194,13 @@ def compute_correlation(start_vars, method="pearson", compute_all=False):
method -- Correlation method to be used (pearson, spearman, or bicor)
compute_all -- Include sample, tissue, and literature correlations (when applicable)
"""
- # pylint: disable-msg=too-many-locals
+ from wqflask.correlation.rust_correlation import compute_correlation_rust
corr_type = start_vars['corr_type']
-
method = start_vars['corr_sample_method']
corr_return_results = int(start_vars.get("corr_return_results", 100))
- corr_input_data = {}
-
- from wqflask.correlation.rust_correlation import compute_correlation_rust
- rust_correlation_results = compute_correlation_rust(
- start_vars, corr_type, method, corr_return_results)
- correlation_results = rust_correlation_results["correlation_results"]
-
- if corr_type == "lit":# elif corr_type == "lit":
- (this_dataset, this_trait, target_dataset,
- sample_data) = create_target_this_trait(start_vars)
- target_dataset_type = target_dataset.type
- this_dataset_type = this_dataset.type
- (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
- this_trait, this_dataset)
-
- conn = database_connector()
- with conn:
- correlation_results = compute_all_lit_correlation(
- conn=conn, trait_lists=list(geneid_dict.items()),
- species=species, gene_id=this_trait_geneid)
-
- correlation_results = correlation_results[0:corr_return_results]
-
- if (compute_all):
- correlation_results = compute_corr_for_top_results(
- start_vars, correlation_results, this_trait, this_dataset,
- target_dataset, corr_type)
-
- return {
- "correlation_results": correlation_results,
- "this_trait": this_trait.name,
- "target_dataset": start_vars['corr_dataset'],
- "return_results": corr_return_results
- }
+ return compute_correlation_rust(
+ start_vars, corr_type, method, corr_return_results, compute_all)
def compute_corr_for_top_results(start_vars,
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 4a22af72..b4435887 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -69,62 +69,91 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
}
return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
+def __compute_sample_corr__(
+ start_vars: dict, corr_type: str, method: str, n_top: int,
+ target_trait_info: tuple):
+ """Compute the sample correlations"""
+ (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+ all_samples = json.loads(start_vars["sample_vals"])
+ sample_data = get_sample_corr_data(
+ sample_type=start_vars["corr_samples_group"], all_samples=all_samples,
+ dataset_samples=this_dataset.group.all_samples_ordered())
+ target_dataset.get_trait_data(list(sample_data.keys()))
+
+ target_data = []
+ for (key, val) in target_dataset.trait_data.items():
+ lts = [key] + [str(x) for x in val]
+ r = ",".join(lts)
+ target_data.append(r)
+
+
+ return run_correlation(
+ target_data, list(sample_data.values()), method, ",", corr_type,
+ n_top)
+
+def __compute_tissue_corr__(
+ start_vars: dict, corr_type: str, method: str, n_top: int,
+ target_trait_info: tuple):
+ """Compute the tissue correlations"""
+ (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+ trait_symbol_dict = this_dataset.retrieve_genes("Symbol")
+ corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
+ symbol_list=list(trait_symbol_dict.values()))
-def compute_correlation_rust(
- start_vars: dict, corr_type: str, method: str = "pearson",
- n_top: int = 500):
- """function to compute correlation"""
-
- (this_dataset, this_trait, target_dataset,
- sample_data) = create_target_this_trait(start_vars)
-
- if corr_type == "sample":
-
- all_samples = json.loads(start_vars["sample_vals"])
- sample_data = get_sample_corr_data(sample_type=start_vars["corr_samples_group"],
- all_samples=all_samples,
- dataset_samples=this_dataset.group.all_samples_ordered())
+ data = parse_tissue_corr_data(
+ symbol_name=this_trait.symbol,
+ symbol_dict=get_trait_symbol_and_tissue_values(
+ symbol_list=[this_trait.symbol]),
+ dataset_symbols=trait_symbol_dict,
+ dataset_vals=corr_result_tissue_vals_dict)
- target_dataset.get_trait_data(list(sample_data.keys()))
+ if data:
+ return run_correlation(data[1], data[0], method, ",", "tissue")
+ return {}
- target_data = []
- for (key, val) in target_dataset.trait_data.items():
- lts = [key] + [str(x) for x in val]
- r = ",".join(lts)
- target_data.append(r)
+def __compute_lit_corr__(
+ start_vars: dict, corr_type: str, method: str, n_top: int,
+ target_trait_info: tuple):
+ """Compute the literature correlations"""
+ (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+ target_dataset_type = target_dataset.type
+ this_dataset_type = this_dataset.type
+ (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
+ this_trait, this_dataset)
+ with database_connector() as conn:
+ return compute_all_lit_correlation(
+ conn=conn, trait_lists=list(geneid_dict.items()),
+ species=species, gene_id=this_trait_geneid)
+ return {}
- results = run_correlation(
- target_data, list(sample_data.values()), method, ",", corr_type,
- n_top)
+def compute_correlation_rust(
+ start_vars: dict, corr_type: str, method: str = "pearson",
+ n_top: int = 500, compute_all: bool = False):
+ """function to compute correlation"""
+ target_trait_info = create_target_this_trait(start_vars)
+ (this_dataset, this_trait, target_dataset, sample_data) = (
+ target_trait_info)
+
+ corr_type_fns = {
+ "sample": __compute_sample_corr__,
+ "tissue": __compute_tissue_corr__,
+ "lit": __compute_lit_corr__
+ }
+ results = corr_type_fns[corr_type](
+ start_vars, corr_type, method, n_top, target_trait_info)
+ top_tissue_results = {}
+ top_lit_results = {}
+ if compute_all:
# example compute of compute both correlation
- top_tissue_results = compute_top_n_tissue(this_dataset,this_trait,results,method)
+ top_tissue_results = compute_top_n_tissue(
+ this_dataset,this_trait,results,method)
top_lit_results = compute_top_n_lit(results,this_dataset,this_trait)
- # merging the results
- results = merge_results(results, top_tissue_results, top_lit_results)
-
- if corr_type == "tissue":
-
- trait_symbol_dict = this_dataset.retrieve_genes("Symbol")
- corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
- symbol_list=list(trait_symbol_dict.values()))
-
- data = parse_tissue_corr_data(symbol_name=this_trait.symbol,
- symbol_dict=get_trait_symbol_and_tissue_values(
- symbol_list=[this_trait.symbol]
- ),
- dataset_symbols=trait_symbol_dict,
- dataset_vals=corr_result_tissue_vals_dict)
-
- if data:
- results = merge_results(
- run_correlation(data[1], data[0], method, ",", "tissue"),
- {}, {})
-
return {
- "correlation_results": results,
+ "correlation_results": merge_results(
+ results, top_tissue_results, top_lit_results),
"this_trait": this_trait.name,
"target_dataset": start_vars['corr_dataset'],
"return_results": n_top