aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/computations/rust_correlation.py52
1 files changed, 16 insertions, 36 deletions
diff --git a/gn3/computations/rust_correlation.py b/gn3/computations/rust_correlation.py
index 276013a..07e0e56 100644
--- a/gn3/computations/rust_correlation.py
+++ b/gn3/computations/rust_correlation.py
@@ -17,63 +17,43 @@ from gn3.settings import TMPDIR
def generate_input_files(dataset: list[str],
output_dir: str = TMPDIR) -> tuple[str, str]:
"""function generates outputfiles and inputfiles"""
-
tmp_dir = f"{output_dir}/correlation"
-
create_output_directory(tmp_dir)
-
tmp_file = os.path.join(tmp_dir, f"{random_string(10)}.txt")
-
with open(tmp_file, "w", encoding="utf-8") as file_writer:
-
file_writer.write("\n".join(dataset))
+
return (tmp_dir, tmp_file)
-def generate_json_file(tmp_dir, tmp_file,
- method, delimiter, x_vals) -> tuple[str, str]:
+def generate_json_file(
+ tmp_dir, tmp_file, method, delimiter, x_vals) -> tuple[str, str]:
"""generating json input file required by cargo"""
-
tmp_json_file = os.path.join(tmp_dir, f"{random_string(10)}.json")
-
output_file = os.path.join(tmp_dir, f"{random_string(10)}.txt")
- correlation_args = {
- "method": method,
- "file_path": tmp_file,
- "x_vals": x_vals,
- "sample_values": "bxd1",
- "output_file": output_file,
- "file_delimiter": delimiter
- }
-
with open(tmp_json_file, "w", encoding="utf-8") as outputfile:
- json.dump(correlation_args, outputfile)
+ json.dump({
+ "method": method,
+ "file_path": tmp_file,
+ "x_vals": x_vals,
+ "sample_values": "bxd1",
+ "output_file": output_file,
+ "file_delimiter": delimiter
+ }, outputfile)
return (output_file, tmp_json_file)
-def run_correlation(dataset, trait_vals:
- str,
- method: str,
- delimiter: str):
+def run_correlation(dataset, trait_vals: str, method: str, delimiter: str):
"""entry function to call rust correlation"""
-
(tmp_dir, tmp_file) = generate_input_files(dataset)
-
- (output_file, json_file) = generate_json_file(tmp_dir=tmp_dir,
- tmp_file=tmp_file,
- method=method,
- delimiter=delimiter,
- x_vals=trait_vals)
-
+ (output_file, json_file) = generate_json_file(
+ tmp_dir=tmp_dir, tmp_file=tmp_file, method=method, delimiter=delimiter,
+ x_vals=trait_vals)
command_list = [CORRELATION_COMMAND, json_file, TMPDIR]
-
subprocess.run(command_list, check=True)
-
- results = parse_correlation_output(output_file, 500)
-
- return results
+ return parse_correlation_output(output_file, 500)
def parse_correlation_output(result_file: str, top_n: int = 500) -> list[dict]: