about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-02-23 14:51:47 +0300
committerFrederick Muriuki Muriithi2022-03-03 10:20:04 +0300
commit6d39c92fbc9a7b82cd8eef60c62cd5d83acb49a1 (patch)
tree7efab53cc8fc367f433ac01ece95b0fbecc858d9
parent8e0fcfa78fcdb5bdd5b49e2b1ac918ae9cc0fc53 (diff)
downloadgenenetwork3-6d39c92fbc9a7b82cd8eef60c62cd5d83acb49a1.tar.gz
Run partial correlations in an external process
Run the partial correlations code in an external python process decoupling it
from the server and making it asynchronous.

Summary of changes:
* gn3/api/correlation.py:
  - Remove response processing code
  - Queue partial corrs processing
  - Create new endpoint to get results
* gn3/commands.py
  - Compose the pcorrs command to be run in an external process
  - Enable running of subprocess commands with list args
* gn3/responses/__init__.py: new module indicator file
* gn3/responses/pcorrs_responses.py: Hold response processing code extracted
  from ~gn3.api.correlations.py~ file
* scripts/partial_correlations.py: CLI script to process the pcorrs
* sheepdog/worker.py:
  - Add the *genenetwork3* path at the beginning of the ~sys.path~ list to
    override any GN3 in the site-packages
  - Add any environment variables to be set for the command to be run
-rw-r--r--gn3/api/correlation.py65
-rw-r--r--gn3/commands.py32
-rw-r--r--gn3/responses/__init__.py0
-rw-r--r--gn3/responses/pcorrs_responses.py24
-rwxr-xr-xscripts/partial_correlations.py59
-rw-r--r--sheepdog/worker.py16
6 files changed, 150 insertions, 46 deletions
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py
index 00b3ad5..14c029c 100644
--- a/gn3/api/correlation.py
+++ b/gn3/api/correlation.py
@@ -1,17 +1,22 @@
 """Endpoints for running correlations"""
+import sys
 import json
 from functools import reduce
 
+import redis
 from flask import jsonify
 from flask import Blueprint
 from flask import request
-from flask import make_response
+from flask import current_app
 
-from gn3.computations.correlations import compute_all_sample_correlation
-from gn3.computations.correlations import compute_all_lit_correlation
-from gn3.computations.correlations import compute_tissue_correlation
-from gn3.computations.correlations import map_shared_keys_to_values
+from gn3.settings import SQL_URI
+from gn3.commands import queue_cmd, compose_pcorrs_command
 from gn3.db_utils import database_connector
+from gn3.responses.pcorrs_responses import build_response
+from gn3.computations.correlations import map_shared_keys_to_values
+from gn3.computations.correlations import compute_tissue_correlation
+from gn3.computations.correlations import compute_all_lit_correlation
+from gn3.computations.correlations import compute_all_sample_correlation
 from gn3.computations.partial_correlations import partial_correlations_entry
 
 correlation = Blueprint("correlation", __name__)
@@ -93,7 +98,7 @@ def compute_tissue_corr(corr_method="pearson"):
 def partial_correlation():
     """API endpoint for partial correlations."""
     def trait_fullname(trait):
-        return f"{trait['dataset']}::{trait['name']}"
+        return f"{trait['dataset']}::{trait['trait_name']}"
 
     def __field_errors__(args):
         def __check__(acc, field):
@@ -107,37 +112,29 @@ def partial_correlation():
         if request_data is None:
             return ("No request data",)
 
-        return reduce(__field_errors__(args), fields, errors)
-
-    class OutputEncoder(json.JSONEncoder):
-        """
-        Class to encode output into JSON, for objects which the default
-        json.JSONEncoder class does not have default encoding for.
-        """
-        def default(self, o):
-            if isinstance(o, bytes):
-                return str(o, encoding="utf-8")
-            return json.JSONEncoder.default(self, o)
-
-    def __build_response__(data):
-        status_codes = {"error": 400, "not-found": 404, "success": 200}
-        response = make_response(
-            json.dumps(data, cls=OutputEncoder),
-            status_codes[data["status"]])
-        response.headers["Content-Type"] = "application/json"
-        return response
-
-    args = request.get_json()
+        return reduce(__field_errors__(request_data), fields, errors)
+
+    args = json.loads(request.get_json())
     request_errors = __errors__(
         args, ("primary_trait", "control_traits", "target_db", "method"))
     if request_errors:
