about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--qc_app/__init__.py5
-rw-r--r--qc_app/check_connections.py28
-rw-r--r--qc_app/db_utils.py10
-rw-r--r--scripts/worker.py7
4 files changed, 40 insertions, 10 deletions
diff --git a/qc_app/__init__.py b/qc_app/__init__.py
index 6b760b9..eefe236 100644
--- a/qc_app/__init__.py
+++ b/qc_app/__init__.py
@@ -7,6 +7,7 @@ from flask import Flask
 from .entry import entrybp
 from .parse import parsebp
 from .dbinsert import dbinsertbp
+from .check_connections import check_db, check_redis
 
 def instance_path():
     """Retrieve the `instance_path`. Raise an exception if not defined."""
@@ -25,6 +26,10 @@ def create_app(instance_dir):
     app.config.from_pyfile(os.path.join(os.getcwd(), "etc/default_config.py"))
     app.config.from_pyfile("config.py") # Override defaults with instance path
 
+    # Check the connection
+    check_db(app.config["SQL_URI"])
+    check_redis(app.config["REDIS_URL"])
+
     # setup blueprints
     app.register_blueprint(entrybp, url_prefix="/")
     app.register_blueprint(parsebp, url_prefix="/parse")
diff --git a/qc_app/check_connections.py b/qc_app/check_connections.py
new file mode 100644
index 0000000..ceccc32
--- /dev/null
+++ b/qc_app/check_connections.py
@@ -0,0 +1,28 @@
+"""Check the various connection used in the application"""
+import sys
+import traceback
+
+import redis
+import MySQLdb
+
+from qc_app.db_utils import database_connection
+
+def check_redis(uri: str):
+    "Check the redis connection"
+    try:
+        with redis.Redis.from_url(uri) as rconn:
+            rconn.ping()
+    except redis.exceptions.ConnectionError as conn_err:
+        print(conn_err, file=sys.stderr)
+        print(traceback.format_exc(), file=sys.stderr)
+        sys.exit(1)
+
+def check_db(uri: str):
+    "Check the mysql connection"
+    try:
+        with database_connection(uri) as dbconn: # pylint: disable=[unused-variable]
+            pass
+    except MySQLdb.OperationalError as op_err:
+        print(op_err, file=sys.stderr)
+        print(traceback.format_exc(), file=sys.stderr)
+        sys.exit(1)
diff --git a/qc_app/db_utils.py b/qc_app/db_utils.py
index 95b0057..239f45c 100644
--- a/qc_app/db_utils.py
+++ b/qc_app/db_utils.py
@@ -1,20 +1,20 @@
 """module contains all db related stuff"""
-from typing import Tuple
+from typing import Tuple, Optional
 
 from urllib.parse import urlparse
 import MySQLdb as mdb
 from flask import current_app as app
 
-def parse_db_url() -> Tuple:
+def parse_db_url(db_url) -> Tuple:
     """
     Parse SQL_URI configuration variable.
     """
-    parsed_db = urlparse(app.config["SQL_URI"])
+    parsed_db = urlparse(db_url)
     return (parsed_db.hostname, parsed_db.username,
             parsed_db.password, parsed_db.path[1:])
 
 
-def database_connection() -> mdb.Connection:
+def database_connection(db_url: Optional[str] = None) -> mdb.Connection:
     """function to create db connector"""
-    host, user, passwd, db_name = parse_db_url()
+    host, user, passwd, db_name = parse_db_url(db_url or app.config["SQL_URI"])
     return mdb.connect(host, user, passwd, db_name)
diff --git a/scripts/worker.py b/scripts/worker.py
index 4077ad1..fee4ec8 100644
--- a/scripts/worker.py
+++ b/scripts/worker.py
@@ -10,6 +10,7 @@ from tempfile import TemporaryFile
 from redis import Redis
 
 from qc_app import jobs
+from qc_app.check_connections import check_redis
 
 def parse_args():
     "Parse the command-line arguments"
@@ -21,11 +22,7 @@ def parse_args():
     parser.add_argument("job_id", help="The id of the job being processed")
 
     args = parser.parse_args()
-    try:
-        conn = Redis.from_url(args.redisurl) # pylint: disable=[unused-variable]
-    except ConnectionError as conn_err: # pylint: disable=[unused-variable]
-        print(traceback.format_exc(), file=sys.stderr)
-        sys.exit(1)
+    check_redis(args.redisurl)
 
     return args