aboutsummaryrefslogtreecommitdiff
"""module contains code for correlations"""
import math
import multiprocessing
from contextlib import closing
from multiprocessing import Pool, cpu_count

from typing import List
from typing import Tuple
from typing import Optional
from typing import Callable
from typing import Generator

import scipy.stats
import pingouin as pg


def map_shared_keys_to_values(target_sample_keys: List,
                              target_sample_vals: dict) -> List:
    """Function to construct target dataset data items given common shared keys
    and trait sample-list values for example given keys

    >>>>>>>>>> ["BXD1", "BXD2", "BXD5", "BXD6", "BXD8", "BXD9"] and value
    object as "HCMA:_AT": [4.1, 5.6, 3.2, 1.1, 4.4, 2.2],TXD_AT": [6.2, 5.7,
    3.6, 1.5, 4.2, 2.3]} return results should be a list of dicts mapping the
    shared keys to the trait values

    """
    target_dataset_data = []

    for trait_id, sample_values in target_sample_vals.items():
        target_trait_dict = dict(zip(target_sample_keys, sample_values))

        target_trait = {
            "trait_id": trait_id,
            "trait_sample_data": target_trait_dict
        }

        target_dataset_data.append(target_trait)

    return target_dataset_data


def normalize_values(a_values: List, b_values: List) -> Generator:
    """
    :param a_values: list of primary strain values
    :param b_values: a list of target strain values
    :return: yield 2 values if none of them is none
    """

    for a_val, b_val in zip(a_values, b_values):
        if (a_val is not None) and (b_val is not None):
            yield a_val, b_val


def compute_corr_coeff_p_value(primary_values: List, target_values: List,
                               corr_method: str) -> Tuple[float, float]:
    """Given array like inputs calculate the primary and target_value methods ->
pearson,spearman and biweight mid correlation return value is rho and p_value

    """
    corr_mapping = {
        "bicor": do_bicor,
        "pearson": scipy.stats.pearsonr,
        "spearman": scipy.stats.spearmanr
    }
    use_corr_method = corr_mapping.get(corr_method, "spearman")
    corr_coefficient, p_val = use_corr_method(primary_values, target_values)
    return (corr_coefficient, p_val)


def compute_sample_r_correlation(trait_name, corr_method, trait_vals,
                                 target_samples_vals) -> Optional[
                                     Tuple[str, float, float, int]]:
    """Given a primary trait values and target trait values calculate the
    correlation coeff and p value

    """

    try:
        normalized_traits_vals, normalized_target_vals = list(
            zip(*list(normalize_values(trait_vals, target_samples_vals))))
        num_overlap = len(normalized_traits_vals)
    except ValueError:
        return None

    if num_overlap > 5:

        (corr_coefficient, p_value) =\
            compute_corr_coeff_p_value(primary_values=normalized_traits_vals,
                                       target_values=normalized_target_vals,
                                       corr_method=corr_method)

        if corr_coefficient is not None and not math.isnan(corr_coefficient):
            return (trait_name, corr_coefficient, p_value, num_overlap)
    return None


def do_bicor(x_val, y_val) -> Tuple[float, float]:
    """Not implemented method for doing biweight mid correlation use astropy stats
package :not packaged in guix

    """

    results = pg.corr(x_val, y_val, method="bicor")
    corr_coeff = results["r"].values[0]
    p_val = results["p-val"].values[0]
    return (corr_coeff, p_val)


def filter_shared_sample_keys(this_samplelist,
                              target_samplelist) -> Generator:
    """Given primary and target sample-list for two base and target trait select
    filter the values using the shared keys

    """
    for key, value in target_samplelist.items():
        if key in this_samplelist:
            yield this_samplelist[key], value


def fast_compute_all_sample_correlation(this_trait,
                                        target_dataset,
                                        corr_method="pearson") -> List:
    """Given a trait data sample-list and target__datasets compute all sample
    correlation
    this functions uses multiprocessing if not use the normal fun

    """
    # xtodo fix trait_name currently returning single one
    # pylint: disable-msg=too-many-locals
    this_trait_samples = this_trait["trait_sample_data"]
    corr_results = []
    processed_values = []
    for target_trait in target_dataset:
        trait_name = target_trait.get("trait_id")
        target_trait_data = target_trait["trait_sample_data"]

        try:
            this_vals, target_vals = list(zip(*list(filter_shared_sample_keys(
                this_trait_samples, target_trait_data))))

            processed_values.append(
                (trait_name, corr_method, this_vals, target_vals))
        except ValueError:
            continue

    with closing(multiprocessing.Pool()) as pool:
        results = pool.starmap(compute_sample_r_correlation, processed_values)

        for sample_correlation in results:
            if sample_correlation is not None:
                (trait_name, corr_coefficient, p_value,
                 num_overlap) = sample_correlation
                corr_result = {
                    "corr_coefficient": corr_coefficient,
                    "p_value": p_value,
                    "num_overlap": num_overlap
                }

                corr_results.append({trait_name: corr_result})
    return sorted(
        corr_results,
        key=lambda trait_name: -abs(list(trait_name.values())[0]["corr_coefficient"]))

