aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-08-11 12:30:18 +0300
committerFrederick Muriuki Muriithi2022-08-12 13:13:27 +0300
commit17ca7e77631f1fc6a247accfb39cf8da23dc4a92 (patch)
treebaf4efcde863d3a38d85c498b59036992ada5ba0
parent15a1ff0d5cf8562c16e14491edd24b0f3171e083 (diff)
downloadgenenetwork2-17ca7e77631f1fc6a247accfb39cf8da23dc4a92.tar.gz
Refactor: separate the three correlation types
Refactor the code such that each correlation type (sample, tissue, literature) is computed in its own function. This makes the code clearer, and helps reduce repetition.
-rw-r--r--wqflask/wqflask/correlation/correlation_gn3_api.py39
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py121
2 files changed, 78 insertions, 82 deletions
diff --git a/wqflask/wqflask/correlation/correlation_gn3_api.py b/wqflask/wqflask/correlation/correlation_gn3_api.py
index 6df4eafe..1a375501 100644
--- a/wqflask/wqflask/correlation/correlation_gn3_api.py
+++ b/wqflask/wqflask/correlation/correlation_gn3_api.py
@@ -194,46 +194,13 @@ def compute_correlation(start_vars, method="pearson", compute_all=False):
method -- Correlation method to be used (pearson, spearman, or bicor)
compute_all -- Include sample, tissue, and literature correlations (when applicable)
"""
- # pylint: disable-msg=too-many-locals
+ from wqflask.correlation.rust_correlation import compute_correlation_rust
corr_type = start_vars['corr_type']
-
method = start_vars['corr_sample_method']
corr_return_results = int(start_vars.get("corr_return_results", 100))
- corr_input_data = {}
-
- from wqflask.correlation.rust_correlation import compute_correlation_rust
- rust_correlation_results = compute_correlation_rust(
- start_vars, corr_type, method, corr_return_results)
- correlation_results = rust_correlation_results["correlation_results"]
-
- if corr_type == "lit":# elif corr_type == "lit":
- (this_dataset, this_trait, target_dataset,
- sample_data) = create_target_this_trait(start_vars)
- target_dataset_type = target_dataset.type
- this_dataset_type = this_dataset.type
- (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
- this_trait, this_dataset)
-
- conn = database_connector()
- with conn:
- correlation_results = compute_all_lit_correlation(
- conn=conn, trait_lists=list(geneid_dict.items()),
- species=species, gene_id=this_trait_geneid)
-
- correlation_results = correlation_results[0:corr_return_results]
-
- if (compute_all):
- correlation_results = compute_corr_for_top_results(
- start_vars, correlation_results, this_trait, this_dataset,
- target_dataset, corr_type)
-
- return {
- "correlation_results": correlation_results,
- "this_trait": this_trait.name,
- "target_dataset": start_vars['corr_dataset'],
- "return_results": corr_return_results
- }
+ return compute_correlation_rust(
+ start_vars, corr_type, method, corr_return_results, compute_all)
def compute_corr_for_top_results(start_vars,
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 4a22af72..b4435887 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -69,62 +69,91 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
}
return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
+def __compute_sample_corr__(
+ start_vars: dict, corr_type: str, method: str, n_top: int,
+ target_trait_info: tuple):
+ """Compute the sample correlations"""
+ (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+ all_samples = json.loads(start_vars["sample_vals"])
+ sample_data = get_sample_corr_data(
+ sample_type=start_vars["corr_samples_group"], all_samples=all_samples,
+ dataset_samples=this_dataset.group.all_samples_ordered())
+ target_dataset.get_trait_data(list(sample_data.keys()))
+
+ target_data = []
+ for (key, val) in target_dataset.trait_data.items():
+ lts = [key] + [str(x) for x in val]
+ r = ",".join(lts)
+ target_data.append(r)
+
+
+ return run_correlation(
+ target_data, list(sample_data.values()), method, ",", corr_type,
+ n_top)
+
+def __compute_tissue_corr__(
+ start_vars: dict, corr_type: str, method: str, n_top: int,
+ target_trait_info: tuple):
+ """Compute the tissue correlations"""
+ (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+ trait_symbol_dict = this_dataset.retrieve_genes("Symbol")
+ corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
+ symbol_list=list(trait_symbol_dict.values()))
-def compute_correlation_rust(
- start_vars: dict, corr_type: str, method: str = "pearson",
- n_top: int = 500):
- """function to compute correlation"""
-
- (this_dataset, this_trait, target_dataset,
- sample_data) = create_target_this_trait(start_vars)
-
- if corr_type == "sample":
-
- all_samples = json.loads(start_vars["sample_vals"])
- sample_data = get_sample_corr_data(sample_type=start_vars["corr_samples_group"],
- all_samples=all_samples,
- dataset_samples=this_dataset.group.all_samples_ordered())
+ data = parse_tissue_corr_data(
+ symbol_name=this_trait.symbol,
+ symbol_dict=get_trait_symbol_and_tissue_values(
+ symbol_list=[this_trait.symbol]),
+ dataset_symbols=trait_symbol_dict,
+ dataset_vals=corr_result_tissue_vals_dict)
- target_dataset.get_trait_data(list(sample_data.keys()))
+ if data:
+ return run_correlation(data[1], data[0], method, ",", "tissue")
+ return {}
- target_data = []
- for (key, val) in target_dataset.trait_data.items():
- lts = [key] + [str(x) for x in val]
- r = ",".join(lts)
- target_data.append(r)
+def __compute_lit_corr__(
+ start_vars: dict, corr_type: str, method: str, n_top: int,
+ target_trait_info: tuple):
+ """Compute the literature correlations"""
+ (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+ target_dataset_type = target_dataset.type
+ this_dataset_type = this_dataset.type
+ (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
+ this_trait, this_dataset)
+ with database_connector() as conn:
+ return compute_all_lit_correlation(
+ conn=conn, trait_lists=list(geneid_dict.items()),
+ species=species, gene_id=this_trait_geneid)
+ return {}
- results = run_correlation(
- target_data, list(sample_data.values()), method, ",", corr_type,
- n_top)
+def compute_correlation_rust(
+ start_vars: dict, corr_type: str, method: str = "pearson",
+ n_top: int = 500, compute_all: bool = False):
+ """function to compute correlation"""
+ target_trait_info = create_target_this_trait(start_vars)
+ (this_dataset, this_trait, target_dataset, sample_data) = (
+ target_trait_info)
+
+ corr_type_fns = {
+ "sample": __compute_sample_corr__,
+ "tissue": __compute_tissue_corr__,
+ "lit": __compute_lit_corr__
+ }
+ results = corr_type_fns[corr_type](
+ start_vars, corr_type, method, n_top, target_trait_info)
+ top_tissue_results = {}
+ top_lit_results = {}
+ if compute_all:
# example compute of compute both correlation
- top_tissue_results = compute_top_n_tissue(this_dataset,this_trait,results,method)
+ top_tissue_results = compute_top_n_tissue(
+ this_dataset,this_trait,results,method)
top_lit_results = compute_top_n_lit(results,this_dataset,this_trait)
- # merging the results
- results = merge_results(results, top_tissue_results, top_lit_results)
-
- if corr_type == "tissue":
-
- trait_symbol_dict = this_dataset.retrieve_genes("Symbol")
- corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
- symbol_list=list(trait_symbol_dict.values()))
-
- data = parse_tissue_corr_data(symbol_name=this_trait.symbol,
- symbol_dict=get_trait_symbol_and_tissue_values(
- symbol_list=[this_trait.symbol]
- ),
- dataset_symbols=trait_symbol_dict,
- dataset_vals=corr_result_tissue_vals_dict)
-
- if data:
- results = merge_results(
- run_correlation(data[1], data[0], method, ",", "tissue"),
- {}, {})
-
return {
- "correlation_results": results,
+ "correlation_results": merge_results(
+ results, top_tissue_results, top_lit_results),
"this_trait": this_trait.name,
"target_dataset": start_vars['corr_dataset'],
"return_results": n_top