diff options
author | Alexander_Kabui | 2024-01-02 13:21:07 +0300 |
---|---|---|
committer | Alexander_Kabui | 2024-01-02 13:21:07 +0300 |
commit | 70c4201b332e0e2c0d958428086512f291469b87 (patch) | |
tree | aea4fac8782c110fc233c589c3f0f7bd34bada6c /gn2/wqflask/correlation/rust_correlation.py | |
parent | 5092eb42f062b1695c4e39619f0bd74a876cfac2 (diff) | |
parent | 965ce5114d585624d5edb082c710b83d83a3be40 (diff) | |
download | genenetwork2-70c4201b332e0e2c0d958428086512f291469b87.tar.gz |
merge changes
Diffstat (limited to 'gn2/wqflask/correlation/rust_correlation.py')
-rw-r--r-- | gn2/wqflask/correlation/rust_correlation.py | 408 |
1 files changed, 408 insertions, 0 deletions
diff --git a/gn2/wqflask/correlation/rust_correlation.py b/gn2/wqflask/correlation/rust_correlation.py new file mode 100644 index 00000000..a0dcbcb4 --- /dev/null +++ b/gn2/wqflask/correlation/rust_correlation.py @@ -0,0 +1,408 @@ +"""module contains integration code for rust-gn3""" +import json +from functools import reduce + +from gn2.utility.tools import SQL_URI +from gn2.utility.db_tools import mescape +from gn2.utility.db_tools import create_in_clause +from gn2.wqflask.correlation.correlation_functions\ + import get_trait_symbol_and_tissue_values +from gn2.wqflask.correlation.correlation_gn3_api import create_target_this_trait +from gn2.wqflask.correlation.correlation_gn3_api import lit_for_trait_list +from gn2.wqflask.correlation.correlation_gn3_api import do_lit_correlation +from gn2.wqflask.correlation.pre_computes import fetch_text_file +from gn2.wqflask.correlation.pre_computes import read_text_file +from gn2.wqflask.correlation.pre_computes import write_db_to_textfile +from gn2.wqflask.correlation.pre_computes import read_trait_metadata +from gn2.wqflask.correlation.pre_computes import cache_trait_metadata +from gn3.computations.correlations import compute_all_lit_correlation +from gn3.computations.rust_correlation import run_correlation +from gn3.computations.rust_correlation import get_sample_corr_data +from gn3.computations.rust_correlation import parse_tissue_corr_data +from gn3.db_utils import database_connection + +from gn2.wqflask.correlation.exceptions import WrongCorrelationType + + +def query_probes_metadata(dataset, trait_list): + """query traits metadata in bulk for probeset""" + + if not bool(trait_list) or dataset.type != "ProbeSet": + return [] + + with database_connection(SQL_URI) as conn: + with conn.cursor() as cursor: + + query = """ + SELECT ProbeSet.Name,ProbeSet.Chr,ProbeSet.Mb, + ProbeSet.Symbol,ProbeSetXRef.mean, + CONCAT_WS('; ', ProbeSet.description, ProbeSet.Probe_Target_Description) AS description, + ProbeSetXRef.additive,ProbeSetXRef.LRS,Geno.Chr, Geno.Mb + FROM ProbeSet INNER JOIN ProbeSetXRef + ON ProbeSet.Id=ProbeSetXRef.ProbeSetId + INNER JOIN Geno + ON ProbeSetXRef.Locus = Geno.Name + INNER JOIN Species + ON Geno.SpeciesId = Species.Id + WHERE ProbeSet.Name in ({}) AND + Species.Name = %s AND + ProbeSetXRef.ProbeSetFreezeId IN ( + SELECT ProbeSetFreeze.Id + FROM ProbeSetFreeze WHERE ProbeSetFreeze.Name = %s) + """.format(", ".join(["%s"] * len(trait_list))) + + cursor.execute(query, + (tuple(trait_list) + + (dataset.group.species,) + (dataset.name,)) + ) + + return cursor.fetchall() + + +def get_metadata(dataset, traits): + """Retrieve the metadata""" + def __location__(probe_chr, probe_mb): + if probe_mb: + return f"Chr{probe_chr}: {probe_mb:.6f}" + return f"Chr{probe_chr}: ???" + cached_metadata = read_trait_metadata(dataset.name) + to_fetch_metadata = list( + set(traits).difference(list(cached_metadata.keys()))) + if to_fetch_metadata: + results = {**({trait_name: { + "name": trait_name, + "view": True, + "symbol": symbol, + "dataset": dataset.name, + "dataset_name": dataset.shortname, + "mean": mean, + "description": description, + "additive": additive, + "lrs_score": f"{lrs:3.1f}" if lrs else "", + "location": __location__(probe_chr, probe_mb), + "chr": probe_chr, + "mb": probe_mb, + "lrs_location": f'Chr{chr_score}: {mb:{".6f" if mb else ""}}', + "lrs_chr": chr_score, + "lrs_mb": mb + + } for trait_name, probe_chr, probe_mb, symbol, mean, description, + additive, lrs, chr_score, mb + in query_probes_metadata(dataset, to_fetch_metadata)}), **cached_metadata} + cache_trait_metadata(dataset.name, results) + return results + return cached_metadata + + +def chunk_dataset(dataset, steps, name): + + results = [] + + query = """ + SELECT ProbeSetXRef.DataId,ProbeSet.Name + FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze + WHERE ProbeSetFreeze.Name = '{}' AND + ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND + ProbeSetXRef.ProbeSetId = ProbeSet.Id + """.format(name) + + with database_connection(SQL_URI) as conn: + with conn.cursor() as curr: + curr.execute(query) + traits_name_dict = dict(curr.fetchall()) + + for i in range(0, len(dataset), steps): + matrix = list(dataset[i:i + steps]) + results.append([traits_name_dict[matrix[0][0]]] + [str(value) + for (trait_name, strain, value) in matrix]) + return results + + +def compute_top_n_sample(start_vars, dataset, trait_list): + """check if dataset is of type probeset""" + + if dataset.type.lower() != "probeset": + return {} + + def __fetch_sample_ids__(samples_vals, samples_group): + sample_data = get_sample_corr_data( + sample_type=samples_group, + sample_data=json.loads(samples_vals), + dataset_samples=dataset.group.all_samples_ordered()) + + with database_connection(SQL_URI) as conn: + with conn.cursor() as curr: + curr.execute( + """ + SELECT Strain.Name, Strain.Id FROM Strain, Species + WHERE Strain.Name IN {} + and Strain.SpeciesId=Species.Id + and Species.name = '{}' + """.format(create_in_clause(list(sample_data.keys())), + *mescape(dataset.group.species))) + return (sample_data, dict(curr.fetchall())) + + (sample_data, sample_ids) = __fetch_sample_ids__( + start_vars["sample_vals"], start_vars["corr_samples_group"]) + + if len(trait_list) == 0: + return {} + + with database_connection(SQL_URI) as conn: + with conn.cursor() as curr: + # fetching strain data in bulk + query = ( + "SELECT * from ProbeSetData " + f"WHERE StrainID IN ({', '.join(['%s'] * len(sample_ids))}) " + "AND Id IN (" + " SELECT ProbeSetXRef.DataId " + " FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) " + " WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id " + " AND ProbeSetFreeze.Name = %s " + " AND ProbeSet.Name " + f" IN ({', '.join(['%s'] * len(trait_list))}) " + " AND ProbeSet.Id = ProbeSetXRef.ProbeSetId" + ")") + curr.execute( + query, + tuple(sample_ids.values()) + (dataset.name,) + tuple(trait_list)) + + corr_data = chunk_dataset( + list(curr.fetchall()), len(sample_ids.values()), dataset.name) + + return run_correlation( + corr_data, list(sample_data.values()), "pearson", ",") + + +def compute_top_n_lit(corr_results, target_dataset, this_trait) -> dict: + if not __datasets_compatible_p__(this_trait.dataset, target_dataset, "lit"): + return {} + + (this_trait_geneid, geneid_dict, species) = do_lit_correlation( + this_trait, target_dataset) + + geneid_dict = {trait_name: geneid for (trait_name, geneid) + in geneid_dict.items() if + corr_results.get(trait_name)} + with database_connection(SQL_URI) as conn: + return reduce( + lambda acc, corr: {**acc, **corr}, + compute_all_lit_correlation( + conn=conn, trait_lists=list(geneid_dict.items()), + species=species, gene_id=this_trait_geneid), + {}) + + return {} + + +def compute_top_n_tissue(target_dataset, this_trait, traits, method): + # refactor lots of rpt + if not __datasets_compatible_p__(this_trait.dataset, target_dataset, "tissue"): + return {} + + trait_symbol_dict = dict({ + trait_name: symbol + for (trait_name, symbol) + in target_dataset.retrieve_genes("Symbol").items() + if traits.get(trait_name)}) + + corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values( + symbol_list=list(trait_symbol_dict.values())) + + data = parse_tissue_corr_data(symbol_name=this_trait.symbol, + symbol_dict=get_trait_symbol_and_tissue_values( + symbol_list=[this_trait.symbol]), + dataset_symbols=trait_symbol_dict, + dataset_vals=corr_result_tissue_vals_dict) + + if data and data[0]: + return run_correlation( + data[1], data[0], method, ",", "tissue") + + return {} + + +def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]: + """code to merge diff corr into individual dicts + a""" + + def __merge__(trait_name, trait_corrs): + return { + trait_name: { + **trait_corrs, + **dict_b.get(trait_name, {}), + **dict_c.get(trait_name, {}) + } + } + return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()] + + +def __compute_sample_corr__( + start_vars: dict, corr_type: str, method: str, n_top: int, + target_trait_info: tuple): + """Compute the sample correlations""" + (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info + + if this_dataset.group.f1list != None: + this_dataset.group.samplelist += this_dataset.group.f1list + + if this_dataset.group.parlist != None: + this_dataset.group.samplelist += this_dataset.group.parlist + + sample_data = get_sample_corr_data( + sample_type=start_vars["corr_samples_group"], + sample_data=json.loads(start_vars["sample_vals"]), + dataset_samples=this_dataset.group.all_samples_ordered()) + + if not bool(sample_data): + return {} + + if target_dataset.type == "ProbeSet" and start_vars.get("use_cache") == "true": + with database_connection(SQL_URI) as conn: + file_path = fetch_text_file(target_dataset.name, conn) + if file_path: + (sample_vals, target_data) = read_text_file( + sample_data, file_path) + + return run_correlation(target_data, sample_vals, + method, ",", corr_type, n_top) + + write_db_to_textfile(target_dataset.name, conn) + file_path = fetch_text_file(target_dataset.name, conn) + if file_path: + (sample_vals, target_data) = read_text_file( + sample_data, file_path) + + return run_correlation(target_data, sample_vals, + method, ",", corr_type, n_top) + + target_dataset.get_trait_data(list(sample_data.keys())) + + def __merge_key_and_values__(rows, current): + wo_nones = [value for value in current[1]] + if len(wo_nones) > 0: + return rows + [[current[0]] + wo_nones] + return rows + + target_data = reduce( + __merge_key_and_values__, target_dataset.trait_data.items(), []) + + if len(target_data) == 0: + return {} + + return run_correlation( + target_data, list(sample_data.values()), method, ",", corr_type, + n_top) + + +def __datasets_compatible_p__(trait_dataset, target_dataset, corr_method): + return not ( + corr_method in ("tissue", "Tissue r", "Literature r", "lit") + and (trait_dataset.type == "ProbeSet" and + target_dataset.type in ("Publish", "Geno"))) + + +def __compute_tissue_corr__( + start_vars: dict, corr_type: str, method: str, n_top: int, + target_trait_info: tuple): + """Compute the tissue correlations""" + (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info + if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type): + raise WrongCorrelationType(this_trait, target_dataset, corr_type) + + trait_symbol_dict = target_dataset.retrieve_genes("Symbol") + corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values( + symbol_list=list(trait_symbol_dict.values())) + + data = parse_tissue_corr_data( + symbol_name=this_trait.symbol, + symbol_dict=get_trait_symbol_and_tissue_values( + symbol_list=[this_trait.symbol]), + dataset_symbols=trait_symbol_dict, + dataset_vals=corr_result_tissue_vals_dict) + + if data: + return run_correlation(data[1], data[0], method, ",", "tissue") + return {} + + +def __compute_lit_corr__( + start_vars: dict, corr_type: str, method: str, n_top: int, + target_trait_info: tuple): + """Compute the literature correlations""" + (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info + if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type): + raise WrongCorrelationType(this_trait, target_dataset, corr_type) + + target_dataset_type = target_dataset.type + this_dataset_type = this_dataset.type + (this_trait_geneid, geneid_dict, species) = do_lit_correlation( + this_trait, target_dataset) + + with database_connection(SQL_URI) as conn: + return reduce( + lambda acc, lit: {**acc, **lit}, + compute_all_lit_correlation( + conn=conn, trait_lists=list(geneid_dict.items()), + species=species, gene_id=this_trait_geneid)[:n_top], + {}) + return {} + + +def compute_correlation_rust( + start_vars: dict, corr_type: str, method: str = "pearson", + n_top: int = 500, should_compute_all: bool = False): + """function to compute correlation""" + target_trait_info = create_target_this_trait(start_vars) + (this_dataset, this_trait, target_dataset, sample_data) = ( + target_trait_info) + if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type): + raise WrongCorrelationType(this_trait, target_dataset, corr_type) + + # Replace this with `match ...` once we hit Python 3.10 + corr_type_fns = { + "sample": __compute_sample_corr__, + "tissue": __compute_tissue_corr__, + "lit": __compute_lit_corr__ + } + + results = corr_type_fns[corr_type]( + start_vars, corr_type, method, n_top, target_trait_info) + + # END: Replace this with `match ...` once we hit Python 3.10 + + top_a = top_b = {} + + if should_compute_all: + + if corr_type == "sample": + if this_dataset.type == "ProbeSet": + top_a = compute_top_n_tissue( + target_dataset, this_trait, results, method) + + top_b = compute_top_n_lit(results, target_dataset, this_trait) + else: + pass + + elif corr_type == "lit": + + # currently fails for lit + + top_a = compute_top_n_sample( + start_vars, target_dataset, list(results.keys())) + top_b = compute_top_n_tissue( + target_dataset, this_trait, results, method) + + else: + + top_a = compute_top_n_sample( + start_vars, target_dataset, list(results.keys())) + + return { + "correlation_results": merge_results( + results, top_a, top_b), + "this_trait": this_trait.name, + "target_dataset": start_vars['corr_dataset'], + "traits_metadata": get_metadata(target_dataset, list(results.keys())), + "return_results": n_top + } |