about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-11-15 06:21:38 +0300
committerFrederick Muriuki Muriithi2022-11-15 06:21:38 +0300
commit3dc9a8a4f413d142e84a81f9c1abafedb779d7dd (patch)
tree92e39a964446dcda1d2871ee5f4bdbbc886925f2
parentb1ee0958815cbb7265d2c5ea3a8374b532054f3b (diff)
downloadgenenetwork3-3dc9a8a4f413d142e84a81f9c1abafedb779d7dd.tar.gz
auth: Add some typing information to the functions
-rw-r--r--gn3/auth/db.py44
1 files changed, 42 insertions, 2 deletions
diff --git a/gn3/auth/db.py b/gn3/auth/db.py
index e732a03..8760153 100644
--- a/gn3/auth/db.py
+++ b/gn3/auth/db.py
@@ -1,9 +1,49 @@
 """Handle connection to auth database."""
 import sqlite3
 import contextlib
+from typing import Any, Iterator, Protocol
+
+class DbConnection(Protocol):
+    """Type annotation for a generic database connection object."""
+    def cursor(self) -> Any:
+        """A cursor object"""
+        ...
+
+    def commit(self) -> Any:
+        """Commit the transaction."""
+        ...
+
+    def rollback(self) -> Any:
+        """Rollback the transaction."""
+        ...
+
+class DbCursor(Protocol):
+    """Type annotation for a generic database cursor object."""
+    def execute(self, *args, **kwargs) -> Any:
+        """Execute a single query"""
+        ...
+
+    def executemany(self, *args, **kwargs) -> Any:
+        """
+        Execute parameterized SQL statement sql against all parameter sequences
+        or mappings found in the sequence parameters.
+        """
+        ...
+
+    def fetchone(self, *args, **kwargs):
+        """Fetch single result if present, or `None`."""
+        ...
+
+    def fetchmany(self, *args, **kwargs):
+        """Fetch many results if present or `None`."""
+        ...
+
+    def fetchall(self, *args, **kwargs):
+        """Fetch all results if present or `None`."""
+        ...
 
 @contextlib.contextmanager
-def connection(db_path: str):
+def connection(db_path: str) -> Iterator[DbConnection]:
     """Create the connection to the auth database."""
     conn = sqlite3.connect(db_path)
     try:
@@ -15,7 +55,7 @@ def connection(db_path: str):
         conn.close()
 
 @contextlib.contextmanager
-def cursor(conn):
+def cursor(conn: DbConnection) -> Iterator[DbCursor]:
     """Get a cursor from the given connection to the auth database."""
     cur = conn.cursor()
     try: