diff options
| -rw-r--r-- | gn3/db_utils.py | 68 | ||||
| -rw-r--r-- | tests/unit/test_db_utils.py | 69 |
2 files changed, 117 insertions, 20 deletions
diff --git a/gn3/db_utils.py b/gn3/db_utils.py index 0d9bd0a..a855137 100644 --- a/gn3/db_utils.py +++ b/gn3/db_utils.py @@ -1,7 +1,7 @@ """module contains all db related stuff""" import contextlib import logging -from typing import Any, Iterator, Protocol, Tuple +from typing import Any, Iterator, Protocol from urllib.parse import urlparse import MySQLdb as mdb import xapian @@ -10,14 +10,59 @@ import xapian LOGGER = logging.getLogger(__file__) -def parse_db_url(sql_uri: str) -> Tuple: - """function to parse SQL_URI env variable note:there\ - is a default value for SQL_URI so a tuple result is\ - always expected""" +def __check_true__(val: str) -> bool: + """Check whether the variable 'val' has the string value `true`.""" + return val.strip().lower() == "true" + + +def __parse_db_opts__(opts: str) -> dict: + """Parse database options into their appropriate values. + + This assumes use of python-mysqlclient library.""" + allowed_opts = ( + "unix_socket", "connect_timeout", "compress", "named_pipe", + "init_command", "read_default_file", "read_default_group", + "cursorclass", "use_unicode", "charset", "collation", "auth_plugin", + "sql_mode", "client_flag", "multi_statements", "ssl_mode", "ssl", + "local_infile", "autocommit", "binary_prefix") + conversion_fns = { + **{opt: str for opt in allowed_opts}, + "connect_timeout": int, + "compress": __check_true__, + "use_unicode": __check_true__, + # "cursorclass": __load_cursor_class__ + "client_flag": int, + "multi_statements": __check_true__, + # "ssl": __parse_ssl_options__, + "local_infile": __check_true__, + "autocommit": __check_true__, + "binary_prefix": __check_true__ + } + keyvals = tuple(filter(bool, opts.split("&"))) + if len(keyvals) > 0: + keyvals = tuple(tuple(item.strip() for item in keyval.split("=")) for keyval in keyvals) + def __check_opt__(opt): + assert opt in allowed_opts, ( + f"Invalid database connection option ({opt}) provided.") + return opt + return { + __check_opt__(key): conversion_fns[key](val) + for key, val in keyvals + } + return {} + + +def parse_db_url(sql_uri: str) -> dict: + """Parse the `sql_uri` variable into a dict of connection parameters.""" parsed_db = urlparse(sql_uri) - return ( - parsed_db.hostname, parsed_db.username, parsed_db.password, - parsed_db.path[1:], parsed_db.port) + return { + "host": parsed_db.hostname, + "port": parsed_db.port or 3306, + "user": parsed_db.username, + "password": parsed_db.password, + "database": parsed_db.path.strip("/").strip(), + **__parse_db_opts__(parsed_db.query) + } # pylint: disable=missing-class-docstring, missing-function-docstring, too-few-public-methods @@ -30,12 +75,7 @@ class Connection(Protocol): @contextlib.contextmanager def database_connection(sql_uri: str, logger: logging.Logger = LOGGER) -> Iterator[Connection]: """Connect to MySQL database.""" - host, user, passwd, db_name, port = parse_db_url(sql_uri) - connection = mdb.connect(db=db_name, - user=user, - passwd=passwd or '', - host=host, - port=port or 3306) + connection = mdb.connect(**parse_db_url(sql_uri)) try: yield connection except mdb.Error as _mbde: diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py index beb7169..3c7ce59 100644 --- a/tests/unit/test_db_utils.py +++ b/tests/unit/test_db_utils.py @@ -10,16 +10,73 @@ from gn3.db_utils import parse_db_url, database_connection @mock.patch("gn3.db_utils.parse_db_url") def test_database_connection(mock_db_parser, mock_sql): """test for creating database connection""" - mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None) + mock_db_parser.return_value = { + "host": "localhost", + "user": "guest", + "password": "4321", + "database": "users", + "port": 3306 + } + mock_sql.Error = Exception with database_connection("mysql://guest:4321@localhost/users") as _conn: mock_sql.connect.assert_called_with( db="users", user="guest", passwd="4321", host="localhost", port=3306) + @pytest.mark.unit_test -def test_parse_db_url(): - """test for parsing db_uri env variable""" - results = parse_db_url("mysql://username:4321@localhost/test") - expected_results = ("localhost", "username", "4321", "test", None) - assert results == expected_results +@pytest.mark.parametrize( + "sql_uri,expected", + (("mysql://theuser:passwd@thehost:3306/thedb", + { + "host": "thehost", + "port": 3306, + "user": "theuser", + "password": "passwd", + "database": "thedb" + }), + (("mysql://auser:passwd@somehost:3307/thedb?" + "unix_socket=/run/mysqld/mysqld.sock&connect_timeout=30"), + { + "host": "somehost", + "port": 3307, + "user": "auser", + "password": "passwd", + "database": "thedb", + "unix_socket": "/run/mysqld/mysqld.sock", + "connect_timeout": 30 + }), + ("mysql://guest:4321@localhost/users", + { + "host": "localhost", + "port": 3306, + "user": "guest", + "password": "4321", + "database": "users" + }), + ("mysql://localhost/users", + { + "host": "localhost", + "port": 3306, + "user": None, + "password": None, + "database": "users" + }))) +def test_parse_db_url(sql_uri, expected): + """Test that valid URIs are passed into valid connection dicts""" + assert parse_db_url(sql_uri) == expected + + +@pytest.mark.unit_test +@pytest.mark.parametrize( + "sql_uri,invalidopt", + (("mysql://localhost/users?socket=/run/mysqld/mysqld.sock", "socket"), + ("mysql://localhost/users?connect_timeout=30¬avalidoption=value", + "notavalidoption"))) +def test_parse_db_url_with_invalid_options(sql_uri, invalidopt): + """Test that invalid options cause the function to raise an exception.""" + with pytest.raises(AssertionError) as exc_info: + parse_db_url(sql_uri) + + assert exc_info.value.args[0] == f"Invalid database connection option ({invalidopt}) provided." |