def compute_one_sample_correlation(trait_samples, target_trait, corr_method):
    """Compute sample correlation against a single trait."""
    trait_name = target_trait.get("trait_id")
    target_trait_data = target_trait["trait_sample_data"]
    try:
        this_vals, target_vals = list(zip(*list(filter_shared_sample_keys(
            trait_samples, target_trait_data))))

        sample_correlation = compute_sample_r_correlation(
            trait_name=trait_name,
            corr_method=corr_method,
            trait_vals=this_vals,
            target_samples_vals=target_vals)
        if sample_correlation is not None:
            (trait_name, corr_coefficient,
             p_value, num_overlap) = sample_correlation
            return {trait_name: {
                "corr_coefficient": corr_coefficient,
                "p_value": p_value,
                "num_overlap": num_overlap
            }}
    except ValueError:
        # case where no matching strain names
        return None
    return None

def compute_all_sample_correlation(this_trait,
                                   target_dataset,
                                   corr_method="pearson") -> List:
    """Temp function to benchmark with compute_all_sample_r alternative to
    compute_all_sample_r where we use multiprocessing

    """
    this_trait_samples = this_trait["trait_sample_data"]
    with Pool(processes=(cpu_count() - 1)) as pool:
        return sorted(
            (
                corr for corr in
                pool.starmap(
                    compute_one_sample_correlation,
                    ((this_trait_samples, trait, corr_method) for trait in target_dataset))
                if corr is not None),
            key=lambda trait_name: -abs(
                list(trait_name.values())[0]["corr_coefficient"]))


def tissue_correlation_for_trait(
        primary_tissue_vals: List,
        target_tissues_values: List,
        corr_method: str,
        trait_id: str,
        compute_corr_p_value: Callable = compute_corr_coeff_p_value) -> dict:
    """Given a primary tissue values for a trait and the target tissues values
    compute the correlation_cooeff and p value the input required are arrays
    output -> List containing Dicts with corr_coefficient value, P_value and
    also the tissue numbers is len(primary) == len(target)

    """

    # ax :todo assertion that length one one target tissue ==primary_tissue

    (tissue_corr_coefficient,
     p_value) = compute_corr_p_value(primary_values=primary_tissue_vals,
                                     target_values=target_tissues_values,
                                     corr_method=corr_method)

    tiss_corr_result = {trait_id: {
        "tissue_corr": tissue_corr_coefficient,
        "tissue_number": len(primary_tissue_vals),
        "tissue_p_val": p_value}}

    return tiss_corr_result


def fetch_lit_correlation_data(
        conn,
        input_mouse_gene_id: Optional[str],
        gene_id: str,
        mouse_gene_id: Optional[str] = None) -> Tuple[str, Optional[float]]:
    """Given input trait mouse gene id and mouse gene id fetch the lit
    corr_data

    """
    if mouse_gene_id is not None and ";" not in mouse_gene_id:
        query = """
        SELECT VALUE
        FROM  LCorrRamin3
        WHERE GeneId1='%s' and
        GeneId2='%s'
        """

        query_values = (str(mouse_gene_id), str(input_mouse_gene_id))

        cursor = conn.cursor()

        cursor.execute(query_formatter(query,
                                       *query_values))
        results = cursor.fetchone()
        lit_corr_results = None
        if results is not None:
            lit_corr_results = results
        else:
            cursor = conn.cursor()
            cursor.execute(query_formatter(query,
                                           *tuple(reversed(query_values))))
            lit_corr_results = cursor.fetchone()
        lit_results = (gene_id, lit_corr_results[0])\
            if lit_corr_results else (gene_id, None)
        return lit_results
    return (gene_id, None)


def lit_correlation_for_trait(
        conn,
        target_trait_lists: List,
        species: Optional[str] = None,
        trait_gene_id: Optional[str] = None) -> List:
    """given species,base trait gene id fetch the lit corr results from the db\
    output is float for lit corr results """
    fetched_lit_corr_results = []
    this_trait_mouse_gene_id = map_to_mouse_gene_id(conn=conn,
                                                    species=species,
                                                    gene_id=trait_gene_id)
    for (trait_name, target_trait_gene_id) in target_trait_lists:
        corr_results = {}
        if target_trait_gene_id:
            target_mouse_gene_id = map_to_mouse_gene_id(
                conn=conn,
                species=species,
                gene_id=target_trait_gene_id)
            fetched_corr_data = fetch_lit_correlation_data(
                conn=conn,
                input_mouse_gene_id=this_trait_mouse_gene_id,
                gene_id=target_trait_gene_id,
                mouse_gene_id=target_mouse_gene_id)
            dict_results = dict(zip(("gene_id", "lit_corr"),
                                    fetched_corr_data))
            corr_results[trait_name] = dict_results
            fetched_lit_corr_results.append(corr_results)
    return fetched_lit_corr_results


