From 86b81be8c67cf82e8f26a7f84cbdc5df7198e214 Mon Sep 17 00:00:00 2001 From: John Nduli Date: Fri, 18 Oct 2024 15:06:36 +0300 Subject: refactor: replace gn3.auth.db with gn3.sqlite_db_utils and drop all refs to gn3.auth --- gn3/api/llm.py | 2 +- gn3/auth/__init__.py | 1 - gn3/auth/db.py | 72 -------------------------------------------------- gn3/auth/db_utils.py | 14 ---------- gn3/auth/dictify.py | 12 --------- gn3/sqlite_db_utils.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++ tests/unit/conftest.py | 2 +- 7 files changed, 74 insertions(+), 101 deletions(-) delete mode 100644 gn3/auth/__init__.py delete mode 100644 gn3/auth/db.py delete mode 100644 gn3/auth/db_utils.py delete mode 100644 gn3/auth/dictify.py create mode 100644 gn3/sqlite_db_utils.py diff --git a/gn3/api/llm.py b/gn3/api/llm.py index 9a44440..d6cd737 100644 --- a/gn3/api/llm.py +++ b/gn3/api/llm.py @@ -12,7 +12,7 @@ from gn3.llms.process import get_gnqa from gn3.llms.errors import LLMError from gn3.oauth2.authorisation import require_token -from gn3.auth import db +from gn3 import sqlite_db_utils as db gnqa = Blueprint("gnqa", __name__) diff --git a/gn3/auth/__init__.py b/gn3/auth/__init__.py deleted file mode 100644 index d9caec9..0000000 --- a/gn3/auth/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Top-Level `Auth` module""" diff --git a/gn3/auth/db.py b/gn3/auth/db.py deleted file mode 100644 index 5cd230f..0000000 --- a/gn3/auth/db.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Handle connection to auth database.""" -import sqlite3 -import logging -import contextlib -from typing import Any, Callable, Iterator, Protocol - -import traceback - -class DbConnection(Protocol): - """Type annotation for a generic database connection object.""" - def cursor(self) -> Any: - """A cursor object""" - - def commit(self) -> Any: - """Commit the transaction.""" - - def rollback(self) -> Any: - """Rollback the transaction.""" - - -class DbCursor(Protocol): - """Type annotation for a generic database cursor object.""" - def execute(self, *args, **kwargs) -> Any: - """Execute a single query""" - - def executemany(self, *args, **kwargs) -> Any: - """ - Execute parameterized SQL statement sql against all parameter sequences - or mappings found in the sequence parameters. - """ - - def fetchone(self, *args, **kwargs): - """Fetch single result if present, or `None`.""" - - def fetchmany(self, *args, **kwargs): - """Fetch many results if present or `None`.""" - - def fetchall(self, *args, **kwargs): - """Fetch all results if present or `None`.""" - - -@contextlib.contextmanager -def connection(db_path: str, row_factory: Callable = sqlite3.Row) -> Iterator[DbConnection]: - """Create the connection to the auth database.""" - logging.debug("SQLite3 DB Path: '%s'.", db_path) - conn = sqlite3.connect(db_path) - conn.row_factory = row_factory - conn.set_trace_callback(logging.debug) - conn.execute("PRAGMA foreign_keys = ON") - try: - yield conn - except sqlite3.Error as exc: - conn.rollback() - logging.debug(traceback.format_exc()) - raise exc - finally: - conn.commit() - conn.close() - -@contextlib.contextmanager -def cursor(conn: DbConnection) -> Iterator[DbCursor]: - """Get a cursor from the given connection to the auth database.""" - cur = conn.cursor() - try: - yield cur - except sqlite3.Error as exc: - conn.rollback() - logging.debug(traceback.format_exc()) - raise exc - finally: - conn.commit() - cur.close() diff --git a/gn3/auth/db_utils.py b/gn3/auth/db_utils.py deleted file mode 100644 index c06b026..0000000 --- a/gn3/auth/db_utils.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Some common auth db utilities""" -from typing import Any, Callable -from flask import current_app - -from . import db - -def with_db_connection(func: Callable[[db.DbConnection], Any]) -> Any: - """ - Takes a function of one argument `func`, whose one argument is a database - connection. - """ - db_uri = current_app.config["AUTH_DB"] - with db.connection(db_uri) as conn: - return func(conn) diff --git a/gn3/auth/dictify.py b/gn3/auth/dictify.py deleted file mode 100644 index f9337f6..0000000 --- a/gn3/auth/dictify.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Module for dictifying objects""" - -from typing import Any, Protocol - -class Dictifiable(Protocol):# pylint: disable=[too-few-public-methods] - """Type annotation for generic object with a `dictify` method.""" - def dictify(self): - """Convert the object to a dict""" - -def dictify(obj: Dictifiable) -> dict[str, Any]: - """Turn `obj` to a dict representation.""" - return obj.dictify() diff --git a/gn3/sqlite_db_utils.py b/gn3/sqlite_db_utils.py new file mode 100644 index 0000000..5cd230f --- /dev/null +++ b/gn3/sqlite_db_utils.py @@ -0,0 +1,72 @@ +"""Handle connection to auth database.""" +import sqlite3 +import logging +import contextlib +from typing import Any, Callable, Iterator, Protocol + +import traceback + +class DbConnection(Protocol): + """Type annotation for a generic database connection object.""" + def cursor(self) -> Any: + """A cursor object""" + + def commit(self) -> Any: + """Commit the transaction.""" + + def rollback(self) -> Any: + """Rollback the transaction.""" + + +class DbCursor(Protocol): + """Type annotation for a generic database cursor object.""" + def execute(self, *args, **kwargs) -> Any: + """Execute a single query""" + + def executemany(self, *args, **kwargs) -> Any: + """ + Execute parameterized SQL statement sql against all parameter sequences + or mappings found in the sequence parameters. + """ + + def fetchone(self, *args, **kwargs): + """Fetch single result if present, or `None`.""" + + def fetchmany(self, *args, **kwargs): + """Fetch many results if present or `None`.""" + + def fetchall(self, *args, **kwargs): + """Fetch all results if present or `None`.""" + + +@contextlib.contextmanager +def connection(db_path: str, row_factory: Callable = sqlite3.Row) -> Iterator[DbConnection]: + """Create the connection to the auth database.""" + logging.debug("SQLite3 DB Path: '%s'.", db_path) + conn = sqlite3.connect(db_path) + conn.row_factory = row_factory + conn.set_trace_callback(logging.debug) + conn.execute("PRAGMA foreign_keys = ON") + try: + yield conn + except sqlite3.Error as exc: + conn.rollback() + logging.debug(traceback.format_exc()) + raise exc + finally: + conn.commit() + conn.close() + +@contextlib.contextmanager +def cursor(conn: DbConnection) -> Iterator[DbCursor]: + """Get a cursor from the given connection to the auth database.""" + cur = conn.cursor() + try: + yield cur + except sqlite3.Error as exc: + conn.rollback() + logging.debug(traceback.format_exc()) + raise exc + finally: + conn.commit() + cur.close() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 8005c8e..d9d5492 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -15,7 +15,7 @@ def fxtr_app(): testdb = Path(testdir).joinpath( f'testdb_{datetime.now().strftime("%Y%m%dT%H%M%S")}') app = create_app({ - "TESTING": True, "AUTH_DB": testdb, + "TESTING": True, "OAUTH2_ACCESS_TOKEN_GENERATOR": "tests.unit.auth.test_token.gen_token" }) app.testing = True -- cgit 1.4.1