aboutsummaryrefslogtreecommitdiff
"""gn3.computations.rust_correlation unittests"""

import json
import os
import pytest

from gn3.computations.rust_correlation import generate_json_file
from gn3.computations.rust_correlation import generate_input_files
from gn3.computations.rust_correlation import get_samples
from gn3.computations.rust_correlation import parse_correlation_output


@pytest.mark.unit_test
def test_generate_input():
    """test generating text files"""

    test_dataset = [
        ["14_at",12.1,14.1,None],
        ["15_at",12.2,14.1,None],
        ["17_at",12.1,14.1,11.4]

    ]
    expected = ["14_at,12.1,14.1,", "15_at,12.2,14.1,", "17_at,12.1,14.1,11.4"]

    (_tmp_dir, tmp_file) = generate_input_files(test_dataset,
                                                output_dir="/tmp")

    with open(tmp_file, "r", encoding="utf-8") as file_reader:
        test_results = [line.rstrip() for line in file_reader]

    os.remove(tmp_file)

    assert test_results == expected


# @pytest.mark.unit_test
def test_json_file():
    """test for generating json files """

    tmp_file, _tmp_json_file = generate_json_file(
        tmp_dir="/tmp/correlation",
        tmp_file="/data.txt",
        method="pearson",
        x_vals="12.1,11.3,16.5,7.5,3.2",
        delimiter=",")

    with open(tmp_file, "r+", encoding="utf-8") as file_reader:
        results = json.load(file_reader)

    assert results == {
        "method": "pearson",
        "file_path": "/data.txt",
        "x_vals": "12.1,11.3,16.5,7.5,3.2",
        "file_delimiter": ","}


@pytest.mark.unit_test
def test_parse_results():
    """test for parsing file results"""

    raw_data = [
        ["63.62", "0.97", "0.00", "12"],
        ["19", "-0.96", "0.22", "12"],
        ["77.92", "-0.94", "0.31", "12"],
        ["84.04", "0.94", "0.11", "12"],
        ["23", "-0.91", "0.11", "12"]
    ]

    expected_results = {
        trait:  {
            "num_overlap":  num_overlap,
            "corr_coefficient": corr_coeff,
            "p_value": p_val}

        for (trait, corr_coeff, p_val, num_overlap) in raw_data
    }

    assert (parse_correlation_output(
        "tests/unit/computations/data/correlation/sorted_results.txt",
        "sample", len(raw_data))
        == expected_results)


@pytest.mark.unit_test
def test_get_samples_no_excluded():
    """test for getting sample data"""

    al_samples = {
        "BXD": "12.1",
        "BXD3": "16.1",
        "BXD4": " x",
        "BXD6": "1.1",
        "BXD5": "1.37",
        "BXD11": "1.91",
        "BXD31": "1.1"

    }

    base = [
        "BXD",
        "BXD4",
        "BXD7",
        "BXD31"
    ]

    assert get_samples(all_samples=al_samples,
                       base_samples=base,
                       excluded=[]) == {
        "BXD": 12.1,
        "BXD31": 1.1
    }


@pytest.mark.unit_test
def test_get_samples():
    """test for getting samples with exluded"""

    al_samples = {
        "BXD": "12.1",
        "BXD3": "16.1",
        "BXD4": " x",
        "BXD5": "1.1",
        "BXD6": "1.37",
        "BXD11": "1.91",
        "BXD31": "1.1"

    }

    assert get_samples(all_samples=al_samples,
                       base_samples=["BXD", "BXD4", "BXD5", "BXD6",
                                     "BXD11"
                                     ], excluded=["BXD", "BXD11"]), {
        "BXD5": 1.1,
        "BXD6": 1.37
    }