diff options
author | Frederick Muriuki Muriithi | 2022-08-11 12:30:18 +0300 |
---|---|---|
committer | Frederick Muriuki Muriithi | 2022-08-11 12:30:18 +0300 |
commit | 3c1cb6a94b64dae28c62f481e1f4499f8f5b89e7 (patch) | |
tree | 8207d0212a3562d2b375667a11770cb1978e0717 | |
parent | 309785e95696567a35b42690b032808eaa59c86d (diff) | |
download | genenetwork2-3c1cb6a94b64dae28c62f481e1f4499f8f5b89e7.tar.gz |
Refactor: separate the three correlation types
Refactor the code such that each correlation type (sample, tissue,
literature) is computed in its own function. This makes the code
clearer, and helps reduce repetition.
-rw-r--r-- | wqflask/wqflask/correlation/correlation_gn3_api.py | 39 | ||||
-rw-r--r-- | wqflask/wqflask/correlation/rust_correlation.py | 121 |
2 files changed, 78 insertions, 82 deletions
diff --git a/wqflask/wqflask/correlation/correlation_gn3_api.py b/wqflask/wqflask/correlation/correlation_gn3_api.py index 6df4eafe..1a375501 100644 --- a/wqflask/wqflask/correlation/correlation_gn3_api.py +++ b/wqflask/wqflask/correlation/correlation_gn3_api.py @@ -194,46 +194,13 @@ def compute_correlation(start_vars, method="pearson", compute_all=False): method -- Correlation method to be used (pearson, spearman, or bicor) compute_all -- Include sample, tissue, and literature correlations (when applicable) """ - # pylint: disable-msg=too-many-locals + from wqflask.correlation.rust_correlation import compute_correlation_rust corr_type = start_vars['corr_type'] - method = start_vars['corr_sample_method'] corr_return_results = int(start_vars.get("corr_return_results", 100)) - corr_input_data = {} - - from wqflask.correlation.rust_correlation import compute_correlation_rust - rust_correlation_results = compute_correlation_rust( - start_vars, corr_type, method, corr_return_results) - correlation_results = rust_correlation_results["correlation_results"] - - if corr_type == "lit":# elif corr_type == "lit": - (this_dataset, this_trait, target_dataset, - sample_data) = create_target_this_trait(start_vars) - target_dataset_type = target_dataset.type - this_dataset_type = this_dataset.type - (this_trait_geneid, geneid_dict, species) = do_lit_correlation( - this_trait, this_dataset) - - conn = database_connector() - with conn: - correlation_results = compute_all_lit_correlation( - conn=conn, trait_lists=list(geneid_dict.items()), - species=species, gene_id=this_trait_geneid) - - correlation_results = correlation_results[0:corr_return_results] - - if (compute_all): - correlation_results = compute_corr_for_top_results( - start_vars, correlation_results, this_trait, this_dataset, - target_dataset, corr_type) - - return { - "correlation_results": correlation_results, - "this_trait": this_trait.name, - "target_dataset": start_vars['corr_dataset'], - "return_results": corr_return_results - } + return compute_correlation_rust( + start_vars, corr_type, method, corr_return_results, compute_all) def compute_corr_for_top_results(start_vars, diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py index 4a22af72..b4435887 100644 --- a/wqflask/wqflask/correlation/rust_correlation.py +++ b/wqflask/wqflask/correlation/rust_correlation.py @@ -69,62 +69,91 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]: } return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()] +def __compute_sample_corr__( + start_vars: dict, corr_type: str, method: str, n_top: int, + target_trait_info: tuple): + """Compute the sample correlations""" + (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info + all_samples = json.loads(start_vars["sample_vals"]) + sample_data = get_sample_corr_data( + sample_type=start_vars["corr_samples_group"], all_samples=all_samples, + dataset_samples=this_dataset.group.all_samples_ordered()) + target_dataset.get_trait_data(list(sample_data.keys())) + + target_data = [] + for (key, val) in target_dataset.trait_data.items(): + lts = [key] + [str(x) for x in val] + r = ",".join(lts) + target_data.append(r) + + + return run_correlation( + target_data, list(sample_data.values()), method, ",", corr_type, + n_top) + +def __compute_tissue_corr__( + start_vars: dict, corr_type: str, method: str, n_top: int, + target_trait_info: tuple): + """Compute the tissue correlations""" + (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info + trait_symbol_dict = this_dataset.retrieve_genes("Symbol") + corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values( + symbol_list=list(trait_symbol_dict.values())) -def compute_correlation_rust( - start_vars: dict, corr_type: str, method: str = "pearson", - n_top: int = 500): - """function to compute correlation""" - - (this_dataset, this_trait, target_dataset, - sample_data) = create_target_this_trait(start_vars) - - if corr_type == "sample": - - all_samples = json.loads(start_vars["sample_vals"]) - sample_data = get_sample_corr_data(sample_type=start_vars["corr_samples_group"], - all_samples=all_samples, - dataset_samples=this_dataset.group.all_samples_ordered()) + data = parse_tissue_corr_data( + symbol_name=this_trait.symbol, + symbol_dict=get_trait_symbol_and_tissue_values( + symbol_list=[this_trait.symbol]), + dataset_symbols=trait_symbol_dict, + dataset_vals=corr_result_tissue_vals_dict) - target_dataset.get_trait_data(list(sample_data.keys())) + if data: + return run_correlation(data[1], data[0], method, ",", "tissue") + return {} - target_data = [] - for (key, val) in target_dataset.trait_data.items(): - lts = [key] + [str(x) for x in val] - r = ",".join(lts) - target_data.append(r) +def __compute_lit_corr__( + start_vars: dict, corr_type: str, method: str, n_top: int, + target_trait_info: tuple): + """Compute the literature correlations""" + (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info + target_dataset_type = target_dataset.type + this_dataset_type = this_dataset.type + (this_trait_geneid, geneid_dict, species) = do_lit_correlation( + this_trait, this_dataset) + with database_connector() as conn: + return compute_all_lit_correlation( + conn=conn, trait_lists=list(geneid_dict.items()), + species=species, gene_id=this_trait_geneid) + return {} - results = run_correlation( - target_data, list(sample_data.values()), method, ",", corr_type, - n_top) +def compute_correlation_rust( + start_vars: dict, corr_type: str, method: str = "pearson", + n_top: int = 500, compute_all: bool = False): + """function to compute correlation""" + target_trait_info = create_target_this_trait(start_vars) + (this_dataset, this_trait, target_dataset, sample_data) = ( + target_trait_info) + + corr_type_fns = { + "sample": __compute_sample_corr__, + "tissue": __compute_tissue_corr__, + "lit": __compute_lit_corr__ + } + results = corr_type_fns[corr_type]( + start_vars, corr_type, method, n_top, target_trait_info) + top_tissue_results = {} + top_lit_results = {} + if compute_all: # example compute of compute both correlation - top_tissue_results = compute_top_n_tissue(this_dataset,this_trait,results,method) + top_tissue_results = compute_top_n_tissue( + this_dataset,this_trait,results,method) top_lit_results = compute_top_n_lit(results,this_dataset,this_trait) - # merging the results - results = merge_results(results, top_tissue_results, top_lit_results) - - if corr_type == "tissue": - - trait_symbol_dict = this_dataset.retrieve_genes("Symbol") - corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values( - symbol_list=list(trait_symbol_dict.values())) - - data = parse_tissue_corr_data(symbol_name=this_trait.symbol, - symbol_dict=get_trait_symbol_and_tissue_values( - symbol_list=[this_trait.symbol] - ), - dataset_symbols=trait_symbol_dict, - dataset_vals=corr_result_tissue_vals_dict) - - if data: - results = merge_results( - run_correlation(data[1], data[0], method, ",", "tissue"), - {}, {}) - return { - "correlation_results": results, + "correlation_results": merge_results( + results, top_tissue_results, top_lit_results), "this_trait": this_trait.name, "target_dataset": start_vars['corr_dataset'], "return_results": n_top |