-        return __build_response__({
+        return build_response({
             "status": "error",
             "messages": request_errors,
             "error_type": "Client Error"})
-    conn, _cursor_object = database_connector()
-    corr_results = partial_correlations_entry(
-        conn, trait_fullname(args["primary_trait"]),
-        tuple(trait_fullname(trait) for trait in args["control_traits"]),
-        args["method"], int(args.get("criteria", 500)), args["target_db"])
-    return __build_response__(corr_results)
+    return build_response({
+        "status": "success",
+        "results": queue_cmd(
+            conn=redis.Redis(),
+            cmd=compose_pcorrs_command(
+                trait_fullname(args["primary_trait"]),
+                tuple(
+                    trait_fullname(trait) for trait in args["control_traits"]),
+                args["method"], args["target_db"],
+                int(args.get("criteria", 500))),
+            job_queue=current_app.config.get("REDIS_JOB_QUEUE"),
+            env = {"PYTHONPATH": ":".join(sys.path), "SQL_URI": SQL_URI})})
+
+@correlation.route("/partial/<job_id>", methods=["GET"])
+def partial_correlation_results():
+    raise Exception("Not implemented!!")
diff --git a/gn3/commands.py b/gn3/commands.py
index 7d42ced..29e3df2 100644
--- a/gn3/commands.py
+++ b/gn3/commands.py
@@ -1,5 +1,8 @@
 """Procedures used to work with the various bio-informatics cli
 commands"""
+import os
+import sys
+import json
 import subprocess
 
 from datetime import datetime
@@ -7,6 +10,8 @@ from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Tuple
+from typing import Union
+from typing import Sequence
 from uuid import uuid4
 from redis.client import Redis  # Used only in type hinting
 
@@ -46,10 +51,20 @@ def compose_rqtl_cmd(rqtl_wrapper_cmd: str,
 
     return cmd
 
+def compose_pcorrs_command(
+        primary_trait: str, control_traits: Tuple[str, ...], method: str,
+        target_database: str, criteria: int = 500):
+    rundir = os.path.abspath(".")
+    return (
+        f"{sys.executable}", f"{rundir}/scripts/partial_correlations.py",
+        primary_trait, ",".join(control_traits), f'"{method}"',
+        f"{target_database}", f"--criteria={criteria}")
+
 def queue_cmd(conn: Redis,
               job_queue: str,
-              cmd: str,
-              email: Optional[str] = None) -> str:
+              cmd: Union[str, Sequence[str]],
+              email: Optional[str] = None,
+              env: Optional[dict] = None) -> str:
     """Given a command CMD; (optional) EMAIL; and a redis connection CONN, queue
 it in Redis with an initial status of 'queued'.  The following status codes
 are supported:
@@ -68,17 +83,22 @@ Returns the name of the specific redis hash for the specific task.
                  f"{datetime.now().strftime('%Y-%m-%d%H-%M%S-%M%S-')}"
                  f"{str(uuid4())}")
     conn.rpush(job_queue, unique_id)
-    for key, value in {"cmd": cmd, "result": "", "status": "queued"}.items():
+    for key, value in {
+            "cmd": json.dumps(cmd), "result": "", "status": "queued",
+            "env": json.dumps(env)}.items():
         conn.hset(name=unique_id, key=key, value=value)
     if email:
         conn.hset(name=unique_id, key="email", value=email)
     return unique_id
 
 
-def run_cmd(cmd: str, success_codes: Tuple = (0,)) -> Dict:
+def run_cmd(cmd: str, success_codes: Tuple = (0,), env: str = None) -> Dict:
     """Run CMD and return the CMD's status code and output as a dict"""
-    results = subprocess.run(cmd, capture_output=True, shell=True,
-                             check=False)
+    parsed_cmd = json.loads(cmd)
+    parsed_env = (json.loads(env) if env is not None else None)
+    results = subprocess.run(
+        parsed_cmd, capture_output=True, shell=isinstance(parsed_cmd, str),
+        check=False, env=parsed_env)
     out = str(results.stdout, 'utf-8')
     if results.returncode not in success_codes:  # Error!
         out = str(results.stderr, 'utf-8')
diff --git a/gn3/responses/__init__.py b/gn3/responses/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/gn3/responses/__init__.py
diff --git a/gn3/responses/pcorrs_responses.py b/gn3/responses/pcorrs_responses.py
new file mode 100644
index 0000000..d6fd9d7
--- /dev/null
+++ b/gn3/responses/pcorrs_responses.py
@@ -0,0 +1,24 @@
+"""Functions and classes that deal with responses and conversion to JSON."""
+import json
+
+from flask import make_response
+
+class OutputEncoder(json.JSONEncoder):
+    """
+    Class to encode output into JSON, for objects which the default
+    json.JSONEncoder class does not have default encoding for.
+    """
+    def default(self, o):
+        if isinstance(o, bytes):
+            return str(o, encoding="utf-8")
+        return json.JSONEncoder.default(self, o)
+
+def build_response(data):
+    """Build the responses for the API"""
+    status_codes = {
+        "error": 400, "not-found": 404, "success": 200, "exception": 500}
+    response = make_response(
+            json.dumps(data, cls=OutputEncoder),
+            status_codes[data["status"]])
+    response.headers["Content-Type"] = "application/json"
+    return response
diff --git a/scripts/partial_correlations.py b/scripts/partial_correlations.py
new file mode 100755
index 0000000..ee442df
--- /dev/null
+++ b/scripts/partial_correlations.py
@@ -0,0 +1,59 @@
+import sys
+import json
+import traceback
+from argparse import ArgumentParser
+
+from gn3.db_utils import database_connector
+from gn3.responses.pcorrs_responses import OutputEncoder
+from gn3.computations.partial_correlations import partial_correlations_entry
+
+def process_cli_arguments():
+    parser = ArgumentParser()
+    parser.add_argument(
+        "primary_trait",
+        help="The primary trait's full name",
+        type=str)
+    parser.add_argument(
+        "control_traits",
+        help="A comma-separated list of traits' full names",
+        type=str)
+    parser.add_argument(
+        "method",
+        help="The correlation method to use",
+        type=str)
+    parser.add_argument(
+        "target_database",
+        help="The target database to run the partial correlations against",
+        type=str)
+    parser.add_argument(
+        "--criteria",
+        help="Number of results to return",
+        type=int, default=500)
+    return parser.parse_args()
+
+def cleanup_string(the_str):
+    return the_str.strip('"\t\n\r ')
+
+def run_partial_corrs(args):
+    try:
+        conn, _cursor_object = database_connector()
+        return partial_correlations_entry(
+            conn, cleanup_string(args.primary_trait),
+            tuple(cleanup_string(args.control_traits).split(",")),
+            cleanup_string(args.method), args.criteria,
+            cleanup_string(args.target_database))
+    except Exception as exc:
+        print(traceback.format_exc(), file=sys.stderr)
+        return {
+            "status": "exception",
+            "message": traceback.format_exc()
+        }
+
+def enter():
+    args = process_cli_arguments()
+    print(json.dumps(
+        run_partial_corrs(process_cli_arguments()),
+        cls = OutputEncoder))
+
+if __name__ == "__main__":
+    enter()
diff --git a/sheepdog/worker.py b/sheepdog/worker.py
index 4e3610e..4e7f9e7 100644
--- a/sheepdog/worker.py
+++ b/sheepdog/worker.py
@@ -5,9 +5,12 @@ import time
 import redis
 import redis.connection
 
-# Enable importing from one dir up since gn3 isn't installed as a globally
-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+# Enable importing from one dir up: put as first to override any other globally
+# accessible GN3
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
 
+def update_status(conn, cmd_id, status):
+    conn.hset(name=f"{cmd_id}", key="status", value=f"{status}")
 
 def run_jobs(conn):
     """Process the redis using a redis connection, CONN"""
@@ -17,13 +20,14 @@ def run_jobs(conn):
     if bool(cmd_id):
         cmd = conn.hget(name=cmd_id, key="cmd")
         if cmd and (conn.hget(cmd_id, "status") == b"queued"):
-            result = run_cmd(cmd.decode("utf-8"))
+            update_status(conn, cmd_id, "running")
+            result = run_cmd(
+                cmd.decode("utf-8"), env=conn.hget(name=cmd_id, key="env"))
             conn.hset(name=cmd_id, key="result", value=result.get("output"))
             if result.get("code") == 0:  # Success
-                conn.hset(name=cmd_id, key="status", value="success")
+                update_status(conn, cmd_id, "success")
             else:
-                conn.hset(name=cmd_id, key="status", value="error")
-
+                update_status(conn, cmd_id, "error")
 
 if __name__ == "__main__":
     redis_conn = redis.Redis()