From acc97e523ecdc6938b5744d4de4781e5d65f6d25 Mon Sep 17 00:00:00 2001 From: Arun Isaac Date: Tue, 18 Oct 2022 16:25:36 +0530 Subject: Add database connection context manager. * gn3/db_utils.py: Import contextlib. Import Any, Iterator, Protocol and Tuple from typing. (database_connector): Deprecate function. (Connection): New class. (database_connection): New function. --- gn3/db_utils.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/gn3/db_utils.py b/gn3/db_utils.py index 4865131..69e88a5 100644 --- a/gn3/db_utils.py +++ b/gn3/db_utils.py @@ -1,5 +1,6 @@ """module contains all db related stuff""" -from typing import Tuple +import contextlib +from typing import Any, Iterator, Protocol, Tuple from urllib.parse import urlparse import MySQLdb as mdb from gn3.settings import SQL_URI @@ -15,8 +16,29 @@ def parse_db_url() -> Tuple: parsed_db.path[1:], parsed_db.port) +# This function is deprecated. Use database_connection instead. def database_connector() -> mdb.Connection: """function to create db connector""" host, user, passwd, db_name, db_port = parse_db_url() return mdb.connect(host, user, passwd, db_name, port=(db_port or 3306)) + + +class Connection(Protocol): + def cursor(self) -> Any: + ... + + +@contextlib.contextmanager +def database_connection() -> Iterator[Connection]: + """Connect to MySQL database.""" + host, user, passwd, db_name, port = parse_db_url() + connection = mdb.connect(db=db_name, + user=user, + passwd=passwd or '', + host=host, + port=port or 3306) + try: + yield connection + finally: + connection.close() -- cgit v1.2.3