aboutsummaryrefslogtreecommitdiff
path: root/wqflask
diff options
context:
space:
mode:
authorMunyoki Kilyungi2022-09-08 11:19:33 +0300
committerBonfaceKilz2022-09-08 14:26:19 +0300
commit80550706fb28ca197208d428e48ead02285e8499 (patch)
tree9f621aea741923d27ae99fd330236d7c08b001c1 /wqflask
parente5904de6a569eb58ead73ccb4329b0088896c39e (diff)
downloadgenenetwork2-80550706fb28ca197208d428e48ead02285e8499.tar.gz
Add type-hints to database_connection
* wqflask/wqflask/database.py: Import Protocol, Any and Iterator. (Connection): New protocol class for type-hints. (read_from_pyfile): Add type-hints. (sql_uri): Ditto. (database_connection): Ditto.
Diffstat (limited to 'wqflask')
-rw-r--r--wqflask/wqflask/database.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/wqflask/wqflask/database.py b/wqflask/wqflask/database.py
index ec616d07..2fd34078 100644
--- a/wqflask/wqflask/database.py
+++ b/wqflask/wqflask/database.py
@@ -1,15 +1,21 @@
# Module to initialize sqlalchemy with flask
import os
import sys
-from typing import Tuple
+from typing import Tuple, Protocol, Any, Iterator
from urllib.parse import urlparse
import importlib
import contextlib
+#: type: ignore
import MySQLdb
-def read_from_pyfile(pyfile, setting):
+class Connection(Protocol):
+ def cursor(self) -> Any:
+ ...
+
+
+def read_from_pyfile(pyfile: str, setting: str) -> Any:
orig_sys_path = sys.path[:]
sys.path.insert(0, os.path.dirname(pyfile))
module = importlib.import_module(os.path.basename(pyfile).strip(".py"))
@@ -17,7 +23,7 @@ def read_from_pyfile(pyfile, setting):
return module.__dict__.get(setting)
-def sql_uri():
+def sql_uri() -> str:
"""Read the SQL_URI from the environment or settings file."""
return os.environ.get(
"SQL_URI", read_from_pyfile(
@@ -25,6 +31,7 @@ def sql_uri():
"GN2_SETTINGS", os.path.abspath("../etc/default_settings.py")),
"SQL_URI"))
+
def parse_db_url(sql_uri: str) -> Tuple:
"""
Parse SQL_URI env variable from an sql URI
@@ -37,7 +44,7 @@ def parse_db_url(sql_uri: str) -> Tuple:
@contextlib.contextmanager
-def database_connection():
+def database_connection() -> Iterator[Connection]:
"""Provide a context manager for opening, closing, and rolling
back - if supported - a database connection. Should an error occur,
and if the table supports transactions, the connection will be