"""module contains integration code for rust-gn3"""
import json
from functools import reduce
from gn2.utility.tools import SQL_URI
from gn2.utility.db_tools import mescape
from gn2.utility.db_tools import create_in_clause
from gn2.wqflask.correlation.correlation_functions\
import get_trait_symbol_and_tissue_values
from gn2.wqflask.correlation.correlation_gn3_api import create_target_this_trait
from gn2.wqflask.correlation.correlation_gn3_api import lit_for_trait_list
from gn2.wqflask.correlation.correlation_gn3_api import do_lit_correlation
from gn2.wqflask.correlation.pre_computes import fetch_text_file
from gn2.wqflask.correlation.pre_computes import read_text_file
from gn2.wqflask.correlation.pre_computes import write_db_to_textfile
from gn2.wqflask.correlation.pre_computes import read_trait_metadata
from gn2.wqflask.correlation.pre_computes import cache_trait_metadata
from gn3.computations.correlations import compute_all_lit_correlation
from gn3.computations.rust_correlation import run_correlation
from gn3.computations.rust_correlation import get_sample_corr_data
from gn3.computations.rust_correlation import parse_tissue_corr_data
from gn3.db_utils import database_connection
from gn2.wqflask.correlation.exceptions import WrongCorrelationType
def query_probes_metadata(dataset, trait_list):
"""query traits metadata in bulk for probeset"""
if not bool(trait_list) or dataset.type != "ProbeSet":
return []
with database_connection(SQL_URI) as conn:
with conn.cursor() as cursor:
query = """
SELECT ProbeSet.Name,ProbeSet.Chr,ProbeSet.Mb,
ProbeSet.Symbol,ProbeSetXRef.mean,
CONCAT_WS('; ', ProbeSet.description, ProbeSet.Probe_Target_Description) AS description,
ProbeSetXRef.additive,ProbeSetXRef.LRS,Geno.Chr, Geno.Mb
FROM ProbeSet INNER JOIN ProbeSetXRef
ON ProbeSet.Id=ProbeSetXRef.ProbeSetId
INNER JOIN Geno
ON ProbeSetXRef.Locus = Geno.Name
INNER JOIN Species
ON Geno.SpeciesId = Species.Id
WHERE ProbeSet.Name in ({}) AND
Species.Name = %s AND
ProbeSetXRef.ProbeSetFreezeId IN (
SELECT ProbeSetFreeze.Id
FROM ProbeSetFreeze WHERE ProbeSetFreeze.Name = %s)
""".format(", ".join(["%s"] * len(trait_list)))
cursor.execute(query,
(tuple(trait_list) +
(dataset.group.species,) + (dataset.name,))
)
return cursor.fetchall()
def get_metadata(dataset, traits):
"""Retrieve the metadata"""
def __location__(probe_chr, probe_mb):
if probe_mb:
return f"Chr{probe_chr}: {probe_mb:.6f}"
return f"Chr{probe_chr}: ???"
cached_metadata = read_trait_metadata(dataset.name)
to_fetch_metadata = list(
set(traits).difference(list(cached_metadata.keys())))
if to_fetch_metadata:
results = {**({trait_name: {
"name": trait_name,
"view": True,
"symbol": symbol,
"dataset": dataset.name,
"dataset_name": dataset.shortname,
"mean": mean,
"description": description,
"additive": additive,
"lrs_score": f"{lrs:3.1f}" if lrs else "",
"location": __location__(probe_chr, probe_mb),
"chr": probe_chr,
"mb": probe_mb,
"lrs_location": f'Chr{chr_score}: {mb:{".6f" if mb else ""}}',
"lrs_chr": chr_score,
"lrs_mb": mb
} for trait_name, probe_chr, probe_mb, symbol, mean, description,
additive, lrs, chr_score, mb
in query_probes_metadata(dataset, to_fetch_metadata)}), **cached_metadata}
cache_trait_metadata(dataset.name, results)
return results
return cached_metadata
def chunk_dataset(dataset, steps, name):
results = []
query = """
SELECT ProbeSetXRef.DataId,ProbeSet.Name
FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze
WHERE ProbeSetFreeze.Name = '{}' AND
ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND
ProbeSetXRef.ProbeSetId = ProbeSet.Id
""".format(name)
with database_connection(SQL_URI) as conn:
with conn.cursor() as curr:
curr.execute(query)
traits_name_dict = dict(curr.fetchall())
for i in range(0, len(dataset), steps):
matrix = list(dataset[i:i + steps])
results.append([traits_name_dict[matrix[0][0]]] + [str(value)
for (trait_name, strain, value) in matrix])
return results
def compute_top_n_sample(start_vars, dataset, trait_list):
"""check if dataset is of type probeset"""
if dataset.type.lower() != "probeset":
return {}
def __fetch_sample_ids__(samples_vals, samples_group):
sample_data = get_sample_corr_data(
sample_type=samples_group,
sample_data=json.loads(samples_vals),
dataset_samples=dataset.group.all_samples_ordered())
with database_connection(SQL_URI) as conn:
with conn.cursor() as curr:
curr.execute(
"""
SELECT Strain.Name, Strain.Id FROM Strain, Species
WHERE Strain.Name IN {}
and Strain.SpeciesId=Species.Id
and Species.name = '{}'
""".format(create_in_clause(list(sample_data.keys())),
*mescape(dataset.group.species)))
return (sample_data, dict(curr.fetchall()))
(sample_data, sample_ids) = __fetch_sample_ids__(
start_vars["sample_vals"], start_vars["corr_samples_group"])
if len(trait_list) == 0:
return {}
with database_connection(SQL_URI) as conn:
with conn.cursor() as curr:
# fetching strain data in bulk
query = (
"SELECT * from ProbeSetData "
f"WHERE StrainID IN ({', '.join(['%s'] * len(sample_ids))}) "
"AND Id IN ("
" SELECT ProbeSetXRef.DataId "
" FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) "
" WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id "
" AND ProbeSetFreeze.Name = %s "
" AND ProbeSet.Name "
f" IN ({', '.join(['%s'] * len(trait_list))}) "
" AND ProbeSet.Id = ProbeSetXRef.ProbeSetId"
")")
curr.execute(
query,
tuple(sample_ids.values()) + (dataset.name,) + tuple(trait_list))
corr_data = chunk_dataset(
list(curr.fetchall()), len(sample_ids.values()), dataset.name)
return run_correlation(
corr_data, list(sample_data.values()), "pearson", ",")
def compute_top_n_lit(corr_results, target_dataset, this_trait) -> dict:
if not __datasets_compatible_p__(this_trait.dataset, target_dataset, "lit"):
return {}
(this_trait_geneid, geneid_dict, species) = do_lit_correlation(
this_trait, target_dataset)
geneid_dict = {trait_name: geneid for (trait_name, geneid)
in geneid_dict.items() if
corr_results.get(trait_name)}
with database_connection(SQL_URI) as conn:
return reduce(
lambda acc, corr: {**acc, **corr},
compute_all_lit_correlation(
conn=conn, trait_lists=list(geneid_dict.items()),
species=species, gene_id=this_trait_geneid),
{})
return {}
def compute_top_n_tissue(target_dataset, this_trait, traits, method):
# refactor lots of rpt
if not __datasets_compatible_p__(this_trait.dataset, target_dataset, "tissue"):
return {}
trait_symbol_dict = dict({
trait_name: symbol
for (trait_name, symbol)
in target_dataset.retrieve_genes("Symbol").items()
if traits.get(trait_name)})
corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
symbol_list=list(trait_symbol_dict.values()))
data = parse_tissue_corr_data(symbol_name=this_trait.symbol,
symbol_dict=get_trait_symbol_and_tissue_values(
symbol_list=[this_trait.symbol]),
dataset_symbols=trait_symbol_dict,
dataset_vals=corr_result_tissue_vals_dict)
if data and data[0]:
return run_correlation(
data[1], data[0], method, ",", "tissue")
return {}
def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
"""code to merge diff corr into individual dicts
a"""
def __merge__(trait_name, trait_corrs):
return {
trait_name: {
**trait_corrs,
**dict_b.get(trait_name, {}),
**dict_c.get(trait_name, {})
}
}
return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
def __compute_sample_corr__(
start_vars: dict, corr_type: str, method: str, n_top: int,
target_trait_info: tuple):
"""Compute the sample correlations"""
(this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
if this_dataset.group.f1list != None:
this_dataset.group.samplelist += this_dataset.group.f1list
if this_dataset.group.parlist != None:
this_dataset.group.samplelist += this_dataset.group.parlist
sample_data = get_sample_corr_data(
sample_type=start_vars["corr_samples_group"],
sample_data=json.loads(start_vars["sample_vals"]),
dataset_samples=this_dataset.group.all_samples_ordered())
if not bool(sample_data):
return {}
if target_dataset.type == "ProbeSet" and start_vars.get("use_cache") == "true":
with database_connection(SQL_URI) as conn:
file_path = fetch_text_file(target_dataset.name, conn)
if file_path:
(sample_vals, target_data) = read_text_file(
sample_data, file_path)
return run_correlation(target_data, sample_vals,
method, ",", corr_type, n_top)
write_db_to_textfile(target_dataset.name, conn)
file_path = fetch_text_file(target_dataset.name, conn)
if file_path:
(sample_vals, target_data) = read_text_file(
sample_data, file_path)
return run_correlation(target_data, sample_vals,
method, ",", corr_type, n_top)
target_dataset.get_trait_data(list(sample_data.keys()))
def __merge_key_and_values__(rows, current):
wo_nones = [value for value in current[1]]
if len(wo_nones) > 0:
return rows + [[current[0]] + wo_nones]
return rows
target_data = reduce(
__merge_key_and_values__, target_dataset.trait_data.items(), [])
if len(target_data) == 0:
return {}
return run_correlation(
target_data, list(sample_data.values()), method, ",", corr_type,
n_top)
def __datasets_compatible_p__(trait_dataset, target_dataset, corr_method):
return not (
corr_method in ("tissue", "Tissue r", "Literature r", "lit")
and (trait_dataset.type == "ProbeSet" and
target_dataset.type in ("Publish", "Geno")))
def __compute_tissue_corr__(
start_vars: dict, corr_type: str, method: str, n_top: int,
target_trait_info: tuple):
"""Compute the tissue correlations"""
(this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type):
raise WrongCorrelationType(this_trait, target_dataset, corr_type)
trait_symbol_dict = target_dataset.retrieve_genes("Symbol")
corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
symbol_list=list(trait_symbol_dict.values()))
data = parse_tissue_corr_data(
symbol_name=this_trait.symbol,
symbol_dict=get_trait_symbol_and_tissue_values(
symbol_list=[this_trait.symbol]),
dataset_symbols=trait_symbol_dict,
dataset_vals=corr_result_tissue_vals_dict)
if data:
return run_correlation(data[1], data[0], method, ",", "tissue")
return {}
def __compute_lit_corr__(
start_vars: dict, corr_type: str, method: str, n_top: int,
target_trait_info: tuple):
"""Compute the literature correlations"""
(this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type):
raise WrongCorrelationType(this_trait, target_dataset, corr_type)
target_dataset_type = target_dataset.type
this_dataset_type = this_dataset.type
(this_trait_geneid, geneid_dict, species) = do_lit_correlation(
this_trait, target_dataset)
with database_connection(SQL_URI) as conn:
return reduce(
lambda acc, lit: {**acc, **lit},
compute_all_lit_correlation(
conn=conn, trait_lists=list(geneid_dict.items()),
species=species, gene_id=this_trait_geneid)[:n_top],
{})
return {}
def compute_correlation_rust(
start_vars: dict, corr_type: str, method: str = "pearson",
n_top: int = 500, should_compute_all: bool = False):
"""function to compute correlation"""
target_trait_info = create_target_this_trait(start_vars)
(this_dataset, this_trait, target_dataset, sample_data) = (
target_trait_info)
if not __datasets_compatible_p__(this_dataset, target_dataset, corr_type):
raise WrongCorrelationType(this_trait, target_dataset, corr_type)
# Replace this with `match ...` once we hit Python 3.10
corr_type_fns = {
"sample": __compute_sample_corr__,
"tissue": __compute_tissue_corr__,
"lit": __compute_lit_corr__
}
results = corr_type_fns[corr_type](
start_vars, corr_type, method, n_top, target_trait_info)
# END: Replace this with `match ...` once we hit Python 3.10
top_a = top_b = {}
if should_compute_all:
if corr_type == "sample":
if this_dataset.type == "ProbeSet":
top_a = compute_top_n_tissue(
target_dataset, this_trait, results, method)
top_b = compute_top_n_lit(results, target_dataset, this_trait)
else:
pass
elif corr_type == "lit":
# currently fails for lit
top_a = compute_top_n_sample(
start_vars, target_dataset, list(results.keys()))
top_b = compute_top_n_tissue(
target_dataset, this_trait, results, method)
else:
top_a = compute_top_n_sample(
start_vars, target_dataset, list(results.keys()))
return {
"correlation_results": merge_results(
results, top_a, top_b),
"this_trait": this_trait.name,
"target_dataset": start_vars['corr_dataset'],
"traits_metadata": get_metadata(target_dataset, list(results.keys())),
"return_results": n_top
}