diff options
-rw-r--r-- | wqflask/wqflask/correlation/rust_correlation.py | 38 |
1 files changed, 14 insertions, 24 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py index 18f8c622..0661fa42 100644 --- a/wqflask/wqflask/correlation/rust_correlation.py +++ b/wqflask/wqflask/correlation/rust_correlation.py @@ -15,6 +15,9 @@ from wqflask.correlation.pre_computes import read_text_file from wqflask.correlation.pre_computes import write_db_to_textfile from wqflask.correlation.pre_computes import read_trait_metadata from wqflask.correlation.pre_computes import cache_trait_metadata +from wqflask.correlation.pre_computes import parse_lmdb_dataset + +from wqflask.correlation.pre_computes import read_lmdb_strain_files from gn3.computations.correlations import compute_all_lit_correlation from gn3.computations.rust_correlation import run_correlation from gn3.computations.rust_correlation import get_sample_corr_data @@ -30,7 +33,7 @@ def query_probes_metadata(dataset, trait_list): if not bool(trait_list) or dataset.type != "ProbeSet": return [] - with database_connection(SQL_URI) as conn: + with database_connection() as conn: with conn.cursor() as cursor: query = """ @@ -103,7 +106,7 @@ def chunk_dataset(dataset, steps, name): ProbeSetXRef.ProbeSetId = ProbeSet.Id """.format(name) - with database_connection(SQL_URI) as conn: + with database_connection() as conn: with conn.cursor() as curr: curr.execute(query) traits_name_dict = dict(curr.fetchall()) @@ -127,7 +130,7 @@ def compute_top_n_sample(start_vars, dataset, trait_list): sample_data=json.loads(samples_vals), dataset_samples=dataset.group.all_samples_ordered()) - with database_connection(SQL_URI) as conn: + with database_connection() as conn: with conn.cursor() as curr: curr.execute( """ @@ -145,7 +148,7 @@ def compute_top_n_sample(start_vars, dataset, trait_list): if len(trait_list) == 0: return {} - with database_connection(SQL_URI) as conn: + with database_connection() as conn: with conn.cursor() as curr: # fetching strain data in bulk query = ( @@ -181,7 +184,7 @@ def compute_top_n_lit(corr_results, target_dataset, this_trait) -> dict: geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if corr_results.get(trait_name)} - with database_connection(SQL_URI) as conn: + with database_connection() as conn: return reduce( lambda acc, corr: {**acc, **corr}, compute_all_lit_correlation( @@ -253,26 +256,13 @@ def __compute_sample_corr__( if not bool(sample_data): return {} - if target_dataset.type == "ProbeSet" and start_vars.get("use_cache") == "true": - with database_connection(SQL_URI) as conn: - file_path = fetch_text_file(target_dataset.name, conn) - if file_path: - (sample_vals, target_data) = read_text_file( - sample_data, file_path) - + with database_connection() as conn: + results = read_lmdb_strain_files("ProbeSets",target_dataset.name) + if results: + (sample_vals,target_data) = parse_lmdb_dataset(results[0],sample_data,results[1]) return run_correlation(target_data, sample_vals, - method, ",", corr_type, n_top) - - write_db_to_textfile(target_dataset.name, conn) - file_path = fetch_text_file(target_dataset.name, conn) - if file_path: - (sample_vals, target_data) = read_text_file( - sample_data, file_path) - - return run_correlation(target_data, sample_vals, - method, ",", corr_type, n_top) - + method, ",", corr_type, n_top) target_dataset.get_trait_data(list(sample_data.keys())) def __merge_key_and_values__(rows, current): @@ -336,7 +326,7 @@ def __compute_lit_corr__( (this_trait_geneid, geneid_dict, species) = do_lit_correlation( this_trait, target_dataset) - with database_connection(SQL_URI) as conn: + with database_connection() as conn: return reduce( lambda acc, lit: {**acc, **lit}, compute_all_lit_correlation( |