diff options
-rw-r--r-- | wqflask/wqflask/correlation/rust_correlation.py | 63 |
1 files changed, 30 insertions, 33 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py index 95354994..7d796e70 100644 --- a/wqflask/wqflask/correlation/rust_correlation.py +++ b/wqflask/wqflask/correlation/rust_correlation.py @@ -3,7 +3,8 @@ import json from functools import reduce from utility.db_tools import mescape from utility.db_tools import create_in_clause -from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values +from wqflask.correlation.correlation_functions\ + import get_trait_symbol_and_tissue_values from wqflask.correlation.correlation_gn3_api import create_target_this_trait from wqflask.correlation.correlation_gn3_api import lit_for_trait_list from wqflask.correlation.correlation_gn3_api import do_lit_correlation @@ -14,9 +15,7 @@ from gn3.computations.rust_correlation import parse_tissue_corr_data from gn3.db_utils import database_connector - - -def chunk_dataset(dataset,steps,name): +def chunk_dataset(dataset, steps, name): results = [] @@ -39,7 +38,8 @@ def chunk_dataset(dataset,steps,name): matrix = list(dataset[i:i + steps]) trait_name = traits_name_dict[matrix[0][0]] - strains = [trait_name] + [str(value) for (trait_name, strain, value) in matrix] + strains = [trait_name] + [str(value) + for (trait_name, strain, value) in matrix] results.append(",".join(strains)) return results @@ -48,18 +48,16 @@ def chunk_dataset(dataset,steps,name): def compute_top_n_sample(start_vars, dataset, trait_list): """check if dataset is of type probeset""" - if dataset.type.lower()!= "probeset": - return {} + if dataset.type.lower() != "probeset": + return {} def __fetch_sample_ids__(samples_vals, samples_group): - all_samples = json.loads(samples_vals) sample_data = get_sample_corr_data( sample_type=samples_group, all_samples=all_samples, dataset_samples=dataset.group.all_samples_ordered()) - with database_connector() as conn: curr = conn.cursor() @@ -75,21 +73,20 @@ def compute_top_n_sample(start_vars, dataset, trait_list): ) - return (sample_data,dict(curr.fetchall())) - - (sample_data,sample_ids) = __fetch_sample_ids__(start_vars["sample_vals"], start_vars["corr_samples_group"]) - + return (sample_data, dict(curr.fetchall())) + (sample_data, sample_ids) = __fetch_sample_ids__( + start_vars["sample_vals"], start_vars["corr_samples_group"]) with database_connector() as conn: curr = conn.cursor() - #fetching strain data in bulk + # fetching strain data in bulk curr.execute( - """ + """ SELECT * from ProbeSetData where StrainID in {} and id in (SELECT ProbeSetXRef.DataId @@ -98,21 +95,25 @@ def compute_top_n_sample(start_vars, dataset, trait_list): and ProbeSetFreeze.Name = '{}' and ProbeSet.Name in {} and ProbeSet.Id = ProbeSetXRef.ProbeSetId) - """.format(create_in_clause(list(sample_ids.values())),dataset.name,create_in_clause(trait_list)) + """.format(create_in_clause(list(sample_ids.values())), dataset.name, create_in_clause(trait_list)) ) - corr_data = chunk_dataset(list(curr.fetchall()),len(sample_ids.values()),dataset.name) + corr_data = chunk_dataset(list(curr.fetchall()), len( + sample_ids.values()), dataset.name) - return run_correlation(corr_data,list(sample_data.values()),"pearson",",") + return run_correlation(corr_data, + list(sample_data.values()), + "pearson", ",") def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict: (this_trait_geneid, geneid_dict, species) = do_lit_correlation( this_trait, this_dataset) - geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if + geneid_dict = {trait_name: geneid for (trait_name, geneid) + in geneid_dict.items() if corr_results.get(trait_name)} with database_connector() as conn: return reduce( @@ -166,8 +167,6 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]: return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()] - - def __compute_sample_corr__( start_vars: dict, corr_type: str, method: str, n_top: int, target_trait_info: tuple): @@ -247,7 +246,6 @@ def compute_correlation_rust( # END: Replace this with `match ...` once we hit Python 3.10 - top_a = top_b = {} if compute_all: @@ -255,28 +253,27 @@ def compute_correlation_rust( if corr_type == "sample": top_a = compute_top_n_tissue( - this_dataset, this_trait, results, method) - - top_b = compute_top_n_lit(results, this_dataset, this_trait) + this_dataset, this_trait, results, method) + top_b = compute_top_n_lit(results, this_dataset, this_trait) elif corr_type == "lit": - #currently fails for lit + # currently fails for lit - top_a = compute_top_n_sample(start_vars,target_dataset,list(results.keys())) - top_b = compute_top_n_tissue( - this_dataset, this_trait, results, method) + top_a = compute_top_n_sample( + start_vars, target_dataset, list(results.keys())) + top_b = compute_top_n_tissue( + this_dataset, this_trait, results, method) else: - top_a = compute_top_n_sample(start_vars,target_dataset,list(results.keys())) + top_a = compute_top_n_sample( + start_vars, target_dataset, list(results.keys())) top_b = compute_top_n_lit(results, this_dataset, this_trait) - - - return { + return { "correlation_results": merge_results( results, top_a, top_b), "this_trait": this_trait.name, |