aboutsummaryrefslogtreecommitdiff
path: root/gn3/computations
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/computations')
-rw-r--r--gn3/computations/rust_correlation.py21
1 files changed, 9 insertions, 12 deletions
diff --git a/gn3/computations/rust_correlation.py b/gn3/computations/rust_correlation.py
index 0539527..380cff1 100644
--- a/gn3/computations/rust_correlation.py
+++ b/gn3/computations/rust_correlation.py
@@ -16,7 +16,7 @@ from gn3.settings import TMPDIR
def generate_input_files(dataset: list[str],
- output_dir: str = TMPDIR) ->(str, str):
+ output_dir: str = TMPDIR) -> tuple[str, str]:
"""function generates outputfiles and inputfiles"""
tmp_dir = f"{output_dir}/correlation"
@@ -31,18 +31,16 @@ def generate_input_files(dataset: list[str],
return (tmp_dir, tmp_file)
-def generate_json_file(**kwargs):
+def generate_json_file(tmp_dir, tmp_file, method, delimiter, x_vals) -> str:
"""generating json input file required by cargo"""
- (tmp_dir, tmp_file) = (kwargs.get("tmp_dir"), kwargs.get("tmp_file"))
-
tmp_json_file = os.path.join(tmp_dir, f"{random_string(10)}.json")
correlation_args = {
- "method": kwargs.get("method", "pearson"),
+ "method": method,
"file_path": tmp_file,
- "x_vals": kwargs.get("x_vals"),
- "file_delimiter": kwargs.get("delimiter", ",")
+ "x_vals": x_vals,
+ "file_delimiter": delimiter
}
with open(tmp_json_file, "w", encoding="utf-8") as outputfile:
@@ -59,17 +57,16 @@ def run_correlation(dataset, trait_vals:
(tmp_dir, tmp_file) = generate_input_files(dataset)
- json_file = generate_json_file(**
- {"tmp_dir": tmp_dir, "tmp_file": tmp_file,
- "method": method, "delimiter": delimiter,
- "x_vals": trait_vals})
+ json_file = generate_json_file(tmp_dir=tmp_dir, tmp_file=tmp_file,
+ method=method, delimiter=delimiter,
+ x_vals=trait_vals)
command_list = [CORRELATION_COMMAND, json_file, TMPDIR]
return subprocess.run(command_list, check=True)
-def parse_correlation_output(result_file: str):
+def parse_correlation_output(result_file: str) -> list[dict]:
"""parse file output """
corr_results = []