about summary refs log tree commit diff
"""This module deals with connections to a(n) SQLite3 database."""
import logging
import traceback
import contextlib
from typing import Callable, Iterator, Any

import sqlite3

from .protocols import DbCursor, DbConnection

_logger_ = logging.getLogger(__name__)


@contextlib.contextmanager
def connection(db_path: str, row_factory: Callable = sqlite3.Row) -> Iterator[DbConnection]:
    """Create the connection to the auth database."""
    logging.debug("SQLite3 DB Path: '%s'.", db_path)
    conn = sqlite3.connect(db_path)
    conn.row_factory = row_factory
    conn.set_trace_callback(logging.debug)
    conn.execute("PRAGMA foreign_keys = ON")
    try:
        yield conn
    except sqlite3.Error as exc:
        conn.rollback()
        _logger_.debug(traceback.format_exc())
        raise exc
    finally:
        conn.commit()
        conn.close()


@contextlib.contextmanager
def cursor(conn: DbConnection) -> Iterator[DbCursor]:
    """Get a cursor from the given connection to the auth database."""
    cur = conn.cursor()
    try:
        yield cur
        conn.commit()
    except sqlite3.Error as exc:
        conn.rollback()
        _logger_.debug(traceback.format_exc())
        raise exc
    finally:
        cur.close()


def with_db_connection(db_uri: str, func: Callable[[DbConnection], Any]) -> Any:
    """
    Call `func`, a function of one argument with the SQLite3 connection created
    from the connection string `db_uri`.
    """
    with connection(db_uri) as conn:
        return func(conn)