aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/auth/db.py44
1 files changed, 42 insertions, 2 deletions
diff --git a/gn3/auth/db.py b/gn3/auth/db.py
index e732a03..8760153 100644
--- a/gn3/auth/db.py
+++ b/gn3/auth/db.py
@@ -1,9 +1,49 @@
"""Handle connection to auth database."""
import sqlite3
import contextlib
+from typing import Any, Iterator, Protocol
+
+class DbConnection(Protocol):
+ """Type annotation for a generic database connection object."""
+ def cursor(self) -> Any:
+ """A cursor object"""
+ ...
+
+ def commit(self) -> Any:
+ """Commit the transaction."""
+ ...
+
+ def rollback(self) -> Any:
+ """Rollback the transaction."""
+ ...
+
+class DbCursor(Protocol):
+ """Type annotation for a generic database cursor object."""
+ def execute(self, *args, **kwargs) -> Any:
+ """Execute a single query"""
+ ...
+
+ def executemany(self, *args, **kwargs) -> Any:
+ """
+ Execute parameterized SQL statement sql against all parameter sequences
+ or mappings found in the sequence parameters.
+ """
+ ...
+
+ def fetchone(self, *args, **kwargs):
+ """Fetch single result if present, or `None`."""
+ ...
+
+ def fetchmany(self, *args, **kwargs):
+ """Fetch many results if present or `None`."""
+ ...
+
+ def fetchall(self, *args, **kwargs):
+ """Fetch all results if present or `None`."""
+ ...
@contextlib.contextmanager
-def connection(db_path: str):
+def connection(db_path: str) -> Iterator[DbConnection]:
"""Create the connection to the auth database."""
conn = sqlite3.connect(db_path)
try:
@@ -15,7 +55,7 @@ def connection(db_path: str):
conn.close()
@contextlib.contextmanager
-def cursor(conn):
+def cursor(conn: DbConnection) -> Iterator[DbCursor]:
"""Get a cursor from the given connection to the auth database."""
cur = conn.cursor()
try: