aboutsummaryrefslogtreecommitdiff
"""module contains code integration  correlation implemented in rust here

https://github.com/Alexanderlacuna/correlation_rust

"""

import subprocess
import json
import csv
import os

from gn3.computations.qtlreaper import create_output_directory
from gn3.chancy import random_string
from gn3.settings import CORRELATION_COMMAND
from gn3.settings import TMPDIR


def generate_input_files(dataset: list[str],
                         output_dir: str = TMPDIR) -> tuple[str, str]:
    """function generates outputfiles and inputfiles"""
    tmp_dir = f"{output_dir}/correlation"
    create_output_directory(tmp_dir)
    tmp_file = os.path.join(tmp_dir, f"{random_string(10)}.txt")
    with open(tmp_file, "w", encoding="utf-8") as op_file:
        writer = csv.writer(
            op_file, delimiter=",", dialect="unix", quotechar="",
            quoting=csv.QUOTE_NONE, escapechar="\\")
        writer.writerows(dataset)

    return (tmp_dir, tmp_file)


def generate_json_file(
        tmp_dir, tmp_file, method, delimiter, x_vals) -> tuple[str, str]:
    """generating json input file required by cargo"""
    tmp_json_file = os.path.join(tmp_dir, f"{random_string(10)}.json")
    output_file = os.path.join(tmp_dir, f"{random_string(10)}.txt")

    with open(tmp_json_file, "w", encoding="utf-8") as outputfile:
        json.dump({
            "method": method,
            "file_path": tmp_file,
            "x_vals": x_vals,
            "sample_values": "bxd1",
            "output_file": output_file,
            "file_delimiter": delimiter
        }, outputfile)
    return (output_file, tmp_json_file)


def run_correlation(
        dataset, trait_vals: str, method: str, delimiter: str,
        corr_type: str = "sample", top_n: int = 500):
    """entry function to call rust correlation"""

    # pylint: disable=too-many-arguments
    (tmp_dir, tmp_file) = generate_input_files(dataset)
    (output_file, json_file) = generate_json_file(
        tmp_dir=tmp_dir, tmp_file=tmp_file, method=method, delimiter=delimiter,
        x_vals=trait_vals)
    command_list = [CORRELATION_COMMAND, json_file, TMPDIR]
    try:
        subprocess.run(command_list, check=True, capture_output=True)
    except subprocess.CalledProcessError as cpe:
        actual_command = (
            os.readlink(CORRELATION_COMMAND)
            if os.path.islink(CORRELATION_COMMAND)
            else CORRELATION_COMMAND)
        raise Exception(command_list, actual_command, cpe.stdout) from cpe

    return parse_correlation_output(output_file, corr_type, top_n)


def parse_correlation_output(result_file: str,
                             corr_type: str, top_n: int = 500) -> dict:
    """parse file output """
    # current types are sample and tissue
    def __parse_line__(line):
        (trait_name, corr_coeff, p_val, num_overlap) = line.rstrip().split(",")
        if corr_type == "sample":
            return (trait_name, {
                "num_overlap": num_overlap,
                "corr_coefficient": corr_coeff,
                "p_value": p_val
            })

        if corr_type == "tissue":
            return (
                trait_name,
                {
                    "tissue_corr": corr_coeff,
                    "tissue_number": num_overlap,
                    "tissue_p_val": p_val
                })

        return tuple(trait_name, {})

    with open(result_file, "r", encoding="utf-8") as file_reader:
        return dict([
            __parse_line__(line)
            for idx, line in enumerate(file_reader) if idx < top_n])

    return {}


def get_samples(all_samples: dict[str, str],
                base_samples: list[str],
                excluded: list[str]):
    """filter null samples and excluded samples"""

    data = {}

    if base_samples:
        fls = [
            sm for sm in base_samples if sm not in excluded]
        for sample in fls:
            if sample in all_samples:
                smp_val = all_samples[sample].strip()
                if smp_val.lower() != "x":
                    data[sample] = float(smp_val)

        return data

    return({key: float(val.strip()) for (key, val) in all_samples.items()
            if key not in excluded and val.lower().strip() != "x"})


def get_sample_corr_data(sample_type: str,
                         sample_data: dict[str, str],
                         dataset_samples: list[str]) -> dict[str, str]:
    """dependeing on the sample_type fetch the correct sample data """

    if sample_type == "samples_primary":

        data = get_samples(all_samples=sample_data,
                           base_samples=dataset_samples, excluded=[])

    elif sample_type == "samples_other":
        data = get_samples(
            all_samples=sample_data,
            base_samples=[],
            excluded=dataset_samples)
    else:
        data = get_samples(
            all_samples=sample_data, base_samples=[], excluded=[])
    return data


def parse_tissue_corr_data(symbol_name: str,
                           symbol_dict: dict,
                           dataset_symbols: dict,
                           dataset_vals: dict):
    """parset tissue data input"""

    results = None

    if symbol_name and symbol_name.lower() in symbol_dict:
        x_vals = [float(val) for val in symbol_dict[symbol_name.lower()]]

        data = []

        for (trait, symbol) in dataset_symbols.items():
            try:
                corr_vals = dataset_vals.get(symbol.lower())

                if corr_vals:
                    data.append([str(trait)] + corr_vals)

            except AttributeError:
                pass

        results = (x_vals, data)

    return results


def run_lmdb_correlation(lmdb_info:dict):
    """
    how is this correlation different from the one above:
    1)  all the preparsing is done in rust which is consinderally fast
    2) file are read directly format of the file is lmdb
    can also  compute correlation or unprocessed csv files
    """
    tmp_json_file = os.path.join(tmp_dir, f"{random_string(10)}.json")
    output_file = os.path.join(tmp_dir, f"{random_string(10)}.txt")
    with open(tmp_json_file, "w", encoding="utf-8") as outputfile:
        json.dump({"outputfile": output_file, **lmdb_info})
    command_list = [CORRELATION_COMMAND, json_file, TMPDIR]
    try:
        subprocess.run(command_list, check=True, capture_output=True)
    except subprocess.CalledProcessError as cpe:
        actual_command = (
            os.readlink(CORRELATION_COMMAND)
            if os.path.islink(CORRELATION_COMMAND)
            else CORRELATION_COMMAND)
        raise Exception(command_list, actual_command, cpe.stdout) from cpe
    return parse_correlation_output(output_file, corr_type)