From a0b150cf2eaed00ca56634da404e621eb6e73484 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Sat, 26 Mar 2022 14:07:18 +0300 Subject: Make creation of database connections more flexible * Pass the URI string to parse into `gn3.db_utils.parse_db_url` rather than relying on a global variable. * Pass the URI string to use to generate the database connection to the `gn3.db_utils.database_connector` function rather than depending on a global variable. Use the global `SQL_URI` variable as a default value if one is not provided when calling the function. The changes above make the creation of the database connection more flexible, since the database URI passed into the function can be changed at the site where the call is made. The port is also parsed, and used where present, to allow for either a socket connection or one based on a port. --- gn3/db_utils.py | 19 ++++++++++++------- tests/unit/test_db_utils.py | 8 ++++---- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/gn3/db_utils.py b/gn3/db_utils.py index 3b72d28..2f08e1b 100644 --- a/gn3/db_utils.py +++ b/gn3/db_utils.py @@ -5,16 +5,21 @@ import MySQLdb as mdb from gn3.settings import SQL_URI -def parse_db_url() -> Tuple: +def parse_db_url(db_uri: str) -> Tuple: """function to parse SQL_URI env variable note:there\ is a default value for SQL_URI so a tuple result is\ always expected""" - parsed_db = urlparse(SQL_URI) - return (parsed_db.hostname, parsed_db.username, - parsed_db.password, parsed_db.path[1:]) + parsed_db = urlparse(db_uri) + return (parsed_db.hostname, parsed_db.username, parsed_db.password, + parsed_db.path[1:], parsed_db.port) -def database_connector() -> mdb.Connection: +def database_connector(db_uri: str = SQL_URI) -> mdb.Connection: """function to create db connector""" - host, user, passwd, db_name = parse_db_url() - return mdb.connect(host, user, passwd, db_name) + return mdb.connect(**{ + key: val for key, val in + dict(zip( + ("host", "user", "passwd", "db", "port"), + parse_db_url(db_uri))).items() + if bool(val) + }) diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py index 7dc66c0..039bc39 100644 --- a/tests/unit/test_db_utils.py +++ b/tests/unit/test_db_utils.py @@ -18,7 +18,7 @@ class TestDatabase(TestCase): @mock.patch("gn3.db_utils.parse_db_url") def test_database_connector(self, mock_db_parser, mock_sql): """test for creating database connection""" - mock_db_parser.return_value = ("localhost", "guest", "4321", "users") + mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None) callable_cursor = lambda: SimpleNamespace(execute=3) cursor_object = SimpleNamespace(cursor=callable_cursor) mock_sql.connect.return_value = cursor_object @@ -26,7 +26,7 @@ class TestDatabase(TestCase): results = database_connector() mock_sql.connect.assert_called_with( - "localhost", "guest", "4321", "users") + host="localhost", user="guest", passwd="4321", db="users") self.assertIsInstance( results, SimpleNamespace, "database not created successfully") @@ -35,6 +35,6 @@ class TestDatabase(TestCase): "mysql://username:4321@localhost/test") def test_parse_db_url(self): """test for parsing db_uri env variable""" - results = parse_db_url() - expected_results = ("localhost", "username", "4321", "test") + results = parse_db_url("mysql://username:4321@localhost/test") + expected_results = ("localhost", "username", "4321", "test", None) self.assertEqual(results, expected_results) -- cgit v1.2.3