diff options
-rw-r--r-- | gn3/db_utils.py | 19 | ||||
-rw-r--r-- | tests/unit/test_db_utils.py | 8 |
2 files changed, 16 insertions, 11 deletions
diff --git a/gn3/db_utils.py b/gn3/db_utils.py index 3b72d28..2f08e1b 100644 --- a/gn3/db_utils.py +++ b/gn3/db_utils.py @@ -5,16 +5,21 @@ import MySQLdb as mdb from gn3.settings import SQL_URI -def parse_db_url() -> Tuple: +def parse_db_url(db_uri: str) -> Tuple: """function to parse SQL_URI env variable note:there\ is a default value for SQL_URI so a tuple result is\ always expected""" - parsed_db = urlparse(SQL_URI) - return (parsed_db.hostname, parsed_db.username, - parsed_db.password, parsed_db.path[1:]) + parsed_db = urlparse(db_uri) + return (parsed_db.hostname, parsed_db.username, parsed_db.password, + parsed_db.path[1:], parsed_db.port) -def database_connector() -> mdb.Connection: +def database_connector(db_uri: str = SQL_URI) -> mdb.Connection: """function to create db connector""" - host, user, passwd, db_name = parse_db_url() - return mdb.connect(host, user, passwd, db_name) + return mdb.connect(**{ + key: val for key, val in + dict(zip( + ("host", "user", "passwd", "db", "port"), + parse_db_url(db_uri))).items() + if bool(val) + }) diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py index 7dc66c0..039bc39 100644 --- a/tests/unit/test_db_utils.py +++ b/tests/unit/test_db_utils.py @@ -18,7 +18,7 @@ class TestDatabase(TestCase): @mock.patch("gn3.db_utils.parse_db_url") def test_database_connector(self, mock_db_parser, mock_sql): """test for creating database connection""" - mock_db_parser.return_value = ("localhost", "guest", "4321", "users") + mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None) callable_cursor = lambda: SimpleNamespace(execute=3) cursor_object = SimpleNamespace(cursor=callable_cursor) mock_sql.connect.return_value = cursor_object @@ -26,7 +26,7 @@ class TestDatabase(TestCase): results = database_connector() mock_sql.connect.assert_called_with( - "localhost", "guest", "4321", "users") + host="localhost", user="guest", passwd="4321", db="users") self.assertIsInstance( results, SimpleNamespace, "database not created successfully") @@ -35,6 +35,6 @@ class TestDatabase(TestCase): "mysql://username:4321@localhost/test") def test_parse_db_url(self): """test for parsing db_uri env variable""" - results = parse_db_url() - expected_results = ("localhost", "username", "4321", "test") + results = parse_db_url("mysql://username:4321@localhost/test") + expected_results = ("localhost", "username", "4321", "test", None) self.assertEqual(results, expected_results) |