diff options
-rw-r--r-- | gn3/computations/rust_correlation.py | 21 | ||||
-rw-r--r-- | tests/unit/computations/test_rust_correlation.py | 19 |
2 files changed, 17 insertions, 23 deletions
diff --git a/gn3/computations/rust_correlation.py b/gn3/computations/rust_correlation.py index 0539527..380cff1 100644 --- a/gn3/computations/rust_correlation.py +++ b/gn3/computations/rust_correlation.py @@ -16,7 +16,7 @@ from gn3.settings import TMPDIR def generate_input_files(dataset: list[str], - output_dir: str = TMPDIR) ->(str, str): + output_dir: str = TMPDIR) -> tuple[str, str]: """function generates outputfiles and inputfiles""" tmp_dir = f"{output_dir}/correlation" @@ -31,18 +31,16 @@ def generate_input_files(dataset: list[str], return (tmp_dir, tmp_file) -def generate_json_file(**kwargs): +def generate_json_file(tmp_dir, tmp_file, method, delimiter, x_vals) -> str: """generating json input file required by cargo""" - (tmp_dir, tmp_file) = (kwargs.get("tmp_dir"), kwargs.get("tmp_file")) - tmp_json_file = os.path.join(tmp_dir, f"{random_string(10)}.json") correlation_args = { - "method": kwargs.get("method", "pearson"), + "method": method, "file_path": tmp_file, - "x_vals": kwargs.get("x_vals"), - "file_delimiter": kwargs.get("delimiter", ",") + "x_vals": x_vals, + "file_delimiter": delimiter } with open(tmp_json_file, "w", encoding="utf-8") as outputfile: @@ -59,17 +57,16 @@ def run_correlation(dataset, trait_vals: (tmp_dir, tmp_file) = generate_input_files(dataset) - json_file = generate_json_file(** - {"tmp_dir": tmp_dir, "tmp_file": tmp_file, - "method": method, "delimiter": delimiter, - "x_vals": trait_vals}) + json_file = generate_json_file(tmp_dir=tmp_dir, tmp_file=tmp_file, + method=method, delimiter=delimiter, + x_vals=trait_vals) command_list = [CORRELATION_COMMAND, json_file, TMPDIR] return subprocess.run(command_list, check=True) -def parse_correlation_output(result_file: str): +def parse_correlation_output(result_file: str) -> list[dict]: """parse file output """ corr_results = [] diff --git a/tests/unit/computations/test_rust_correlation.py b/tests/unit/computations/test_rust_correlation.py index b402621..ac27dea 100644 --- a/tests/unit/computations/test_rust_correlation.py +++ b/tests/unit/computations/test_rust_correlation.py @@ -9,7 +9,6 @@ from gn3.computations.rust_correlation import generate_input_files from gn3.computations.rust_correlation import parse_correlation_output - @pytest.mark.unit_test def test_generate_input(): """test generating text files""" @@ -21,7 +20,8 @@ def test_generate_input(): ] - (_tmp_dir, tmp_file) = generate_input_files(test_dataset, output_dir="/tmp") + (_tmp_dir, tmp_file) = generate_input_files(test_dataset, + output_dir="/tmp") with open(tmp_file, "r", encoding="utf-8") as file_reader: test_results = [line.rstrip() for line in file_reader] @@ -35,16 +35,13 @@ def test_generate_input(): def test_json_file(): """test for generating json files """ - json_dict = {"tmp_dir": "/tmp/correlation", - - "tmp_file": "/data.txt", - "method": "pearson", - "file_path": "/data.txt", - "x_vals": "12.1,11.3,16.5,7.5,3.2", - "file_delimiter": ","} - tmp_file = generate_json_file(**json_dict) + tmp_file = generate_json_file(tmp_dir="/tmp/correlation", + tmp_file="/data.txt", + method="pearson", + x_vals="12.1,11.3,16.5,7.5,3.2", + delimiter=",") - with open(tmp_file, "r+",encoding="utf-8") as file_reader: + with open(tmp_file, "r+", encoding="utf-8") as file_reader: results = json.load(file_reader) assert results == { |