diff options
Diffstat (limited to 'gn3/db_utils.py')
-rw-r--r-- | gn3/db_utils.py | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/gn3/db_utils.py b/gn3/db_utils.py index 4865131..69e88a5 100644 --- a/gn3/db_utils.py +++ b/gn3/db_utils.py @@ -1,5 +1,6 @@ """module contains all db related stuff""" -from typing import Tuple +import contextlib +from typing import Any, Iterator, Protocol, Tuple from urllib.parse import urlparse import MySQLdb as mdb from gn3.settings import SQL_URI @@ -15,8 +16,29 @@ def parse_db_url() -> Tuple: parsed_db.path[1:], parsed_db.port) +# This function is deprecated. Use database_connection instead. def database_connector() -> mdb.Connection: """function to create db connector""" host, user, passwd, db_name, db_port = parse_db_url() return mdb.connect(host, user, passwd, db_name, port=(db_port or 3306)) + + +class Connection(Protocol): + def cursor(self) -> Any: + ... + + +@contextlib.contextmanager +def database_connection() -> Iterator[Connection]: + """Connect to MySQL database.""" + host, user, passwd, db_name, port = parse_db_url() + connection = mdb.connect(db=db_name, + user=user, + passwd=passwd or '', + host=host, + port=port or 3306) + try: + yield connection + finally: + connection.close() |