diff options
-rw-r--r-- | gn_libs/mysqldb.py | 20 | ||||
-rw-r--r-- | tests/unit/test_mysqldb.py | 13 |
2 files changed, 24 insertions, 9 deletions
diff --git a/gn_libs/mysqldb.py b/gn_libs/mysqldb.py index 6c05a6b..c738d96 100644 --- a/gn_libs/mysqldb.py +++ b/gn_libs/mysqldb.py @@ -15,9 +15,13 @@ class InvalidOptionValue(Exception): """Raised whenever a parsed value is invalid for the specific option.""" -def __check_true__(val: str) -> bool: +def __parse_boolean__(val: str) -> bool: """Check whether the variable 'val' has the string value `true`.""" - return val.strip().lower() == "true" + true_vals = ("t", "T", "true", "TRUE", "True") + false_vals = ("f", "F", "false", "FALSE", "False") + if (val.strip() not in (true_vals + false_vals)): + raise InvalidOptionValue(f"Invalid value: {val}") + return val.strip().lower() in true_vals def __parse_db_opts__(opts: str) -> dict: @@ -33,15 +37,15 @@ def __parse_db_opts__(opts: str) -> dict: conversion_fns: dict[str, Callable] = { **{opt: str for opt in allowed_opts}, "connect_timeout": int, - "compress": __check_true__, - "use_unicode": __check_true__, + "compress": __parse_boolean__, + "use_unicode": __parse_boolean__, # "cursorclass": __load_cursor_class__ "client_flag": int, - "multi_statements": __check_true__, # "ssl": __parse_ssl_options__, - "local_infile": __check_true__, - "autocommit": __check_true__, - "binary_prefix": __check_true__ + "multi_statements": __parse_boolean__, + "local_infile": __parse_boolean__, + "autocommit": __parse_boolean__, + "binary_prefix": __parse_boolean__ } queries = tuple(filter(bool, opts.split("&"))) if len(queries) > 0: diff --git a/tests/unit/test_mysqldb.py b/tests/unit/test_mysqldb.py index 5beaae3..b75f1e0 100644 --- a/tests/unit/test_mysqldb.py +++ b/tests/unit/test_mysqldb.py @@ -3,7 +3,7 @@ from unittest import mock import pytest -from gn_libs.mysqldb import parse_db_url, database_connection +from gn_libs.mysqldb import parse_db_url, database_connection, InvalidOptionValue @pytest.mark.unit_test @mock.patch("gn_libs.mysqldb.mdb") @@ -80,3 +80,14 @@ def test_parse_db_url_with_invalid_options(sql_uri, invalidopt): parse_db_url(sql_uri) assert exc_info.value.args[0] == f"Invalid database connection option ({invalidopt}) provided." + + +@pytest.mark.unit_test +@pytest.mark.parametrize( + "sql_uri", + (("mysql://auser:passwd@somehost:3307/thedb?use_unicode=fire"), + ("mysql://auser:passwd@somehost:3307/thedb?use_unicode=3"))) +def test_parse_db_url_with_invalid_options_values(sql_uri): + """Test parsing with invalid options' values.""" + with pytest.raises(InvalidOptionValue) as iov: + parse_db_url(sql_uri) |