diff options
-rw-r--r-- | wqflask/wqflask/__init__.py | 13 | ||||
-rw-r--r-- | wqflask/wqflask/database.py | 33 |
2 files changed, 27 insertions, 19 deletions
diff --git a/wqflask/wqflask/__init__.py b/wqflask/wqflask/__init__.py index ab8b9e66..118a7ff3 100644 --- a/wqflask/wqflask/__init__.py +++ b/wqflask/wqflask/__init__.py @@ -11,6 +11,8 @@ from utility import formatting from gn3.authentication import DataRole, AdminRole +from wqflask.database import parse_db_url + from wqflask.group_manager import group_management from wqflask.resource_manager import resource_management from wqflask.metadata_edits import metadata_edit @@ -29,17 +31,6 @@ from wqflask.jupyter_notebooks import jupyter_notebooks app = Flask(__name__) -# Helper function for getting the SQL objects -def parse_db_url(sql_uri: str) -> Tuple: - """Parse SQL_URI env variable from an sql URI - e.g. 'mysql://user:pass@host_name/db_name' - - """ - parsed_db = urlparse(sql_uri) - return (parsed_db.hostname, parsed_db.username, - parsed_db.password, parsed_db.path[1:]) - - # See http://flask.pocoo.org/docs/config/#configuring-from-files # Note no longer use the badly named WQFLASK_OVERRIDES (nyi) app.config.from_envvar('GN2_SETTINGS') diff --git a/wqflask/wqflask/database.py b/wqflask/wqflask/database.py index e485bcf1..cbf01346 100644 --- a/wqflask/wqflask/database.py +++ b/wqflask/wqflask/database.py @@ -1,15 +1,27 @@ # Module to initialize sqlalchemy with flask +import os +import sys +import importlib + import MySQLdb from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.ext.declarative import declarative_base -from utility.tools import SQL_URI -from flask import current_app +def read_from_pyfile(pyfile, setting): + orig_sys_path = sys.path[:] + sys.path.insert(0, os.path.dirname(pyfile)) + module = importlib.import_module(os.path.basename(pyfile).strip(".py")) + sys.path = orig_sys_path[:] + return module.__dict__.get(setting) +def sql_uri(): + """Read the SQL_URI from the environment or settings file.""" + return os.environ.get( + "SQL_URI", read_from_pyfile(os.environ.get("GN2_SETTINGS"), "SQL_URI")) -engine = create_engine(SQL_URI, encoding="latin1") +engine = create_engine(sql_uri(), encoding="latin1") db_session = scoped_session(sessionmaker(autocommit=False, autoflush=False, @@ -20,11 +32,16 @@ Base.query = db_session.query_property() # Initialise the db Base.metadata.create_all(bind=engine) +def parse_db_url(sql_uri: str) -> Tuple: + """ + Parse SQL_URI env variable from an sql URI + e.g. 'mysql://user:pass@host_name/db_name' + """ + parsed_db = urlparse(sql_uri) + return (parsed_db.hostname, parsed_db.username, + parsed_db.password, parsed_db.path[1:]) def database_connection(): """Returns a database connection""" - return MySQLdb.Connect( - db=current_app.config.get("DB_NAME"), - user=current_app.config.get("DB_USER"), - passwd=current_app.config.get("DB_PASS"), - host=current_app.config.get("DB_HOST")) + host, user, passwd, db_name = parse_db_url(sql_uri()) + return MySQLdb.Connect(db=db_name, user=user, passwd=passwd, host=host) |