about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2023-08-07 07:47:01 +0300
committerFrederick Muriuki Muriithi2023-08-07 09:26:12 +0300
commit6ab6d46ab4b1611ed72bdbce85cf9324ce69b305 (patch)
tree3d10ba6514e594cf3add2086c6668c891b8cedae
parente5cf3178743260e5003f3a9becf025c154204ccd (diff)
downloadgn-auth-6ab6d46ab4b1611ed72bdbce85cf9324ce69b305.tar.gz
Collect db-connections function in single module.
-rw-r--r--gn_auth/auth/db.py78
-rw-r--r--gn_auth/auth/db/__init__.py1
-rw-r--r--gn_auth/auth/db/mariadb.py26
-rw-r--r--gn_auth/auth/db/protocols.py41
-rw-r--r--gn_auth/auth/db/sqlite3.py50
-rw-r--r--gn_auth/auth/db_utils.py14
6 files changed, 118 insertions, 92 deletions
diff --git a/gn_auth/auth/db.py b/gn_auth/auth/db.py
deleted file mode 100644
index 2ba6619..0000000
--- a/gn_auth/auth/db.py
+++ /dev/null
@@ -1,78 +0,0 @@
-"""Handle connection to auth database."""
-import sqlite3
-import logging
-import contextlib
-from typing import Any, Callable, Iterator, Protocol
-
-import traceback
-
-class DbConnection(Protocol):
-    """Type annotation for a generic database connection object."""
-    def cursor(self) -> Any:
-        """A cursor object"""
-        ...
-
-    def commit(self) -> Any:
-        """Commit the transaction."""
-        ...
-
-    def rollback(self) -> Any:
-        """Rollback the transaction."""
-        ...
-
-class DbCursor(Protocol):
-    """Type annotation for a generic database cursor object."""
-    def execute(self, *args, **kwargs) -> Any:
-        """Execute a single query"""
-        ...
-
-    def executemany(self, *args, **kwargs) -> Any:
-        """
-        Execute parameterized SQL statement sql against all parameter sequences
-        or mappings found in the sequence parameters.
-        """
-        ...
-
-    def fetchone(self, *args, **kwargs):
-        """Fetch single result if present, or `None`."""
-        ...
-
-    def fetchmany(self, *args, **kwargs):
-        """Fetch many results if present or `None`."""
-        ...
-
-    def fetchall(self, *args, **kwargs):
-        """Fetch all results if present or `None`."""
-        ...
-
-@contextlib.contextmanager
-def connection(db_path: str, row_factory: Callable = sqlite3.Row) -> Iterator[DbConnection]:
-    """Create the connection to the auth database."""
-    logging.debug("SQLite3 DB Path: '%s'.", db_path)
-    conn = sqlite3.connect(db_path)
-    conn.row_factory = row_factory
-    conn.set_trace_callback(logging.debug)
-    conn.execute("PRAGMA foreign_keys = ON")
-    try:
-        yield conn
-    except sqlite3.Error as exc:
-        conn.rollback()
-        logging.debug(traceback.format_exc())
-        raise exc
-    finally:
-        conn.commit()
-        conn.close()
-
-@contextlib.contextmanager
-def cursor(conn: DbConnection) -> Iterator[DbCursor]:
-    """Get a cursor from the given connection to the auth database."""
-    cur = conn.cursor()
-    try:
-        yield cur
-    except sqlite3.Error as exc:
-        conn.rollback()
-        logging.debug(traceback.format_exc())
-        raise exc
-    finally:
-        conn.commit()
-        cur.close()
diff --git a/gn_auth/auth/db/__init__.py b/gn_auth/auth/db/__init__.py
new file mode 100644
index 0000000..eab58ef
--- /dev/null
+++ b/gn_auth/auth/db/__init__.py
@@ -0,0 +1 @@
+from .protocols import DbCursor, DbConnection
diff --git a/gn_auth/auth/db/mariadb.py b/gn_auth/auth/db/mariadb.py
new file mode 100644
index 0000000..a934fd9
--- /dev/null
+++ b/gn_auth/auth/db/mariadb.py
@@ -0,0 +1,26 @@
+"""Connections to MariaDB"""
+import traceback
+import contextlib
+from typing import Iterator
+
+import MySQLdb as mdb
+
+from .protocols import DbConnection
+
+@contextlib.contextmanager
+def database_connection(sql_uri) -> Iterator[DbConnection]:
+    """Connect to MySQL database."""
+    host, user, passwd, db_name, port = parse_db_url(sql_uri)
+    connection = mdb.connect(db=db_name,
+                             user=user,
+                             passwd=passwd or '',
+                             host=host,
+                             port=port or 3306)
+    try:
+        yield connection
+    except Exception as _exc: # TODO: Make the Exception class less general
+        logging.debug(traceback.format_exc())
+        connection.rollback()
+    finally:
+        connection.commit()
+        connection.close()
diff --git a/gn_auth/auth/db/protocols.py b/gn_auth/auth/db/protocols.py
new file mode 100644
index 0000000..c089cfe
--- /dev/null
+++ b/gn_auth/auth/db/protocols.py
@@ -0,0 +1,41 @@
+"""Common Database connection protocols."""
+from typing import Any, Protocol
+
+class DbConnection(Protocol):
+    """Type annotation for a generic database connection object."""
+    def cursor(self, *args, **kwargs) -> Any:
+        """A cursor object"""
+        ...
+
+    def commit(self) -> Any:
+        """Commit the transaction."""
+        ...
+
+    def rollback(self) -> Any:
+        """Rollback the transaction."""
+        ...
+
+class DbCursor(Protocol):
+    """Type annotation for a generic database cursor object."""
+    def execute(self, *args, **kwargs) -> Any:
+        """Execute a single query"""
+        ...
+
+    def executemany(self, *args, **kwargs) -> Any:
+        """
+        Execute parameterized SQL statement sql against all parameter sequences
+        or mappings found in the sequence parameters.
+        """
+        ...
+
+    def fetchone(self, *args, **kwargs):
+        """Fetch single result if present, or `None`."""
+        ...
+
+    def fetchmany(self, *args, **kwargs):
+        """Fetch many results if present or `None`."""
+        ...
+
+    def fetchall(self, *args, **kwargs):
+        """Fetch all results if present or `None`."""
+        ...
diff --git a/gn_auth/auth/db/sqlite3.py b/gn_auth/auth/db/sqlite3.py
new file mode 100644
index 0000000..3d94832
--- /dev/null
+++ b/gn_auth/auth/db/sqlite3.py
@@ -0,0 +1,50 @@
+"""Handle connection to auth database."""
+import sqlite3
+import logging
+import contextlib
+from typing import Any, Callable, Iterator
+
+import traceback
+
+from .protocols import DbCursor, DbConnection
+
+@contextlib.contextmanager
+def connection(db_path: str, row_factory: Callable = sqlite3.Row) -> Iterator[DbConnection]:
+    """Create the connection to the auth database."""
+    logging.debug("SQLite3 DB Path: '%s'.", db_path)
+    conn = sqlite3.connect(db_path)
+    conn.row_factory = row_factory
+    conn.set_trace_callback(logging.debug)
+    conn.execute("PRAGMA foreign_keys = ON")
+    try:
+        yield conn
+    except sqlite3.Error as exc:
+        conn.rollback()
+        logging.debug(traceback.format_exc())
+        raise exc
+    finally:
+        conn.commit()
+        conn.close()
+
+@contextlib.contextmanager
+def cursor(conn: DbConnection) -> Iterator[DbCursor]:
+    """Get a cursor from the given connection to the auth database."""
+    cur = conn.cursor()
+    try:
+        yield cur
+    except sqlite3.Error as exc:
+        conn.rollback()
+        logging.debug(traceback.format_exc())
+        raise exc
+    finally:
+        conn.commit()
+        cur.close()
+
+def with_db_connection(func: Callable[[DbConnection], Any]) -> Any:
+    """
+    Takes a function of one argument `func`, whose one argument is a database
+    connection.
+    """
+    db_uri = current_app.config["AUTH_DB"]
+    with connection(db_uri) as conn:
+        return func(conn)
diff --git a/gn_auth/auth/db_utils.py b/gn_auth/auth/db_utils.py
deleted file mode 100644
index c06b026..0000000
--- a/gn_auth/auth/db_utils.py
+++ /dev/null
@@ -1,14 +0,0 @@
-"""Some common auth db utilities"""
-from typing import Any, Callable
-from flask import current_app
-
-from . import db
-
-def with_db_connection(func: Callable[[db.DbConnection], Any]) -> Any:
-    """
-    Takes a function of one argument `func`, whose one argument is a database
-    connection.
-    """
-    db_uri = current_app.config["AUTH_DB"]
-    with db.connection(db_uri) as conn:
-        return func(conn)