diff options
-rw-r--r-- | gn3/computations/rust_correlation.py | 52 |
1 files changed, 16 insertions, 36 deletions
diff --git a/gn3/computations/rust_correlation.py b/gn3/computations/rust_correlation.py index 276013a..07e0e56 100644 --- a/gn3/computations/rust_correlation.py +++ b/gn3/computations/rust_correlation.py @@ -17,63 +17,43 @@ from gn3.settings import TMPDIR def generate_input_files(dataset: list[str], output_dir: str = TMPDIR) -> tuple[str, str]: """function generates outputfiles and inputfiles""" - tmp_dir = f"{output_dir}/correlation" - create_output_directory(tmp_dir) - tmp_file = os.path.join(tmp_dir, f"{random_string(10)}.txt") - with open(tmp_file, "w", encoding="utf-8") as file_writer: - file_writer.write("\n".join(dataset)) + return (tmp_dir, tmp_file) -def generate_json_file(tmp_dir, tmp_file, - method, delimiter, x_vals) -> tuple[str, str]: +def generate_json_file( + tmp_dir, tmp_file, method, delimiter, x_vals) -> tuple[str, str]: """generating json input file required by cargo""" - tmp_json_file = os.path.join(tmp_dir, f"{random_string(10)}.json") - output_file = os.path.join(tmp_dir, f"{random_string(10)}.txt") - correlation_args = { - "method": method, - "file_path": tmp_file, - "x_vals": x_vals, - "sample_values": "bxd1", - "output_file": output_file, - "file_delimiter": delimiter - } - with open(tmp_json_file, "w", encoding="utf-8") as outputfile: - json.dump(correlation_args, outputfile) + json.dump({ + "method": method, + "file_path": tmp_file, + "x_vals": x_vals, + "sample_values": "bxd1", + "output_file": output_file, + "file_delimiter": delimiter + }, outputfile) return (output_file, tmp_json_file) -def run_correlation(dataset, trait_vals: - str, - method: str, - delimiter: str): +def run_correlation(dataset, trait_vals: str, method: str, delimiter: str): """entry function to call rust correlation""" - (tmp_dir, tmp_file) = generate_input_files(dataset) - - (output_file, json_file) = generate_json_file(tmp_dir=tmp_dir, - tmp_file=tmp_file, - method=method, - delimiter=delimiter, - x_vals=trait_vals) - + (output_file, json_file) = generate_json_file( + tmp_dir=tmp_dir, tmp_file=tmp_file, method=method, delimiter=delimiter, + x_vals=trait_vals) command_list = [CORRELATION_COMMAND, json_file, TMPDIR] - subprocess.run(command_list, check=True) - - results = parse_correlation_output(output_file, 500) - - return results + return parse_correlation_output(output_file, 500) def parse_correlation_output(result_file: str, top_n: int = 500) -> list[dict]: |