From 51a35646f1eb509b2f94204eb490715ce2abcdf2 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Thu, 11 Aug 2022 06:57:11 +0300 Subject: Update format to prevent tissue correlation from failing Update the data format of returned values so that it conforms with expectatitions. --- wqflask/wqflask/correlation/rust_correlation.py | 53 +++++++++++-------------- 1 file changed, 24 insertions(+), 29 deletions(-) (limited to 'wqflask') diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py index 8a5021cc..4a22af72 100644 --- a/wqflask/wqflask/correlation/rust_correlation.py +++ b/wqflask/wqflask/correlation/rust_correlation.py @@ -1,5 +1,6 @@ """module contains integration code for rust-gn3""" import json +from functools import reduce from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values from wqflask.correlation.correlation_gn3_api import create_target_this_trait from wqflask.correlation.correlation_gn3_api import lit_for_trait_list @@ -11,22 +12,21 @@ from gn3.computations.rust_correlation import parse_tissue_corr_data from gn3.db_utils import database_connector -def compute_top_n_lit(corr_results, this_dataset, this_trait): +def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict: (this_trait_geneid, geneid_dict, species) = do_lit_correlation( this_trait, this_dataset) geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if corr_results.get(trait_name)} + with database_connector() as conn: + return reduce( + lambda acc, corr: {**acc, **corr}, + compute_all_lit_correlation( + conn=conn, trait_lists=list(geneid_dict.items()), + species=species, gene_id=this_trait_geneid), + {}) - conn = database_connector() - - with conn: - - correlation_results = compute_all_lit_correlation( - conn=conn, trait_lists=list(geneid_dict.items()), - species=species, gene_id=this_trait_geneid) - - return correlation_results + return {} def compute_top_n_tissue(this_dataset, this_trait, traits, method): @@ -55,25 +55,19 @@ def compute_top_n_tissue(this_dataset, this_trait, traits, method): return {} -def merge_results(dict_a, dict_b, dict_c): +def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]: """code to merge diff corr into individual dicts a""" - correlation_results = [] - - for (key, val) in dict_a.items(): - - if key in dict_b: - - dict_a[key].update(dict_b[key]) - - if key in dict_c: - - dict_a[key].update(dict_c[key]) - - correlation_results.append({key: dict_a[key]}) - - return correlation_results + def __merge__(trait_name, trait_corrs): + return { + trait_name: { + **trait_corrs, + **dict_b.get(trait_name, {}), + **dict_c.get(trait_name, {}) + } + } + return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()] def compute_correlation_rust( @@ -109,7 +103,7 @@ def compute_correlation_rust( top_lit_results = compute_top_n_lit(results,this_dataset,this_trait) # merging the results - results = merge_results(results,top_tissue_results,top_lit_results) + results = merge_results(results, top_tissue_results, top_lit_results) if corr_type == "tissue": @@ -125,8 +119,9 @@ def compute_correlation_rust( dataset_vals=corr_result_tissue_vals_dict) if data: - results = run_correlation( - data[1], data[0], method, ",", "tissue") + results = merge_results( + run_correlation(data[1], data[0], method, ",", "tissue"), + {}, {}) return { "correlation_results": results, -- cgit v1.2.3