diff options
Diffstat (limited to 'gn3')
-rw-r--r-- | gn3/auth/db.py | 44 |
1 files changed, 42 insertions, 2 deletions
diff --git a/gn3/auth/db.py b/gn3/auth/db.py index e732a03..8760153 100644 --- a/gn3/auth/db.py +++ b/gn3/auth/db.py @@ -1,9 +1,49 @@ """Handle connection to auth database.""" import sqlite3 import contextlib +from typing import Any, Iterator, Protocol + +class DbConnection(Protocol): + """Type annotation for a generic database connection object.""" + def cursor(self) -> Any: + """A cursor object""" + ... + + def commit(self) -> Any: + """Commit the transaction.""" + ... + + def rollback(self) -> Any: + """Rollback the transaction.""" + ... + +class DbCursor(Protocol): + """Type annotation for a generic database cursor object.""" + def execute(self, *args, **kwargs) -> Any: + """Execute a single query""" + ... + + def executemany(self, *args, **kwargs) -> Any: + """ + Execute parameterized SQL statement sql against all parameter sequences + or mappings found in the sequence parameters. + """ + ... + + def fetchone(self, *args, **kwargs): + """Fetch single result if present, or `None`.""" + ... + + def fetchmany(self, *args, **kwargs): + """Fetch many results if present or `None`.""" + ... + + def fetchall(self, *args, **kwargs): + """Fetch all results if present or `None`.""" + ... @contextlib.contextmanager -def connection(db_path: str): +def connection(db_path: str) -> Iterator[DbConnection]: """Create the connection to the auth database.""" conn = sqlite3.connect(db_path) try: @@ -15,7 +55,7 @@ def connection(db_path: str): conn.close() @contextlib.contextmanager -def cursor(conn): +def cursor(conn: DbConnection) -> Iterator[DbCursor]: """Get a cursor from the given connection to the auth database.""" cur = conn.cursor() try: |