about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/computations/rust_correlation.py32
1 files changed, 23 insertions, 9 deletions
diff --git a/gn3/computations/rust_correlation.py b/gn3/computations/rust_correlation.py
index 07e0e56..db357fe 100644
--- a/gn3/computations/rust_correlation.py
+++ b/gn3/computations/rust_correlation.py
@@ -45,7 +45,9 @@ def generate_json_file(
     return (output_file, tmp_json_file)
 
 
-def run_correlation(dataset, trait_vals: str, method: str, delimiter: str):
+def run_correlation(
+        dataset, trait_vals: str, method: str, delimiter: str,
+        corr_type: str = "sample", top_n: int = 500):
     """entry function to call rust correlation"""
     (tmp_dir, tmp_file) = generate_input_files(dataset)
     (output_file, json_file) = generate_json_file(
@@ -53,19 +55,31 @@ def run_correlation(dataset, trait_vals: str, method: str, delimiter: str):
         x_vals=trait_vals)
     command_list = [CORRELATION_COMMAND, json_file, TMPDIR]
     subprocess.run(command_list, check=True)
-    return parse_correlation_output(output_file, 500)
 
+    return parse_correlation_output(output_file, corr_type, top_n)
 
-def parse_correlation_output(result_file: str, top_n: int = 500) -> list[dict]:
+
+def parse_correlation_output(result_file: str,
+                             corr_type: str, top_n: int = 500) -> dict:
     """parse file output """
     def __parse_line__(line):
         (trait_name, corr_coeff, p_val, num_overlap) = line.rstrip().split(",")
-        return {
-            trait_name: {
-                "num_overlap": num_overlap,
-                "corr_coefficient": corr_coeff,
-                "p_value": p_val
-            }}
+        if corr_type == "sample":
+            return {
+                trait_name: {
+                    "num_overlap": num_overlap,
+                    "corr_coefficient": corr_coeff,
+                    "p_value": p_val
+                }
+            }
+        if corr_type == "tissue":
+            return {
+                trait_name: {
+                    "tissue_corr": corr_coeff,
+                    "tissue_number": num_overlap,
+                    "tissue_p_val": p_val
+                }
+            }
 
     with open(result_file, "r", encoding="utf-8") as file_reader:
         return [