From 3c1cb6a94b64dae28c62f481e1f4499f8f5b89e7 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Thu, 11 Aug 2022 12:30:18 +0300 Subject: Refactor: separate the three correlation types Refactor the code such that each correlation type (sample, tissue, literature) is computed in its own function. This makes the code clearer, and helps reduce repetition. --- wqflask/wqflask/correlation/correlation_gn3_api.py | 39 +------ wqflask/wqflask/correlation/rust_correlation.py | 121 +++++++++++++-------- 2 files changed, 78 insertions(+), 82 deletions(-) (limited to 'wqflask') diff --git a/wqflask/wqflask/correlation/correlation_gn3_api.py b/wqflask/wqflask/correlation/correlation_gn3_api.py index 6df4eafe..1a375501 100644 --- a/wqflask/wqflask/correlation/correlation_gn3_api.py +++ b/wqflask/wqflask/correlation/correlation_gn3_api.py @@ -194,46 +194,13 @@ def compute_correlation(start_vars, method="pearson", compute_all=False): method -- Correlation method to be used (pearson, spearman, or bicor) compute_all -- Include sample, tissue, and literature correlations (when applicable) """ - # pylint: disable-msg=too-many-locals + from wqflask.correlation.rust_correlation import compute_correlation_rust corr_type = start_vars['corr_type'] - method = start_vars['corr_sample_method'] corr_return_results = int(start_vars.get("corr_return_results", 100)) - corr_input_data = {} - - from wqflask.correlation.rust_correlation import compute_correlation_rust - rust_correlation_results = compute_correlation_rust( - start_vars, corr_type, method, corr_return_results) - correlation_results = rust_correlation_results["correlation_results"] - - if corr_type == "lit":# elif corr_type == "lit": - (this_dataset, this_trait, target_dataset, - sample_data) = create_target_this_trait(start_vars) - target_dataset_type = target_dataset.type - this_dataset_type = this_dataset.type - (this_trait_geneid, geneid_dict, species) = do_lit_correlation( - this_trait, this_dataset) - - conn = database_connector() - with conn: - correlation_results = compute_all_lit_correlation( - conn=conn, trait_lists=list(geneid_dict.items()), - species=species, gene_id=this_trait_geneid) - - correlation_results = correlation_results[0:corr_return_results] - - if (compute_all): - correlation_results = compute_corr_for_top_results( - start_vars, correlation_results, this_trait, this_dataset, - target_dataset, corr_type) - - return { - "correlation_results": correlation_results, - "this_trait": this_trait.name, - "target_dataset": start_vars['corr_dataset'], - "return_results": corr_return_results - } + return compute_correlation_rust( + start_vars, corr_type, method, corr_return_results, compute_all) def compute_corr_for_top_results(start_vars, diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py index 4a22af72..b4435887 100644 --- a/wqflask/wqflask/correlation/rust_correlation.py +++ b/wqflask/wqflask/correlation/rust_correlation.py @@ -69,62 +69,91 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]: } return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()] +def __compute_sample_corr__( + start_vars: dict, corr_type: str, method: str, n_top: int, + target_trait_info: tuple): + """Compute the sample correlations""" + (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info + all_samples = json.loads(start_vars["sample_vals"]) + sample_data = get_sample_corr_data( + sample_type=start_vars["corr_samples_group"], all_samples=all_samples, + dataset_samples=this_dataset.group.all_samples_ordered()) + target_dataset.get_trait_data(list(sample_data.keys())) + + target_data = [] + for (key, val) in target_dataset.trait_data.items(): + lts = [key] + [str(x) for x in val] + r = ",".join(lts) + target_data.append(r) + + + return run_correlation( + target_data, list(sample_data.values()), method, ",", corr_type, + n_top) + +def __compute_tissue_corr__( + start_vars: dict, corr_type: str, method: str, n_top: int, + target_trait_info: tuple): + """Compute the tissue correlations""" + (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info + trait_symbol_dict = this_dataset.retrieve_genes("Symbol") + corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values( + symbol_list=list(trait_symbol_dict.values())) -def compute_correlation_rust( - start_vars: dict, corr_type: str, method: str = "pearson", - n_top: int = 500): - """function to compute correlation""" - - (this_dataset, this_trait, target_dataset, - sample_data) = create_target_this_trait(start_vars) - - if corr_type == "sample": - - all_samples = json.loads(start_vars["sample_vals"]) - sample_data = get_sample_corr_data(sample_type=start_vars["corr_samples_group"], - all_samples=all_samples, - dataset_samples=this_dataset.group.all_samples_ordered()) + data = parse_tissue_corr_data( + symbol_name=this_trait.symbol, + symbol_dict=get_trait_symbol_and_tissue_values( + symbol_list=[this_trait.symbol]), + dataset_symbols=trait_symbol_dict, + dataset_vals=corr_result_tissue_vals_dict) - target_dataset.get_trait_data(list(sample_data.keys())) + if data: + return run_correlation(data[1], data[0], method, ",", "tissue") + return {} - target_data = [] - for (key, val) in target_dataset.trait_data.items(): - lts = [key] + [str(x) for x in val] - r = ",".join(lts) - target_data.append(r) +def __compute_lit_corr__( + start_vars: dict, corr_type: str, method: str, n_top: int, + target_trait_info: tuple): + """Compute the literature correlations""" + (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info + target_dataset_type = target_dataset.type + this_dataset_type = this_dataset.type + (this_trait_geneid, geneid_dict, species) = do_lit_correlation( + this_trait, this_dataset) + with database_connector() as conn: + return compute_all_lit_correlation( + conn=conn, trait_lists=list(geneid_dict.items()), + species=species, gene_id=this_trait_geneid) + return {} - results = run_correlation( - target_data, list(sample_data.values()), method, ",", corr_type, - n_top) +def compute_correlation_rust( + start_vars: dict, corr_type: str, method: str = "pearson", + n_top: int = 500, compute_all: bool = False): + """function to compute correlation""" + target_trait_info = create_target_this_trait(start_vars) + (this_dataset, this_trait, target_dataset, sample_data) = ( + target_trait_info) + + corr_type_fns = { + "sample": __compute_sample_corr__, + "tissue": __compute_tissue_corr__, + "lit": __compute_lit_corr__ + } + results = corr_type_fns[corr_type]( + start_vars, corr_type, method, n_top, target_trait_info) + top_tissue_results = {} + top_lit_results = {} + if compute_all: # example compute of compute both correlation - top_tissue_results = compute_top_n_tissue(this_dataset,this_trait,results,method) + top_tissue_results = compute_top_n_tissue( + this_dataset,this_trait,results,method) top_lit_results = compute_top_n_lit(results,this_dataset,this_trait) - # merging the results - results = merge_results(results, top_tissue_results, top_lit_results) - - if corr_type == "tissue": - - trait_symbol_dict = this_dataset.retrieve_genes("Symbol") - corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values( - symbol_list=list(trait_symbol_dict.values())) - - data = parse_tissue_corr_data(symbol_name=this_trait.symbol, - symbol_dict=get_trait_symbol_and_tissue_values( - symbol_list=[this_trait.symbol] - ), - dataset_symbols=trait_symbol_dict, - dataset_vals=corr_result_tissue_vals_dict) - - if data: - results = merge_results( - run_correlation(data[1], data[0], method, ",", "tissue"), - {}, {}) - return { - "correlation_results": results, + "correlation_results": merge_results( + results, top_tissue_results, top_lit_results), "this_trait": this_trait.name, "target_dataset": start_vars['corr_dataset'], "return_results": n_top -- cgit v1.2.3