aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/database.py26
1 files changed, 23 insertions, 3 deletions
diff --git a/wqflask/wqflask/database.py b/wqflask/wqflask/database.py
index d8cc5de9..2063190c 100644
--- a/wqflask/wqflask/database.py
+++ b/wqflask/wqflask/database.py
@@ -5,6 +5,7 @@ from string import Template
from typing import Tuple
from urllib.parse import urlparse
import importlib
+import contextlib
import MySQLdb
@@ -27,8 +28,27 @@ def parse_db_url(sql_uri: str) -> Tuple:
parsed_db.hostname, parsed_db.username, parsed_db.password,
parsed_db.path[1:], parsed_db.port)
+
+@contextlib.contextmanager
def database_connection():
- """Returns a database connection"""
+ """Provide a context manager for opening, closing, and rolling
+ back - if supported - a database connection. Should an error occur,
+ and if the table supports transactions, the connection will be
+ rolled back.
+
+ """
host, user, passwd, db_name, port = parse_db_url(sql_uri())
- return MySQLdb.connect(
- db=db_name, user=user, passwd=passwd, host=host, port=port)
+ connection = MySQLdb.connect(
+ db=db_name, user=user, passwd=passwd, host=host, port=port,
+ autocommit=False # Required for roll-backs
+ )
+ try:
+ yield connection
+ connection.close()
+ except Exception:
+ connection.rollback()
+ raise
+ else:
+ connection.commit()
+ finally:
+ connection.close()