about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/db_utils.py68
-rw-r--r--tests/unit/test_db_utils.py69
2 files changed, 117 insertions, 20 deletions
diff --git a/gn3/db_utils.py b/gn3/db_utils.py
index 0d9bd0a..a855137 100644
--- a/gn3/db_utils.py
+++ b/gn3/db_utils.py
@@ -1,7 +1,7 @@
 """module contains all db related stuff"""
 import contextlib
 import logging
-from typing import Any, Iterator, Protocol, Tuple
+from typing import Any, Iterator, Protocol
 from urllib.parse import urlparse
 import MySQLdb as mdb
 import xapian
@@ -10,14 +10,59 @@ import xapian
 LOGGER = logging.getLogger(__file__)
 
 
-def parse_db_url(sql_uri: str) -> Tuple:
-    """function to parse SQL_URI env variable note:there\
-    is a default value for SQL_URI so a tuple result is\
-    always expected"""
+def __check_true__(val: str) -> bool:
+    """Check whether the variable 'val' has the string value `true`."""
+    return val.strip().lower() == "true"
+
+
+def __parse_db_opts__(opts: str) -> dict:
+    """Parse database options into their appropriate values.
+
+    This assumes use of python-mysqlclient library."""
+    allowed_opts = (
+        "unix_socket", "connect_timeout", "compress", "named_pipe",
+        "init_command", "read_default_file", "read_default_group",
+        "cursorclass", "use_unicode", "charset", "collation", "auth_plugin",
+        "sql_mode", "client_flag", "multi_statements", "ssl_mode", "ssl",
+        "local_infile", "autocommit", "binary_prefix")
+    conversion_fns = {
+        **{opt: str for opt in allowed_opts},
+        "connect_timeout": int,
+        "compress": __check_true__,
+        "use_unicode": __check_true__,
+        # "cursorclass": __load_cursor_class__
+        "client_flag": int,
+        "multi_statements": __check_true__,
+        # "ssl": __parse_ssl_options__,
+        "local_infile": __check_true__,
+        "autocommit": __check_true__,
+        "binary_prefix": __check_true__
+    }
+    keyvals = tuple(filter(bool, opts.split("&")))
+    if len(keyvals) > 0:
+        keyvals = tuple(tuple(item.strip() for item in keyval.split("=")) for keyval in keyvals)
+        def __check_opt__(opt):
+            assert opt in allowed_opts, (
+                f"Invalid database connection option ({opt}) provided.")
+            return opt
+        return {
+            __check_opt__(key): conversion_fns[key](val)
+            for key, val in keyvals
+        }
+    return {}
+
+
+def parse_db_url(sql_uri: str) -> dict:
+    """Parse the `sql_uri` variable into a dict of connection parameters."""
     parsed_db = urlparse(sql_uri)
-    return (
-        parsed_db.hostname, parsed_db.username, parsed_db.password,
-        parsed_db.path[1:], parsed_db.port)
+    return {
+        "host": parsed_db.hostname,
+        "port": parsed_db.port or 3306,
+        "user": parsed_db.username,
+        "password": parsed_db.password,
+        "database": parsed_db.path.strip("/").strip(),
+        **__parse_db_opts__(parsed_db.query)
+    }
 
 
 # pylint: disable=missing-class-docstring, missing-function-docstring, too-few-public-methods
@@ -30,12 +75,7 @@ class Connection(Protocol):
 @contextlib.contextmanager
 def database_connection(sql_uri: str, logger: logging.Logger = LOGGER) -> Iterator[Connection]:
     """Connect to MySQL database."""
-    host, user, passwd, db_name, port = parse_db_url(sql_uri)
-    connection = mdb.connect(db=db_name,
-                             user=user,
-                             passwd=passwd or '',
-                             host=host,
-                             port=port or 3306)
+    connection = mdb.connect(**parse_db_url(sql_uri))
     try:
         yield connection
     except mdb.Error as _mbde:
diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py
index beb7169..3c7ce59 100644
--- a/tests/unit/test_db_utils.py
+++ b/tests/unit/test_db_utils.py
@@ -10,16 +10,73 @@ from gn3.db_utils import parse_db_url, database_connection
 @mock.patch("gn3.db_utils.parse_db_url")
 def test_database_connection(mock_db_parser, mock_sql):
     """test for creating database connection"""
-    mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None)
+    mock_db_parser.return_value = {
+        "host": "localhost",
+        "user": "guest",
+        "password": "4321",
+        "database": "users",
+        "port": 3306
+    }
 
+    mock_sql.Error = Exception
     with database_connection("mysql://guest:4321@localhost/users") as _conn:
         mock_sql.connect.assert_called_with(
             db="users", user="guest", passwd="4321", host="localhost",
             port=3306)
 
+
 @pytest.mark.unit_test
-def test_parse_db_url():
-    """test for parsing db_uri env variable"""
-    results = parse_db_url("mysql://username:4321@localhost/test")
-    expected_results = ("localhost", "username", "4321", "test", None)
-    assert results == expected_results
+@pytest.mark.parametrize(
+    "sql_uri,expected",
+    (("mysql://theuser:passwd@thehost:3306/thedb",
+      {
+          "host": "thehost",
+          "port": 3306,
+          "user": "theuser",
+          "password": "passwd",
+          "database": "thedb"
+      }),
+     (("mysql://auser:passwd@somehost:3307/thedb?"
+       "unix_socket=/run/mysqld/mysqld.sock&connect_timeout=30"),
+      {
+          "host": "somehost",
+          "port": 3307,
+          "user": "auser",
+          "password": "passwd",
+          "database": "thedb",
+          "unix_socket": "/run/mysqld/mysqld.sock",
+          "connect_timeout": 30
+      }),
+     ("mysql://guest:4321@localhost/users",
+      {
+          "host": "localhost",
+          "port": 3306,
+          "user": "guest",
+          "password": "4321",
+          "database": "users"
+      }),
+     ("mysql://localhost/users",
+      {
+          "host": "localhost",
+          "port": 3306,
+          "user": None,
+          "password": None,
+          "database": "users"
+      })))
+def test_parse_db_url(sql_uri, expected):
+    """Test that valid URIs are passed into valid connection dicts"""
+    assert parse_db_url(sql_uri) == expected
+
+
+@pytest.mark.unit_test
+@pytest.mark.parametrize(
+    "sql_uri,invalidopt",
+    (("mysql://localhost/users?socket=/run/mysqld/mysqld.sock", "socket"),
+     ("mysql://localhost/users?connect_timeout=30&notavalidoption=value",
+      "notavalidoption")))
+def test_parse_db_url_with_invalid_options(sql_uri, invalidopt):
+    """Test that invalid options cause the function to raise an exception."""
+    with pytest.raises(AssertionError) as exc_info:
+        parse_db_url(sql_uri)
+
+    assert exc_info.value.args[0] == f"Invalid database connection option ({invalidopt}) provided."