aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/auth/db/sqlite3.py
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2023-08-07 07:47:01 +0300
committerFrederick Muriuki Muriithi2023-08-07 09:26:12 +0300
commit6ab6d46ab4b1611ed72bdbce85cf9324ce69b305 (patch)
tree3d10ba6514e594cf3add2086c6668c891b8cedae /gn_auth/auth/db/sqlite3.py
parente5cf3178743260e5003f3a9becf025c154204ccd (diff)
downloadgn-auth-6ab6d46ab4b1611ed72bdbce85cf9324ce69b305.tar.gz
Collect db-connections function in single module.
Diffstat (limited to 'gn_auth/auth/db/sqlite3.py')
-rw-r--r--gn_auth/auth/db/sqlite3.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/gn_auth/auth/db/sqlite3.py b/gn_auth/auth/db/sqlite3.py
new file mode 100644
index 0000000..3d94832
--- /dev/null
+++ b/gn_auth/auth/db/sqlite3.py
@@ -0,0 +1,50 @@
+"""Handle connection to auth database."""
+import sqlite3
+import logging
+import contextlib
+from typing import Any, Callable, Iterator
+
+import traceback
+
+from .protocols import DbCursor, DbConnection
+
+@contextlib.contextmanager
+def connection(db_path: str, row_factory: Callable = sqlite3.Row) -> Iterator[DbConnection]:
+ """Create the connection to the auth database."""
+ logging.debug("SQLite3 DB Path: '%s'.", db_path)
+ conn = sqlite3.connect(db_path)
+ conn.row_factory = row_factory
+ conn.set_trace_callback(logging.debug)
+ conn.execute("PRAGMA foreign_keys = ON")
+ try:
+ yield conn
+ except sqlite3.Error as exc:
+ conn.rollback()
+ logging.debug(traceback.format_exc())
+ raise exc
+ finally:
+ conn.commit()
+ conn.close()
+
+@contextlib.contextmanager
+def cursor(conn: DbConnection) -> Iterator[DbCursor]:
+ """Get a cursor from the given connection to the auth database."""
+ cur = conn.cursor()
+ try:
+ yield cur
+ except sqlite3.Error as exc:
+ conn.rollback()
+ logging.debug(traceback.format_exc())
+ raise exc
+ finally:
+ conn.commit()
+ cur.close()
+
+def with_db_connection(func: Callable[[DbConnection], Any]) -> Any:
+ """
+ Takes a function of one argument `func`, whose one argument is a database
+ connection.
+ """
+ db_uri = current_app.config["AUTH_DB"]
+ with connection(db_uri) as conn:
+ return func(conn)