about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/db_utils.py19
-rw-r--r--tests/unit/test_db_utils.py8
2 files changed, 16 insertions, 11 deletions
diff --git a/gn3/db_utils.py b/gn3/db_utils.py
index 3b72d28..2f08e1b 100644
--- a/gn3/db_utils.py
+++ b/gn3/db_utils.py
@@ -5,16 +5,21 @@ import MySQLdb as mdb
 from gn3.settings import SQL_URI
 
 
-def parse_db_url() -> Tuple:
+def parse_db_url(db_uri: str) -> Tuple:
     """function to parse SQL_URI env variable note:there\
     is a default value for SQL_URI so a tuple result is\
     always expected"""
-    parsed_db = urlparse(SQL_URI)
-    return (parsed_db.hostname, parsed_db.username,
-            parsed_db.password, parsed_db.path[1:])
+    parsed_db = urlparse(db_uri)
+    return (parsed_db.hostname, parsed_db.username, parsed_db.password,
+            parsed_db.path[1:], parsed_db.port)
 
 
-def database_connector() -> mdb.Connection:
+def database_connector(db_uri: str = SQL_URI) -> mdb.Connection:
     """function to create db connector"""
-    host, user, passwd, db_name = parse_db_url()
-    return mdb.connect(host, user, passwd, db_name)
+    return mdb.connect(**{
+        key: val for key, val in
+        dict(zip(
+            ("host", "user", "passwd", "db", "port"),
+            parse_db_url(db_uri))).items()
+        if bool(val)
+    })
diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py
index 7dc66c0..039bc39 100644
--- a/tests/unit/test_db_utils.py
+++ b/tests/unit/test_db_utils.py
@@ -18,7 +18,7 @@ class TestDatabase(TestCase):
     @mock.patch("gn3.db_utils.parse_db_url")
     def test_database_connector(self, mock_db_parser, mock_sql):
         """test for creating database connection"""
-        mock_db_parser.return_value = ("localhost", "guest", "4321", "users")
+        mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None)
         callable_cursor = lambda: SimpleNamespace(execute=3)
         cursor_object = SimpleNamespace(cursor=callable_cursor)
         mock_sql.connect.return_value = cursor_object
@@ -26,7 +26,7 @@ class TestDatabase(TestCase):
         results = database_connector()
 
         mock_sql.connect.assert_called_with(
-            "localhost", "guest", "4321", "users")
+            host="localhost", user="guest", passwd="4321", db="users")
         self.assertIsInstance(
             results, SimpleNamespace, "database not created successfully")
 
@@ -35,6 +35,6 @@ class TestDatabase(TestCase):
                 "mysql://username:4321@localhost/test")
     def test_parse_db_url(self):
         """test for parsing db_uri env variable"""
-        results = parse_db_url()
-        expected_results = ("localhost", "username", "4321", "test")
+        results = parse_db_url("mysql://username:4321@localhost/test")
+        expected_results = ("localhost", "username", "4321", "test", None)
         self.assertEqual(results, expected_results)