diff options
-rw-r--r-- | gn3/api/correlation.py | 44 | ||||
-rw-r--r-- | gn3/computations/partial_correlations.py | 20 |
2 files changed, 44 insertions, 20 deletions
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py index 7eb7cd6..aeb7f8c 100644 --- a/gn3/api/correlation.py +++ b/gn3/api/correlation.py @@ -16,6 +16,8 @@ from gn3.computations.correlations import map_shared_keys_to_values from gn3.computations.correlations import compute_tissue_correlation from gn3.computations.correlations import compute_all_lit_correlation from gn3.computations.correlations import compute_all_sample_correlation +from gn3.computations.partial_correlations import ( + partial_correlations_with_target_traits) correlation = Blueprint("correlation", __name__) @@ -111,22 +113,42 @@ def partial_correlation(): return reduce(__field_errors__(request_data), fields, errors) args = request.get_json() - request_errors = __errors__( - args, ("primary_trait", "control_traits", "target_db", "method")) + with_target_db = args.get("with_target_db", True) + request_errors = None + if with_target_db: + request_errors = __errors__( + args, ("primary_trait", "control_traits", "target_db", "method")) + else: + request_errors = __errors__( + args, ("primary_trait", "control_traits", "target_traits", "method")) if request_errors: return build_response({ "status": "error", "messages": request_errors, "error_type": "Client Error"}) - return build_response({ - "status": "success", - "results": queue_cmd( - conn=redis.Redis(), - cmd=compose_pcorrs_command( + + if with_target_db: + return build_response({ + "status": "queued", + "results": queue_cmd( + conn=redis.Redis(), + cmd=compose_pcorrs_command( + trait_fullname(args["primary_trait"]), + tuple( + trait_fullname(trait) for trait in args["control_traits"]), + args["method"], args["target_db"], + int(args.get("criteria", 500))), + job_queue=current_app.config.get("REDIS_JOB_QUEUE"), + env = {"PYTHONPATH": ":".join(sys.path), "SQL_URI": SQL_URI})}) + else: + with database_connector() as conn: + results = partial_correlations_with_target_traits( + conn, trait_fullname(args["primary_trait"]), tuple( trait_fullname(trait) for trait in args["control_traits"]), - args["method"], args["target_db"], - int(args.get("criteria", 500))), - job_queue=current_app.config.get("REDIS_JOB_QUEUE"), - env = {"PYTHONPATH": ":".join(sys.path), "SQL_URI": SQL_URI})}) + tuple( + trait_fullname(trait) for trait in args["target_traits"]), + args["method"]) + + return build_response({"status": "success", "results": results}) diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py index 0041684..07c73db 100644 --- a/gn3/computations/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -555,9 +555,9 @@ def trait_for_output(trait): } return {key: val for key, val in trait.items() if val is not None} -def check_for_common_errors(conn, primary_trait_name, control_trait_names): +def check_for_common_errors( + conn, primary_trait_name, control_trait_names, threshold): """Check for common errors""" - threshold = 0 corr_min_informative = 4 non_error_result = {"status": "success"} @@ -671,8 +671,10 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911] functionality into smaller functions that do fewer things. """ + threshold = 0 + check_res = check_for_common_errors( - conn, primary_trait_name, control_trait_names) + conn, primary_trait_name, control_trait_names, threshold) if check_res.get("status") == "error": return error_check_results @@ -680,7 +682,7 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911] input_trait_geneid = primary_trait.get("geneid", 0) input_trait_symbol = primary_trait.get("symbol", "") input_trait_mouse_geneid = translate_to_mouse_gene_id( - species, input_trait_geneid, conn) + check_res["species"], input_trait_geneid, conn) tissue_probeset_freeze_id = 1 db_type = primary_trait["db"]["dataset_type"] @@ -744,8 +746,8 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911] database_filename = get_filename(conn, target_db_name, TEXTDIR) _total_traits, all_correlations = partial_corrs( conn, check_res["common_primary_control_samples"], - check_res["fixed_primary_values"], check_res["fixed_control_vals"], - len(check_res["fixed_primary_vals"]), check_res["species"], + check_res["fixed_primary_values"], check_res["fixed_control_values"], + len(check_res["fixed_primary_values"]), check_res["species"], input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id, method, {**target_dataset, "dataset_type": target_dataset["type"]}, database_filename) @@ -764,7 +766,7 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911] return __by_partial_corr_p_value__ add_lit_corr_and_tiss_corr = compose( - partial(literature_correlation_by_list, conn, species), + partial(literature_correlation_by_list, conn, check_res["species"]), partial( tissue_correlation_by_list, conn, input_trait_symbol, tissue_probeset_freeze_id, method)) @@ -801,7 +803,7 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911] "results": { "primary_trait": trait_for_output(primary_trait), "control_traits": tuple( - trait_for_output(trait) for trait in cntrl_traits), + trait_for_output(trait) for trait in check_res["control_traits"]), "correlations": tuple( trait_for_output(trait) for trait in trait_list), "dataset_type": target_dataset["type"], @@ -818,7 +820,7 @@ def partial_correlations_with_target_traits( """ threshold = 0 check_res = check_for_common_errors( - conn, primary_trait_name, control_trait_names) + conn, primary_trait_name, control_trait_names, threshold) if check_res.get("status") == "error": return error_check_results |