aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/__init__.py13
-rw-r--r--wqflask/wqflask/database.py33
2 files changed, 27 insertions, 19 deletions
diff --git a/wqflask/wqflask/__init__.py b/wqflask/wqflask/__init__.py
index ab8b9e66..118a7ff3 100644
--- a/wqflask/wqflask/__init__.py
+++ b/wqflask/wqflask/__init__.py
@@ -11,6 +11,8 @@ from utility import formatting
from gn3.authentication import DataRole, AdminRole
+from wqflask.database import parse_db_url
+
from wqflask.group_manager import group_management
from wqflask.resource_manager import resource_management
from wqflask.metadata_edits import metadata_edit
@@ -29,17 +31,6 @@ from wqflask.jupyter_notebooks import jupyter_notebooks
app = Flask(__name__)
-# Helper function for getting the SQL objects
-def parse_db_url(sql_uri: str) -> Tuple:
- """Parse SQL_URI env variable from an sql URI
- e.g. 'mysql://user:pass@host_name/db_name'
-
- """
- parsed_db = urlparse(sql_uri)
- return (parsed_db.hostname, parsed_db.username,
- parsed_db.password, parsed_db.path[1:])
-
-
# See http://flask.pocoo.org/docs/config/#configuring-from-files
# Note no longer use the badly named WQFLASK_OVERRIDES (nyi)
app.config.from_envvar('GN2_SETTINGS')
diff --git a/wqflask/wqflask/database.py b/wqflask/wqflask/database.py
index e485bcf1..cbf01346 100644
--- a/wqflask/wqflask/database.py
+++ b/wqflask/wqflask/database.py
@@ -1,15 +1,27 @@
# Module to initialize sqlalchemy with flask
+import os
+import sys
+import importlib
+
import MySQLdb
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
-from utility.tools import SQL_URI
-from flask import current_app
+def read_from_pyfile(pyfile, setting):
+ orig_sys_path = sys.path[:]
+ sys.path.insert(0, os.path.dirname(pyfile))
+ module = importlib.import_module(os.path.basename(pyfile).strip(".py"))
+ sys.path = orig_sys_path[:]
+ return module.__dict__.get(setting)
+def sql_uri():
+ """Read the SQL_URI from the environment or settings file."""
+ return os.environ.get(
+ "SQL_URI", read_from_pyfile(os.environ.get("GN2_SETTINGS"), "SQL_URI"))
-engine = create_engine(SQL_URI, encoding="latin1")
+engine = create_engine(sql_uri(), encoding="latin1")
db_session = scoped_session(sessionmaker(autocommit=False,
autoflush=False,
@@ -20,11 +32,16 @@ Base.query = db_session.query_property()
# Initialise the db
Base.metadata.create_all(bind=engine)
+def parse_db_url(sql_uri: str) -> Tuple:
+ """
+ Parse SQL_URI env variable from an sql URI
+ e.g. 'mysql://user:pass@host_name/db_name'
+ """
+ parsed_db = urlparse(sql_uri)
+ return (parsed_db.hostname, parsed_db.username,
+ parsed_db.password, parsed_db.path[1:])
def database_connection():
"""Returns a database connection"""
- return MySQLdb.Connect(
- db=current_app.config.get("DB_NAME"),
- user=current_app.config.get("DB_USER"),
- passwd=current_app.config.get("DB_PASS"),
- host=current_app.config.get("DB_HOST"))
+ host, user, passwd, db_name = parse_db_url(sql_uri())
+ return MySQLdb.Connect(db=db_name, user=user, passwd=passwd, host=host)