about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/database.py26
1 files changed, 23 insertions, 3 deletions
diff --git a/wqflask/wqflask/database.py b/wqflask/wqflask/database.py
index d8cc5de9..2063190c 100644
--- a/wqflask/wqflask/database.py
+++ b/wqflask/wqflask/database.py
@@ -5,6 +5,7 @@ from string import Template
 from typing import Tuple
 from urllib.parse import urlparse
 import importlib
+import contextlib
 
 import MySQLdb
 
@@ -27,8 +28,27 @@ def parse_db_url(sql_uri: str) -> Tuple:
         parsed_db.hostname, parsed_db.username, parsed_db.password,
         parsed_db.path[1:], parsed_db.port)
 
+
+@contextlib.contextmanager
 def database_connection():
-    """Returns a database connection"""
+    """Provide a context manager for opening, closing, and rolling
+    back - if supported - a database connection.  Should an error occur,
+    and if the table supports transactions, the connection will be
+    rolled back.
+
+    """
     host, user, passwd, db_name, port = parse_db_url(sql_uri())
-    return MySQLdb.connect(
-        db=db_name, user=user, passwd=passwd, host=host, port=port)
+    connection = MySQLdb.connect(
+        db=db_name, user=user, passwd=passwd, host=host, port=port,
+        autocommit=False  # Required for roll-backs
+    )
+    try:
+        yield connection
+        connection.close()
+    except Exception:
+        connection.rollback()
+        raise
+    else:
+        connection.commit()
+    finally:
+        connection.close()