def query_formatter(query_string: str, *query_values):
    """Formatter query string given the unformatted query string and the
    respectibe values.Assumes number of placeholders is equal to the number of
    query values

    """
    # xtodo escape sql queries
    return query_string % (query_values)


def map_to_mouse_gene_id(conn, species: Optional[str],
                         gene_id: Optional[str]) -> Optional[str]:
    """Given a species which is not mouse map the gene_id\
    to respective mouse gene id"""
    if None in (species, gene_id):
        return None
    if species == "mouse":
        return gene_id
    cursor = conn.cursor()
    query = """SELECT mouse
                FROM GeneIDXRef
                WHERE '%s' = '%s'"""
    query_values = (species, gene_id)
    cursor.execute(query_formatter(query,
                                   *query_values))
    results = cursor.fetchone()
    mouse_gene_id = results.mouse if results is not None else None
    return mouse_gene_id


def compute_all_lit_correlation(conn, trait_lists: List,
                                species: str, gene_id):
    """Function that acts as an abstraction for
    lit_correlation_for_trait"""

    def __sorter__(trait_name):
        val = list(trait_name.values())[0]["lit_corr"]
        try:
            return (0, -abs(val))
        except TypeError:
            return (1, val)

    lit_results = lit_correlation_for_trait(
        conn=conn,
        target_trait_lists=trait_lists,
        species=species,
        trait_gene_id=gene_id)
    sorted_lit_results = sorted(lit_results, key=__sorter__)

    return sorted_lit_results


def compute_tissue_correlation(primary_tissue_dict: dict,
                               target_tissues_data: dict,
                               corr_method: str):
    """Function acts as an abstraction for tissue_correlation_for_trait\
    required input are target tissue object and primary tissue trait\
    target tissues data contains the trait_symbol_dict and symbol_tissue_vals
    """
    tissues_results = []
    primary_tissue_vals = primary_tissue_dict["tissue_values"]
    traits_symbol_dict = target_tissues_data["trait_symbol_dict"]
    symbol_tissue_vals_dict = target_tissues_data["symbol_tissue_vals_dict"]
    target_tissues_list = process_trait_symbol_dict(
        traits_symbol_dict, symbol_tissue_vals_dict)
    for target_tissue_obj in target_tissues_list:
        trait_id = target_tissue_obj.get("trait_id")
        target_tissue_vals = target_tissue_obj.get("tissue_values")

        tissue_result = tissue_correlation_for_trait(
            primary_tissue_vals=primary_tissue_vals,
            target_tissues_values=target_tissue_vals,
            trait_id=trait_id,
            corr_method=corr_method)
        tissues_results.append(tissue_result)
    return sorted(
        tissues_results,
        key=lambda trait_name: -abs(list(trait_name.values())[0]["tissue_corr"]))


def process_trait_symbol_dict(trait_symbol_dict, symbol_tissue_vals_dict) -> List:
    """Method for processing trait symbol dict given the symbol tissue values

    """
    traits_tissue_vals = []
    for (trait, symbol) in trait_symbol_dict.items():
        if symbol is not None:
            target_symbol = symbol.lower()
            if target_symbol in symbol_tissue_vals_dict:
                trait_tissue_val = symbol_tissue_vals_dict[target_symbol]
                target_tissue_dict = {"trait_id": trait,
                                      "symbol": target_symbol,
                                      "tissue_values": trait_tissue_val}
                traits_tissue_vals.append(target_tissue_dict)
    return traits_tissue_vals


def fast_compute_tissue_correlation(primary_tissue_dict: dict,
                                    target_tissues_data: dict,
                                    corr_method: str):
    """Experimental function that uses multiprocessing for computing tissue
    correlation

    """
    tissues_results = []
    primary_tissue_vals = primary_tissue_dict["tissue_values"]
    traits_symbol_dict = target_tissues_data["trait_symbol_dict"]
    symbol_tissue_vals_dict = target_tissues_data["symbol_tissue_vals_dict"]
    target_tissues_list = process_trait_symbol_dict(
        traits_symbol_dict, symbol_tissue_vals_dict)
    processed_values = []

    for target_tissue_obj in target_tissues_list:
        trait_id = target_tissue_obj.get("trait_id")

        target_tissue_vals = target_tissue_obj.get("tissue_values")
        processed_values.append(
            (primary_tissue_vals, target_tissue_vals, corr_method, trait_id))

    with multiprocessing.Pool(4) as pool:
        results = pool.starmap(
            tissue_correlation_for_trait, processed_values)
        for result in results:
            tissues_results.append(result)

    return sorted(
        tissues_results,
        key=lambda trait_name: -abs(list(trait_name.values())[0]["tissue_corr"]))