aboutsummaryrefslogtreecommitdiff
path: root/gn3
diff options
context:
space:
mode:
authorArun Isaac2022-10-18 16:25:36 +0530
committerArun Isaac2022-10-18 16:26:43 +0530
commitacc97e523ecdc6938b5744d4de4781e5d65f6d25 (patch)
tree9767b1e9fb0f6bdd495d37620384253d40adc21c /gn3
parentd7aa54361c1f7e6acb222f3ff65576d14eaff916 (diff)
downloadgenenetwork3-acc97e523ecdc6938b5744d4de4781e5d65f6d25.tar.gz
Add database connection context manager.
* gn3/db_utils.py: Import contextlib. Import Any, Iterator, Protocol and Tuple from typing. (database_connector): Deprecate function. (Connection): New class. (database_connection): New function.
Diffstat (limited to 'gn3')
-rw-r--r--gn3/db_utils.py24
1 files changed, 23 insertions, 1 deletions
diff --git a/gn3/db_utils.py b/gn3/db_utils.py
index 4865131..69e88a5 100644
--- a/gn3/db_utils.py
+++ b/gn3/db_utils.py
@@ -1,5 +1,6 @@
"""module contains all db related stuff"""
-from typing import Tuple
+import contextlib
+from typing import Any, Iterator, Protocol, Tuple
from urllib.parse import urlparse
import MySQLdb as mdb
from gn3.settings import SQL_URI
@@ -15,8 +16,29 @@ def parse_db_url() -> Tuple:
parsed_db.path[1:], parsed_db.port)
+# This function is deprecated. Use database_connection instead.
def database_connector() -> mdb.Connection:
"""function to create db connector"""
host, user, passwd, db_name, db_port = parse_db_url()
return mdb.connect(host, user, passwd, db_name, port=(db_port or 3306))
+
+
+class Connection(Protocol):
+ def cursor(self) -> Any:
+ ...
+
+
+@contextlib.contextmanager
+def database_connection() -> Iterator[Connection]:
+ """Connect to MySQL database."""
+ host, user, passwd, db_name, port = parse_db_url()
+ connection = mdb.connect(db=db_name,
+ user=user,
+ passwd=passwd or '',
+ host=host,
+ port=port or 3306)
+ try:
+ yield connection
+ finally:
+ connection.close()