diff options
Diffstat (limited to 'wqflask')
-rw-r--r-- | wqflask/wqflask/correlation/rust_correlation.py | 124 |
1 files changed, 119 insertions, 5 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py index 3628f549..94720f54 100644 --- a/wqflask/wqflask/correlation/rust_correlation.py +++ b/wqflask/wqflask/correlation/rust_correlation.py @@ -1,6 +1,9 @@ """module contains integration code for rust-gn3""" import json from functools import reduce +from flask import g +from utility.db_tools import mescape +from utility.db_tools import create_in_clause from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values from wqflask.correlation.correlation_gn3_api import create_target_this_trait from wqflask.correlation.correlation_gn3_api import lit_for_trait_list @@ -12,6 +15,106 @@ from gn3.computations.rust_correlation import parse_tissue_corr_data from gn3.db_utils import database_connector + + +def chunk_dataset(dataset,steps,name): + + results = [] + + query = """ + SELECT ProbeSetXRef.DataId,ProbeSet.Name + FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze + WHERE ProbeSetFreeze.Name = '{}' AND + ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND + ProbeSetXRef.ProbeSetId = ProbeSet.Id + """.format(name) + + traits_name_dict = dict(g.db.execute(query).fetchall()) + + + for i in range(0, len(dataset), steps): + matrix = list(dataset[i:i + steps]) + trait_name = traits_name_dict[matrix[0][0]] + + strains = [trait_name] + [str(value) for (trait_name, strain, value) in matrix] + results.append(",".join(strains)) + + breakpoint() + return results + + +def compute_top_n_sample(start_vars, dataset, trait_list): + """only if dataset is of type probeset""" + + + + + def __fetch_sample_ids__(samples_vals, samples_group): + + + all_samples = json.loads(samples_vals) + sample_data = get_sample_corr_data( + sample_type=samples_group, all_samples=all_samples, + dataset_samples=dataset.group.all_samples_ordered()) + + + with database_connector() as conn: + + curr = conn.cursor() + + curr.execute( + """ + SELECT Strain.Name, Strain.Id FROM Strain, Species + WHERE Strain.Name IN {} + and Strain.SpeciesId=Species.Id + and Species.name = '{}' + """.format(create_in_clause(list(sample_data.keys())), + *mescape(dataset.group.species)) + + ) + + return dict(curr.fetchall()) + + + + + + + + + + + + ty = __fetch_sample_ids__(start_vars["sample_vals"], start_vars["corr_samples_group"]) + + + + with database_connector() as conn: + + curr = conn.cursor() + + curr.execute( + + """ + SELECT * from ProbeSetData + where StrainID in {} + and id in (SELECT ProbeSetXRef.DataId + FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) + WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id + and ProbeSetFreeze.Name = '{}' + and ProbeSet.Name in {} + and ProbeSet.Id = ProbeSetXRef.ProbeSetId) + """.format(create_in_clause(list(ty.values())),dataset.name,create_in_clause(trait_list)) + + + ) + + + + + return chunk_dataset(list(curr.fetchall()),len(ty.values()),dataset.name) + + def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict: (this_trait_geneid, geneid_dict, species) = do_lit_correlation( this_trait, this_dataset) @@ -69,6 +172,7 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]: } return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()] + def __compute_sample_corr__( start_vars: dict, corr_type: str, method: str, n_top: int, target_trait_info: tuple): @@ -86,11 +190,11 @@ def __compute_sample_corr__( r = ",".join(lts) target_data.append(r) - return run_correlation( target_data, list(sample_data.values()), method, ",", corr_type, n_top) + def __compute_tissue_corr__( start_vars: dict, corr_type: str, method: str, n_top: int, target_trait_info: tuple): @@ -111,6 +215,7 @@ def __compute_tissue_corr__( return run_correlation(data[1], data[0], method, ",", "tissue") return {} + def __compute_lit_corr__( start_vars: dict, corr_type: str, method: str, n_top: int, target_trait_info: tuple): @@ -127,6 +232,7 @@ def __compute_lit_corr__( species=species, gene_id=this_trait_geneid) return {} + def compute_correlation_rust( start_vars: dict, corr_type: str, method: str = "pearson", n_top: int = 500, compute_all: bool = False): @@ -135,7 +241,7 @@ def compute_correlation_rust( (this_dataset, this_trait, target_dataset, sample_data) = ( target_trait_info) - ## Replace this with `match ...` once we hit Python 3.10 + # Replace this with `match ...` once we hit Python 3.10 corr_type_fns = { "sample": __compute_sample_corr__, "tissue": __compute_tissue_corr__, @@ -143,15 +249,23 @@ def compute_correlation_rust( } results = corr_type_fns[corr_type]( start_vars, corr_type, method, n_top, target_trait_info) - ## END: Replace this with `match ...` once we hit Python 3.10 + # END: Replace this with `match ...` once we hit Python 3.10 top_tissue_results = {} top_lit_results = {} + + + results = compute_top_n_sample(start_vars,target_dataset,list(results.keys())) + + + + breakpoint() + if compute_all: # example compute of compute both correlation top_tissue_results = compute_top_n_tissue( - this_dataset,this_trait,results,method) - top_lit_results = compute_top_n_lit(results,this_dataset,this_trait) + this_dataset, this_trait, results, method) + top_lit_results = compute_top_n_lit(results, this_dataset, this_trait) return { "correlation_results": merge_results( |