"""gn3.computations.rust_correlation unittests"""
import json
import os
import pytest
from gn3.computations.rust_correlation import generate_json_file
from gn3.computations.rust_correlation import generate_input_files
from gn3.computations.rust_correlation import get_samples
from gn3.computations.rust_correlation import parse_correlation_output
@pytest.mark.unit_test
def test_generate_input():
"""test generating text files"""
test_dataset = [
["14_at",12.1,14.1,None],
["15_at",12.2,14.1,None],
["17_at",12.1,14.1,11.4]
]
expected = ["14_at,12.1,14.1,", "15_at,12.2,14.1,", "17_at,12.1,14.1,11.4"]
(_tmp_dir, tmp_file) = generate_input_files(test_dataset,
output_dir="/tmp")
with open(tmp_file, "r", encoding="utf-8") as file_reader:
test_results = [line.rstrip() for line in file_reader]
os.remove(tmp_file)
assert test_results == expected
# @pytest.mark.unit_test
def test_json_file():
"""test for generating json files """
tmp_file, _tmp_json_file = generate_json_file(
tmp_dir="/tmp/correlation",
tmp_file="/data.txt",
method="pearson",
x_vals="12.1,11.3,16.5,7.5,3.2",
delimiter=",")
with open(tmp_file, "r+", encoding="utf-8") as file_reader:
results = json.load(file_reader)
assert results == {
"method": "pearson",
"file_path": "/data.txt",
"x_vals": "12.1,11.3,16.5,7.5,3.2",
"file_delimiter": ","}
@pytest.mark.unit_test
def test_parse_results():
"""test for parsing file results"""
raw_data = [
["63.62", "0.97", "0.00", "12"],
["19", "-0.96", "0.22", "12"],
["77.92", "-0.94", "0.31", "12"],
["84.04", "0.94", "0.11", "12"],
["23", "-0.91", "0.11", "12"]
]
expected_results = {
trait: {
"num_overlap": num_overlap,
"corr_coefficient": corr_coeff,
"p_value": p_val}
for (trait, corr_coeff, p_val, num_overlap) in raw_data
}
assert (parse_correlation_output(
"tests/unit/computations/data/correlation/sorted_results.txt",
"sample", len(raw_data))
== expected_results)
@pytest.mark.unit_test
def test_get_samples_no_excluded():
"""test for getting sample data"""
al_samples = {
"BXD": "12.1",
"BXD3": "16.1",
"BXD4": " x",
"BXD6": "1.1",
"BXD5": "1.37",
"BXD11": "1.91",
"BXD31": "1.1"
}
base = [
"BXD",
"BXD4",
"BXD7",
"BXD31"
]
assert get_samples(all_samples=al_samples,
base_samples=base,
excluded=[]) == {
"BXD": 12.1,
"BXD31": 1.1
}
@pytest.mark.unit_test
def test_get_samples():
"""test for getting samples with exluded"""
al_samples = {
"BXD": "12.1",
"BXD3": "16.1",
"BXD4": " x",
"BXD5": "1.1",
"BXD6": "1.37",
"BXD11": "1.91",
"BXD31": "1.1"
}
assert get_samples(all_samples=al_samples,
base_samples=["BXD", "BXD4", "BXD5", "BXD6",
"BXD11"
], excluded=["BXD", "BXD11"]), {
"BXD5": 1.1,
"BXD6": 1.37
}