diff options
-rw-r--r-- | wqflask/wqflask/correlation/rust_correlation.py | 71 |
1 files changed, 29 insertions, 42 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py index 37e2ba76..5c22efbf 100644 --- a/wqflask/wqflask/correlation/rust_correlation.py +++ b/wqflask/wqflask/correlation/rust_correlation.py @@ -28,11 +28,9 @@ def chunk_dataset(dataset, steps, name): """.format(name) with database_connector() as conn: - curr = conn.cursor() - - curr.execute(query) - - traits_name_dict = dict(curr.fetchall()) + with conn.cursor() as curr: + curr.execute(query) + traits_name_dict = dict(curr.fetchall()) for i in range(0, len(dataset), steps): matrix = list(dataset[i:i + steps]) @@ -52,60 +50,49 @@ def compute_top_n_sample(start_vars, dataset, trait_list): return {} def __fetch_sample_ids__(samples_vals, samples_group): - all_samples = json.loads(samples_vals) sample_data = get_sample_corr_data( sample_type=samples_group, all_samples=all_samples, dataset_samples=dataset.group.all_samples_ordered()) with database_connector() as conn: - - curr = conn.cursor() - - curr.execute( - """ + with conn.cursor() as curr: + curr.execute( + """ SELECT Strain.Name, Strain.Id FROM Strain, Species WHERE Strain.Name IN {} and Strain.SpeciesId=Species.Id and Species.name = '{}' """.format(create_in_clause(list(sample_data.keys())), - *mescape(dataset.group.species)) - - ) - - return (sample_data, dict(curr.fetchall())) + *mescape(dataset.group.species))) + return (sample_data, dict(curr.fetchall())) (sample_data, sample_ids) = __fetch_sample_ids__( start_vars["sample_vals"], start_vars["corr_samples_group"]) with database_connector() as conn: + with conn.cursor() as curr: + # fetching strain data in bulk + curr.execute( + """ + SELECT * from ProbeSetData + where StrainID in {} + and id in (SELECT ProbeSetXRef.DataId + FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) + WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id + and ProbeSetFreeze.Name = '{}' + and ProbeSet.Name in {} + and ProbeSet.Id = ProbeSetXRef.ProbeSetId) + """.format( + create_in_clause(list(sample_ids.values())), + dataset.name, + create_in_clause(trait_list))) + + corr_data = chunk_dataset( + list(curr.fetchall()), len(sample_ids.values()), dataset.name) - curr = conn.cursor() - - # fetching strain data in bulk - - curr.execute( - - """ - SELECT * from ProbeSetData - where StrainID in {} - and id in (SELECT ProbeSetXRef.DataId - FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) - WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id - and ProbeSetFreeze.Name = '{}' - and ProbeSet.Name in {} - and ProbeSet.Id = ProbeSetXRef.ProbeSetId) - """.format(create_in_clause(list(sample_ids.values())), dataset.name, create_in_clause(trait_list)) - - - ) - - corr_data = chunk_dataset(list(curr.fetchall()), len( - sample_ids.values()), dataset.name) - - return run_correlation(corr_data, - list(sample_data.values()), - "pearson", ",") + return run_correlation( + corr_data, list(sample_data.values()), "pearson", ",") def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict: |