diff options
Diffstat (limited to 'gn2/wqflask/database.py')
-rw-r--r-- | gn2/wqflask/database.py | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/gn2/wqflask/database.py b/gn2/wqflask/database.py new file mode 100644 index 00000000..331ad380 --- /dev/null +++ b/gn2/wqflask/database.py @@ -0,0 +1,52 @@ +# Module to initialize sqlalchemy with flask +import os +import sys +import logging +import traceback +from typing import Tuple, Protocol, Any, Iterator +from urllib.parse import urlparse +import importlib +import contextlib + +#: type: ignore +import MySQLdb + + +class Connection(Protocol): + def cursor(self) -> Any: + ... + +def parse_db_url(sql_uri: str) -> Tuple: + """ + Parse SQL_URI env variable from an sql URI + e.g. 'mysql://user:pass@host_name/db_name' + """ + parsed_db = urlparse(sql_uri) + return ( + parsed_db.hostname, parsed_db.username, parsed_db.password, + parsed_db.path[1:], parsed_db.port) + + +@contextlib.contextmanager +def database_connection(sql_uri: str) -> Iterator[Connection]: + """Provide a context manager for opening, closing, and rolling + back - if supported - a database connection. Should an error occur, + and if the table supports transactions, the connection will be + rolled back. + + """ + host, user, passwd, db_name, port = parse_db_url(sql_uri) + connection = MySQLdb.connect( + db=db_name, user=user, passwd=passwd or '', host=host, + port=(port or 3306), autocommit=False # Required for roll-backs + ) + try: + yield connection + connection.commit() + except Exception as _exc: + logging.error("===== Query Error =====\r\n%s\r\n===== END: Query Error", + traceback.format_exc()) + connection.rollback() + raise _exc + finally: + connection.close() |