"""This module deals with connections to a(n) SQLite3 database.""" import logging import traceback import contextlib from typing import Callable, Iterator, Any import sqlite3 from .protocols import DbCursor, DbConnection _logger_ = logging.getLogger(__name__) @contextlib.contextmanager def connection(db_path: str, row_factory: Callable = sqlite3.Row) -> Iterator[DbConnection]: """Create the connection to the auth database.""" logging.debug("SQLite3 DB Path: '%s'.", db_path) conn = sqlite3.connect(db_path) conn.row_factory = row_factory conn.set_trace_callback(logging.debug) conn.execute("PRAGMA foreign_keys = ON") try: yield conn except sqlite3.Error as exc: conn.rollback() _logger_.debug(traceback.format_exc()) raise exc finally: conn.commit() conn.close() @contextlib.contextmanager def cursor(conn: DbConnection) -> Iterator[DbCursor]: """Get a cursor from the given connection to the auth database.""" cur = conn.cursor() try: yield cur conn.commit() except sqlite3.Error as exc: conn.rollback() _logger_.debug(traceback.format_exc()) raise exc finally: cur.close() def with_db_connection(db_uri: str, func: Callable[[DbConnection], Any]) -> Any: """ Call `func`, a function of one argument with the SQLite3 connection created from the connection string `db_uri`. """ with connection(db_uri) as conn: return func(conn)