diff options
-rw-r--r-- | gn_auth/auth/db.py | 78 | ||||
-rw-r--r-- | gn_auth/auth/db/__init__.py | 1 | ||||
-rw-r--r-- | gn_auth/auth/db/mariadb.py | 26 | ||||
-rw-r--r-- | gn_auth/auth/db/protocols.py | 41 | ||||
-rw-r--r-- | gn_auth/auth/db/sqlite3.py | 50 | ||||
-rw-r--r-- | gn_auth/auth/db_utils.py | 14 |
6 files changed, 118 insertions, 92 deletions
diff --git a/gn_auth/auth/db.py b/gn_auth/auth/db.py deleted file mode 100644 index 2ba6619..0000000 --- a/gn_auth/auth/db.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Handle connection to auth database.""" -import sqlite3 -import logging -import contextlib -from typing import Any, Callable, Iterator, Protocol - -import traceback - -class DbConnection(Protocol): - """Type annotation for a generic database connection object.""" - def cursor(self) -> Any: - """A cursor object""" - ... - - def commit(self) -> Any: - """Commit the transaction.""" - ... - - def rollback(self) -> Any: - """Rollback the transaction.""" - ... - -class DbCursor(Protocol): - """Type annotation for a generic database cursor object.""" - def execute(self, *args, **kwargs) -> Any: - """Execute a single query""" - ... - - def executemany(self, *args, **kwargs) -> Any: - """ - Execute parameterized SQL statement sql against all parameter sequences - or mappings found in the sequence parameters. - """ - ... - - def fetchone(self, *args, **kwargs): - """Fetch single result if present, or `None`.""" - ... - - def fetchmany(self, *args, **kwargs): - """Fetch many results if present or `None`.""" - ... - - def fetchall(self, *args, **kwargs): - """Fetch all results if present or `None`.""" - ... - -@contextlib.contextmanager -def connection(db_path: str, row_factory: Callable = sqlite3.Row) -> Iterator[DbConnection]: - """Create the connection to the auth database.""" - logging.debug("SQLite3 DB Path: '%s'.", db_path) - conn = sqlite3.connect(db_path) - conn.row_factory = row_factory - conn.set_trace_callback(logging.debug) - conn.execute("PRAGMA foreign_keys = ON") - try: - yield conn - except sqlite3.Error as exc: - conn.rollback() - logging.debug(traceback.format_exc()) - raise exc - finally: - conn.commit() - conn.close() - -@contextlib.contextmanager -def cursor(conn: DbConnection) -> Iterator[DbCursor]: - """Get a cursor from the given connection to the auth database.""" - cur = conn.cursor() - try: - yield cur - except sqlite3.Error as exc: - conn.rollback() - logging.debug(traceback.format_exc()) - raise exc - finally: - conn.commit() - cur.close() diff --git a/gn_auth/auth/db/__init__.py b/gn_auth/auth/db/__init__.py new file mode 100644 index 0000000..eab58ef --- /dev/null +++ b/gn_auth/auth/db/__init__.py @@ -0,0 +1 @@ +from .protocols import DbCursor, DbConnection diff --git a/gn_auth/auth/db/mariadb.py b/gn_auth/auth/db/mariadb.py new file mode 100644 index 0000000..a934fd9 --- /dev/null +++ b/gn_auth/auth/db/mariadb.py @@ -0,0 +1,26 @@ +"""Connections to MariaDB""" +import traceback +import contextlib +from typing import Iterator + +import MySQLdb as mdb + +from .protocols import DbConnection + +@contextlib.contextmanager +def database_connection(sql_uri) -> Iterator[DbConnection]: + """Connect to MySQL database.""" + host, user, passwd, db_name, port = parse_db_url(sql_uri) + connection = mdb.connect(db=db_name, + user=user, + passwd=passwd or '', + host=host, + port=port or 3306) + try: + yield connection + except Exception as _exc: # TODO: Make the Exception class less general + logging.debug(traceback.format_exc()) + connection.rollback() + finally: + connection.commit() + connection.close() diff --git a/gn_auth/auth/db/protocols.py b/gn_auth/auth/db/protocols.py new file mode 100644 index 0000000..c089cfe --- /dev/null +++ b/gn_auth/auth/db/protocols.py @@ -0,0 +1,41 @@ +"""Common Database connection protocols.""" +from typing import Any, Protocol + +class DbConnection(Protocol): + """Type annotation for a generic database connection object.""" + def cursor(self, *args, **kwargs) -> Any: + """A cursor object""" + ... + + def commit(self) -> Any: + """Commit the transaction.""" + ... + + def rollback(self) -> Any: + """Rollback the transaction.""" + ... + +class DbCursor(Protocol): + """Type annotation for a generic database cursor object.""" + def execute(self, *args, **kwargs) -> Any: + """Execute a single query""" + ... + + def executemany(self, *args, **kwargs) -> Any: + """ + Execute parameterized SQL statement sql against all parameter sequences + or mappings found in the sequence parameters. + """ + ... + + def fetchone(self, *args, **kwargs): + """Fetch single result if present, or `None`.""" + ... + + def fetchmany(self, *args, **kwargs): + """Fetch many results if present or `None`.""" + ... + + def fetchall(self, *args, **kwargs): + """Fetch all results if present or `None`.""" + ... diff --git a/gn_auth/auth/db/sqlite3.py b/gn_auth/auth/db/sqlite3.py new file mode 100644 index 0000000..3d94832 --- /dev/null +++ b/gn_auth/auth/db/sqlite3.py @@ -0,0 +1,50 @@ +"""Handle connection to auth database.""" +import sqlite3 +import logging +import contextlib +from typing import Any, Callable, Iterator + +import traceback + +from .protocols import DbCursor, DbConnection + +@contextlib.contextmanager +def connection(db_path: str, row_factory: Callable = sqlite3.Row) -> Iterator[DbConnection]: + """Create the connection to the auth database.""" + logging.debug("SQLite3 DB Path: '%s'.", db_path) + conn = sqlite3.connect(db_path) + conn.row_factory = row_factory + conn.set_trace_callback(logging.debug) + conn.execute("PRAGMA foreign_keys = ON") + try: + yield conn + except sqlite3.Error as exc: + conn.rollback() + logging.debug(traceback.format_exc()) + raise exc + finally: + conn.commit() + conn.close() + +@contextlib.contextmanager +def cursor(conn: DbConnection) -> Iterator[DbCursor]: + """Get a cursor from the given connection to the auth database.""" + cur = conn.cursor() + try: + yield cur + except sqlite3.Error as exc: + conn.rollback() + logging.debug(traceback.format_exc()) + raise exc + finally: + conn.commit() + cur.close() + +def with_db_connection(func: Callable[[DbConnection], Any]) -> Any: + """ + Takes a function of one argument `func`, whose one argument is a database + connection. + """ + db_uri = current_app.config["AUTH_DB"] + with connection(db_uri) as conn: + return func(conn) diff --git a/gn_auth/auth/db_utils.py b/gn_auth/auth/db_utils.py deleted file mode 100644 index c06b026..0000000 --- a/gn_auth/auth/db_utils.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Some common auth db utilities""" -from typing import Any, Callable -from flask import current_app - -from . import db - -def with_db_connection(func: Callable[[db.DbConnection], Any]) -> Any: - """ - Takes a function of one argument `func`, whose one argument is a database - connection. - """ - db_uri = current_app.config["AUTH_DB"] - with db.connection(db_uri) as conn: - return func(conn) |