diff options
-rw-r--r-- | wqflask/wqflask/database.py | 26 |
1 files changed, 23 insertions, 3 deletions
diff --git a/wqflask/wqflask/database.py b/wqflask/wqflask/database.py index d8cc5de9..2063190c 100644 --- a/wqflask/wqflask/database.py +++ b/wqflask/wqflask/database.py @@ -5,6 +5,7 @@ from string import Template from typing import Tuple from urllib.parse import urlparse import importlib +import contextlib import MySQLdb @@ -27,8 +28,27 @@ def parse_db_url(sql_uri: str) -> Tuple: parsed_db.hostname, parsed_db.username, parsed_db.password, parsed_db.path[1:], parsed_db.port) + +@contextlib.contextmanager def database_connection(): - """Returns a database connection""" + """Provide a context manager for opening, closing, and rolling + back - if supported - a database connection. Should an error occur, + and if the table supports transactions, the connection will be + rolled back. + + """ host, user, passwd, db_name, port = parse_db_url(sql_uri()) - return MySQLdb.connect( - db=db_name, user=user, passwd=passwd, host=host, port=port) + connection = MySQLdb.connect( + db=db_name, user=user, passwd=passwd, host=host, port=port, + autocommit=False # Required for roll-backs + ) + try: + yield connection + connection.close() + except Exception: + connection.rollback() + raise + else: + connection.commit() + finally: + connection.close() |