diff options
Diffstat (limited to 'gn_libs')
-rw-r--r-- | gn_libs/sqlite3.py | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/gn_libs/sqlite3.py b/gn_libs/sqlite3.py new file mode 100644 index 0000000..1dcdf29 --- /dev/null +++ b/gn_libs/sqlite3.py @@ -0,0 +1,44 @@ +import logging +import traceback +import contextlib +from typing import Any, Protocol, Callable, Iterator + +import sqlite3 + +from .protocols import DbCursor, DbConnection + +_logger_ = logging.getLogger(__name__) + + +@contextlib.contextmanager +def connection(db_path: str, row_factory: Callable = sqlite3.Row) -> Iterator[DbConnection]: + """Create the connection to the auth database.""" + logging.debug("SQLite3 DB Path: '%s'.", db_path) + conn = sqlite3.connect(db_path) + conn.row_factory = row_factory + conn.set_trace_callback(logging.debug) + conn.execute("PRAGMA foreign_keys = ON") + try: + yield conn + except sqlite3.Error as exc: + conn.rollback() + _logger_.debug(traceback.format_exc()) + raise exc + finally: + conn.commit() + conn.close() + + +@contextlib.contextmanager +def cursor(conn: DbConnection) -> Iterator[DbCursor]: + """Get a cursor from the given connection to the auth database.""" + cur = conn.cursor() + try: + yield cur + conn.commit() + except sqlite3.Error as exc: + conn.rollback() + _logger_.debug(traceback.format_exc()) + raise exc + finally: + cur.close() |