"""module contains code for correlations""" import math import multiprocessing from contextlib import closing from multiprocessing import Pool, cpu_count from typing import List from typing import Tuple from typing import Optional from typing import Callable from typing import Generator import scipy.stats import pingouin as pg def map_shared_keys_to_values(target_sample_keys: List, target_sample_vals: dict) -> List: """Function to construct target dataset data items given common shared keys and trait sample-list values for example given keys >>>>>>>>>> ["BXD1", "BXD2", "BXD5", "BXD6", "BXD8", "BXD9"] and value object as "HCMA:_AT": [4.1, 5.6, 3.2, 1.1, 4.4, 2.2],TXD_AT": [6.2, 5.7, 3.6, 1.5, 4.2, 2.3]} return results should be a list of dicts mapping the shared keys to the trait values """ target_dataset_data = [] for trait_id, sample_values in target_sample_vals.items(): target_trait_dict = dict(zip(target_sample_keys, sample_values)) target_trait = { "trait_id": trait_id, "trait_sample_data": target_trait_dict } target_dataset_data.append(target_trait) return target_dataset_data def normalize_values(a_values: List, b_values: List) -> Generator: """ :param a_values: list of primary strain values :param b_values: a list of target strain values :return: yield 2 values if none of them is none """ for a_val, b_val in zip(a_values, b_values): if (a_val is not None) and (b_val is not None): yield a_val, b_val def compute_corr_coeff_p_value(primary_values: List, target_values: List, corr_method: str) -> Tuple[float, float]: """Given array like inputs calculate the primary and target_value methods -> pearson,spearman and biweight mid correlation return value is rho and p_value """ corr_mapping = { "bicor": do_bicor, "pearson": scipy.stats.pearsonr, "spearman": scipy.stats.spearmanr } use_corr_method = corr_mapping.get(corr_method, "spearman") corr_coefficient, p_val = use_corr_method(primary_values, target_values) return (corr_coefficient, p_val) def compute_sample_r_correlation(trait_name, corr_method, trait_vals, target_samples_vals) -> Optional[ Tuple[str, float, float, int]]: """Given a primary trait values and target trait values calculate the correlation coeff and p value """ try: normalized_traits_vals, normalized_target_vals = list( zip(*list(normalize_values(trait_vals, target_samples_vals)))) num_overlap = len(normalized_traits_vals) except ValueError: return None if num_overlap > 5: (corr_coefficient, p_value) =\ compute_corr_coeff_p_value(primary_values=normalized_traits_vals, target_values=normalized_target_vals, corr_method=corr_method) if corr_coefficient is not None and not math.isnan(corr_coefficient): return (trait_name, corr_coefficient, p_value, num_overlap) return None def do_bicor(x_val, y_val) -> Tuple[float, float]: """Not implemented method for doing biweight mid correlation use astropy stats package :not packaged in guix """ results = pg.corr(x_val, y_val, method="bicor") corr_coeff = results["r"].values[0] p_val = results["p-val"].values[0] return (corr_coeff, p_val) def filter_shared_sample_keys(this_samplelist, target_samplelist) -> Generator: """Given primary and target sample-list for two base and target trait select filter the values using the shared keys """ for key, value in target_samplelist.items(): if key in this_samplelist: yield this_samplelist[key], value def fast_compute_all_sample_correlation(this_trait, target_dataset, corr_method="pearson") -> List: """Given a trait data sample-list and target__datasets compute all sample correlation this functions uses multiprocessing if not use the normal fun """ # xtodo fix trait_name currently returning single one # pylint: disable-msg=too-many-locals this_trait_samples = this_trait["trait_sample_data"] corr_results = [] processed_values = [] for target_trait in target_dataset: trait_name = target_trait.get("trait_id") target_trait_data = target_trait["trait_sample_data"] try: this_vals, target_vals = list(zip(*list(filter_shared_sample_keys( this_trait_samples, target_trait_data)))) processed_values.append( (trait_name, corr_method, this_vals, target_vals)) except ValueError: continue with closing(multiprocessing.Pool()) as pool: results = pool.starmap(compute_sample_r_correlation, processed_values) for sample_correlation in results: if sample_correlation is not None: (trait_name, corr_coefficient, p_value, num_overlap) = sample_correlation corr_result = { "corr_coefficient": corr_coefficient, "p_value": p_value, "num_overlap": num_overlap } corr_results.append({trait_name: corr_result}) return sorted( corr_results, key=lambda trait_name: -abs(list(trait_name.values())[0]["corr_coefficient"])) def compute_one_sample_correlation(trait_samples, target_trait, corr_method): """Compute sample correlation against a single trait.""" trait_name = target_trait.get("trait_id") target_trait_data = target_trait["trait_sample_data"] try: this_vals, target_vals = list(zip(*list(filter_shared_sample_keys( trait_samples, target_trait_data)))) sample_correlation = compute_sample_r_correlation( trait_name=trait_name, corr_method=corr_method, trait_vals=this_vals, target_samples_vals=target_vals) if sample_correlation is not None: (trait_name, corr_coefficient, p_value, num_overlap) = sample_correlation return {trait_name: { "corr_coefficient": corr_coefficient, "p_value": p_value, "num_overlap": num_overlap }} except ValueError: # case where no matching strain names return None return None def compute_all_sample_correlation(this_trait, target_dataset, corr_method="pearson") -> List: """Temp function to benchmark with compute_all_sample_r alternative to compute_all_sample_r where we use multiprocessing """ this_trait_samples = this_trait["trait_sample_data"] with Pool(processes=(cpu_count() - 1)) as pool: return sorted( ( corr for corr in pool.starmap( compute_one_sample_correlation, ((this_trait_samples, trait, corr_method) for trait in target_dataset)) if corr is not None), key=lambda trait_name: -abs( list(trait_name.values())[0]["corr_coefficient"])) def tissue_correlation_for_trait( primary_tissue_vals: List, target_tissues_values: List, corr_method: str, trait_id: str, compute_corr_p_value: Callable = compute_corr_coeff_p_value) -> dict: """Given a primary tissue values for a trait and the target tissues values compute the correlation_cooeff and p value the input required are arrays output -> List containing Dicts with corr_coefficient value, P_value and also the tissue numbers is len(primary) == len(target) """ # ax :todo assertion that length one one target tissue ==primary_tissue (tissue_corr_coefficient, p_value) = compute_corr_p_value(primary_values=primary_tissue_vals, target_values=target_tissues_values, corr_method=corr_method) tiss_corr_result = {trait_id: { "tissue_corr": tissue_corr_coefficient, "tissue_number": len(primary_tissue_vals), "tissue_p_val": p_value}} return tiss_corr_result def fetch_lit_correlation_data( conn, input_mouse_gene_id: Optional[str], gene_id: str, mouse_gene_id: Optional[str] = None) -> Tuple[str, Optional[float]]: """Given input trait mouse gene id and mouse gene id fetch the lit corr_data """ if mouse_gene_id is not None and ";" not in mouse_gene_id: query = """ SELECT VALUE FROM LCorrRamin3 WHERE GeneId1='%s' and GeneId2='%s' """ query_values = (str(mouse_gene_id), str(input_mouse_gene_id)) cursor = conn.cursor() cursor.execute(query_formatter(query, *query_values)) results = cursor.fetchone() lit_corr_results = None if results is not None: lit_corr_results = results else: cursor = conn.cursor() cursor.execute(query_formatter(query, *tuple(reversed(query_values)))) lit_corr_results = cursor.fetchone() lit_results = (gene_id, lit_corr_results[0])\ if lit_corr_results else (gene_id, None) return lit_results return (gene_id, None) def lit_correlation_for_trait( conn, target_trait_lists: List, species: Optional[str] = None, trait_gene_id: Optional[str] = None) -> List: """given species,base trait gene id fetch the lit corr results from the db\ output is float for lit corr results """ fetched_lit_corr_results = [] this_trait_mouse_gene_id = map_to_mouse_gene_id(conn=conn, species=species, gene_id=trait_gene_id) for (trait_name, target_trait_gene_id) in target_trait_lists: corr_results = {} if target_trait_gene_id: target_mouse_gene_id = map_to_mouse_gene_id( conn=conn, species=species, gene_id=target_trait_gene_id) fetched_corr_data = fetch_lit_correlation_data( conn=conn, input_mouse_gene_id=this_trait_mouse_gene_id, gene_id=target_trait_gene_id, mouse_gene_id=target_mouse_gene_id) dict_results = dict(zip(("gene_id", "lit_corr"), fetched_corr_data)) corr_results[trait_name] = dict_results fetched_lit_corr_results.append(corr_results) return fetched_lit_corr_results def query_formatter(query_string: str, *query_values): """Formatter query string given the unformatted query string and the respectibe values.Assumes number of placeholders is equal to the number of query values """ # xtodo escape sql queries return query_string % (query_values) def map_to_mouse_gene_id(conn, species: Optional[str], gene_id: Optional[str]) -> Optional[str]: """Given a species which is not mouse map the gene_id\ to respective mouse gene id""" if None in (species, gene_id): return None if species == "mouse": return gene_id cursor = conn.cursor() query = """SELECT mouse FROM GeneIDXRef WHERE '%s' = '%s'""" query_values = (species, gene_id) cursor.execute(query_formatter(query, *query_values)) results = cursor.fetchone() mouse_gene_id = results.mouse if results is not None else None return mouse_gene_id def compute_all_lit_correlation(conn, trait_lists: List, species: str, gene_id): """Function that acts as an abstraction for lit_correlation_for_trait""" def __sorter__(trait_name): val = list(trait_name.values())[0]["lit_corr"] try: return (0, -abs(val)) except TypeError: return (1, val) lit_results = lit_correlation_for_trait( conn=conn, target_trait_lists=trait_lists, species=species, trait_gene_id=gene_id) sorted_lit_results = sorted(lit_results, key=__sorter__) return sorted_lit_results def compute_tissue_correlation(primary_tissue_dict: dict, target_tissues_data: dict, corr_method: str): """Function acts as an abstraction for tissue_correlation_for_trait\ required input are target tissue object and primary tissue trait\ target tissues data contains the trait_symbol_dict and symbol_tissue_vals """ tissues_results = [] primary_tissue_vals = primary_tissue_dict["tissue_values"] traits_symbol_dict = target_tissues_data["trait_symbol_dict"] symbol_tissue_vals_dict = target_tissues_data["symbol_tissue_vals_dict"] target_tissues_list = process_trait_symbol_dict( traits_symbol_dict, symbol_tissue_vals_dict) for target_tissue_obj in target_tissues_list: trait_id = target_tissue_obj.get("trait_id") target_tissue_vals = target_tissue_obj.get("tissue_values") tissue_result = tissue_correlation_for_trait( primary_tissue_vals=primary_tissue_vals, target_tissues_values=target_tissue_vals, trait_id=trait_id, corr_method=corr_method) tissues_results.append(tissue_result) return sorted( tissues_results, key=lambda trait_name: -abs(list(trait_name.values())[0]["tissue_corr"])) def process_trait_symbol_dict(trait_symbol_dict, symbol_tissue_vals_dict) -> List: """Method for processing trait symbol dict given the symbol tissue values """ traits_tissue_vals = [] for (trait, symbol) in trait_symbol_dict.items(): if symbol is not None: target_symbol = symbol.lower() if target_symbol in symbol_tissue_vals_dict: trait_tissue_val = symbol_tissue_vals_dict[target_symbol] target_tissue_dict = {"trait_id": trait, "symbol": target_symbol, "tissue_values": trait_tissue_val} traits_tissue_vals.append(target_tissue_dict) return traits_tissue_vals def fast_compute_tissue_correlation(primary_tissue_dict: dict, target_tissues_data: dict, corr_method: str): """Experimental function that uses multiprocessing for computing tissue correlation """ tissues_results = [] primary_tissue_vals = primary_tissue_dict["tissue_values"] traits_symbol_dict = target_tissues_data["trait_symbol_dict"] symbol_tissue_vals_dict = target_tissues_data["symbol_tissue_vals_dict"] target_tissues_list = process_trait_symbol_dict( traits_symbol_dict, symbol_tissue_vals_dict) processed_values = [] for target_tissue_obj in target_tissues_list: trait_id = target_tissue_obj.get("trait_id") target_tissue_vals = target_tissue_obj.get("tissue_values") processed_values.append( (primary_tissue_vals, target_tissue_vals, corr_method, trait_id)) with multiprocessing.Pool(4) as pool: results = pool.starmap( tissue_correlation_for_trait, processed_values) for result in results: tissues_results.append(result) return sorted( tissues_results, key=lambda trait_name: -abs(list(trait_name.values())[0]["tissue_corr"]))