about summary refs log tree commit diff
path: root/gn3
diff options
context:
space:
mode:
Diffstat (limited to 'gn3')
-rw-r--r--gn3/api/correlation.py44
-rw-r--r--gn3/computations/partial_correlations.py20
2 files changed, 44 insertions, 20 deletions
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py
index 7eb7cd6..aeb7f8c 100644
--- a/gn3/api/correlation.py
+++ b/gn3/api/correlation.py
@@ -16,6 +16,8 @@ from gn3.computations.correlations import map_shared_keys_to_values
 from gn3.computations.correlations import compute_tissue_correlation
 from gn3.computations.correlations import compute_all_lit_correlation
 from gn3.computations.correlations import compute_all_sample_correlation
+from gn3.computations.partial_correlations import (
+    partial_correlations_with_target_traits)
 
 correlation = Blueprint("correlation", __name__)
 
@@ -111,22 +113,42 @@ def partial_correlation():
         return reduce(__field_errors__(request_data), fields, errors)
 
     args = request.get_json()
-    request_errors = __errors__(
-        args, ("primary_trait", "control_traits", "target_db", "method"))
+    with_target_db = args.get("with_target_db", True)
+    request_errors = None
+    if with_target_db:
+        request_errors = __errors__(
+            args, ("primary_trait", "control_traits", "target_db", "method"))
+    else:
+        request_errors = __errors__(
+            args, ("primary_trait", "control_traits", "target_traits", "method"))
     if request_errors:
         return build_response({
             "status": "error",
             "messages": request_errors,
             "error_type": "Client Error"})
-    return build_response({
-        "status": "success",
-        "results": queue_cmd(
-            conn=redis.Redis(),
-            cmd=compose_pcorrs_command(
+
+    if with_target_db:
+        return build_response({
+            "status": "queued",
+            "results": queue_cmd(
+                conn=redis.Redis(),
+                cmd=compose_pcorrs_command(
+                    trait_fullname(args["primary_trait"]),
+                    tuple(
+                        trait_fullname(trait) for trait in args["control_traits"]),
+                    args["method"], args["target_db"],
+                    int(args.get("criteria", 500))),
+                job_queue=current_app.config.get("REDIS_JOB_QUEUE"),
+                env = {"PYTHONPATH": ":".join(sys.path), "SQL_URI": SQL_URI})})
+    else:
+        with database_connector() as conn:
+            results = partial_correlations_with_target_traits(
+                conn,
                 trait_fullname(args["primary_trait"]),
                 tuple(
                     trait_fullname(trait) for trait in args["control_traits"]),
-                args["method"], args["target_db"],
-                int(args.get("criteria", 500))),
-            job_queue=current_app.config.get("REDIS_JOB_QUEUE"),
-            env = {"PYTHONPATH": ":".join(sys.path), "SQL_URI": SQL_URI})})
+                tuple(
+                    trait_fullname(trait) for trait in args["target_traits"]),
+                args["method"])
+
+        return build_response({"status": "success", "results": results})
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 0041684..07c73db 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -555,9 +555,9 @@ def trait_for_output(trait):
     }
     return {key: val for key, val in trait.items() if val is not None}
 
-def check_for_common_errors(conn, primary_trait_name, control_trait_names):
+def check_for_common_errors(
+        conn, primary_trait_name, control_trait_names, threshold):
     """Check for common errors"""
-    threshold = 0
     corr_min_informative = 4
     non_error_result = {"status": "success"}
 
@@ -671,8 +671,10 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911]
     functionality into smaller functions that do fewer things.
     """
 
+    threshold = 0
+
     check_res = check_for_common_errors(
-        conn, primary_trait_name, control_trait_names)
+        conn, primary_trait_name, control_trait_names, threshold)
     if check_res.get("status") == "error":
         return error_check_results
 
@@ -680,7 +682,7 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911]
     input_trait_geneid = primary_trait.get("geneid", 0)
     input_trait_symbol = primary_trait.get("symbol", "")
     input_trait_mouse_geneid = translate_to_mouse_gene_id(
-        species, input_trait_geneid, conn)
+        check_res["species"], input_trait_geneid, conn)
 
     tissue_probeset_freeze_id = 1
     db_type = primary_trait["db"]["dataset_type"]
@@ -744,8 +746,8 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911]
     database_filename = get_filename(conn, target_db_name, TEXTDIR)
     _total_traits, all_correlations = partial_corrs(
         conn, check_res["common_primary_control_samples"],
-        check_res["fixed_primary_values"], check_res["fixed_control_vals"],
-        len(check_res["fixed_primary_vals"]), check_res["species"],
+        check_res["fixed_primary_values"], check_res["fixed_control_values"],
+        len(check_res["fixed_primary_values"]), check_res["species"],
         input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id,
         method, {**target_dataset, "dataset_type": target_dataset["type"]},
         database_filename)
@@ -764,7 +766,7 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911]
         return __by_partial_corr_p_value__
 
     add_lit_corr_and_tiss_corr = compose(
-        partial(literature_correlation_by_list, conn, species),
+        partial(literature_correlation_by_list, conn, check_res["species"]),
         partial(
             tissue_correlation_by_list, conn, input_trait_symbol,
             tissue_probeset_freeze_id, method))
@@ -801,7 +803,7 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911]
         "results": {
             "primary_trait": trait_for_output(primary_trait),
             "control_traits": tuple(
-                trait_for_output(trait) for trait in cntrl_traits),
+                trait_for_output(trait) for trait in check_res["control_traits"]),
             "correlations": tuple(
                 trait_for_output(trait) for trait in trait_list),
             "dataset_type": target_dataset["type"],
@@ -818,7 +820,7 @@ def partial_correlations_with_target_traits(
     """
     threshold = 0
     check_res = check_for_common_errors(
-        conn, primary_trait_name, control_trait_names)
+        conn, primary_trait_name, control_trait_names, threshold)
     if check_res.get("status") == "error":
         return error_check_results