aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-03-26 14:07:18 +0300
committerFrederick Muriuki Muriithi2022-03-26 14:47:56 +0300
commita0b150cf2eaed00ca56634da404e621eb6e73484 (patch)
tree0cbd3340e913c398fea20882bebe87cccaa7382c
parent5906098d32e7aa1ab155d5dd49cd3b04b06684eb (diff)
downloadgenenetwork3-a0b150cf2eaed00ca56634da404e621eb6e73484.tar.gz
Make creation of database connections more flexible
* Pass the URI string to parse into `gn3.db_utils.parse_db_url` rather than relying on a global variable. * Pass the URI string to use to generate the database connection to the `gn3.db_utils.database_connector` function rather than depending on a global variable. Use the global `SQL_URI` variable as a default value if one is not provided when calling the function. The changes above make the creation of the database connection more flexible, since the database URI passed into the function can be changed at the site where the call is made. The port is also parsed, and used where present, to allow for either a socket connection or one based on a port.
-rw-r--r--gn3/db_utils.py19
-rw-r--r--tests/unit/test_db_utils.py8
2 files changed, 16 insertions, 11 deletions
diff --git a/gn3/db_utils.py b/gn3/db_utils.py
index 3b72d28..2f08e1b 100644
--- a/gn3/db_utils.py
+++ b/gn3/db_utils.py
@@ -5,16 +5,21 @@ import MySQLdb as mdb
from gn3.settings import SQL_URI
-def parse_db_url() -> Tuple:
+def parse_db_url(db_uri: str) -> Tuple:
"""function to parse SQL_URI env variable note:there\
is a default value for SQL_URI so a tuple result is\
always expected"""
- parsed_db = urlparse(SQL_URI)
- return (parsed_db.hostname, parsed_db.username,
- parsed_db.password, parsed_db.path[1:])
+ parsed_db = urlparse(db_uri)
+ return (parsed_db.hostname, parsed_db.username, parsed_db.password,
+ parsed_db.path[1:], parsed_db.port)
-def database_connector() -> mdb.Connection:
+def database_connector(db_uri: str = SQL_URI) -> mdb.Connection:
"""function to create db connector"""
- host, user, passwd, db_name = parse_db_url()
- return mdb.connect(host, user, passwd, db_name)
+ return mdb.connect(**{
+ key: val for key, val in
+ dict(zip(
+ ("host", "user", "passwd", "db", "port"),
+ parse_db_url(db_uri))).items()
+ if bool(val)
+ })
diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py
index 7dc66c0..039bc39 100644
--- a/tests/unit/test_db_utils.py
+++ b/tests/unit/test_db_utils.py
@@ -18,7 +18,7 @@ class TestDatabase(TestCase):
@mock.patch("gn3.db_utils.parse_db_url")
def test_database_connector(self, mock_db_parser, mock_sql):
"""test for creating database connection"""
- mock_db_parser.return_value = ("localhost", "guest", "4321", "users")
+ mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None)
callable_cursor = lambda: SimpleNamespace(execute=3)
cursor_object = SimpleNamespace(cursor=callable_cursor)
mock_sql.connect.return_value = cursor_object
@@ -26,7 +26,7 @@ class TestDatabase(TestCase):
results = database_connector()
mock_sql.connect.assert_called_with(
- "localhost", "guest", "4321", "users")
+ host="localhost", user="guest", passwd="4321", db="users")
self.assertIsInstance(
results, SimpleNamespace, "database not created successfully")
@@ -35,6 +35,6 @@ class TestDatabase(TestCase):
"mysql://username:4321@localhost/test")
def test_parse_db_url(self):
"""test for parsing db_uri env variable"""
- results = parse_db_url()
- expected_results = ("localhost", "username", "4321", "test")
+ results = parse_db_url("mysql://username:4321@localhost/test")
+ expected_results = ("localhost", "username", "4321", "test", None)
self.assertEqual(results, expected_results)