aboutsummaryrefslogtreecommitdiff
path: root/gn3/db_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/db_utils.py')
-rw-r--r--gn3/db_utils.py24
1 files changed, 23 insertions, 1 deletions
diff --git a/gn3/db_utils.py b/gn3/db_utils.py
index 4865131..69e88a5 100644
--- a/gn3/db_utils.py
+++ b/gn3/db_utils.py
@@ -1,5 +1,6 @@
"""module contains all db related stuff"""
-from typing import Tuple
+import contextlib
+from typing import Any, Iterator, Protocol, Tuple
from urllib.parse import urlparse
import MySQLdb as mdb
from gn3.settings import SQL_URI
@@ -15,8 +16,29 @@ def parse_db_url() -> Tuple:
parsed_db.path[1:], parsed_db.port)
+# This function is deprecated. Use database_connection instead.
def database_connector() -> mdb.Connection:
"""function to create db connector"""
host, user, passwd, db_name, db_port = parse_db_url()
return mdb.connect(host, user, passwd, db_name, port=(db_port or 3306))
+
+
+class Connection(Protocol):
+ def cursor(self) -> Any:
+ ...
+
+
+@contextlib.contextmanager
+def database_connection() -> Iterator[Connection]:
+ """Connect to MySQL database."""
+ host, user, passwd, db_name, port = parse_db_url()
+ connection = mdb.connect(db=db_name,
+ user=user,
+ passwd=passwd or '',
+ host=host,
+ port=port or 3306)
+ try:
+ yield connection
+ finally:
+ connection.close()