diff options
-rw-r--r-- | gn_libs/mysqldb.py | 15 | ||||
-rw-r--r-- | tests/unit/test_mysqldb.py | 16 |
2 files changed, 30 insertions, 1 deletions
diff --git a/gn_libs/mysqldb.py b/gn_libs/mysqldb.py index 445b073..aa11356 100644 --- a/gn_libs/mysqldb.py +++ b/gn_libs/mysqldb.py @@ -43,6 +43,19 @@ def __parse_ssl_mode_options__(val: str) -> str: if(_val not in mode_opts): raise InvalidOptionValue(f"Invalid ssl_mode option: {_val}") return _val + + +def __parse_ssl_options__(val: str) -> dict: + allowed_keys = ("key", "cert", "ca", "capath", "cipher") + opts = { + key.strip(): val.strip() for key,val in + (keyval.split(";") for keyval in val.split(",")) + } + disallowed = tuple(key for key in opts.keys() if key not in allowed_keys) + assert len(disallowed) == 0, f"Invalid SSL keys: {', '.join(disallowed)}" + return opts + + def __parse_db_opts__(opts: str) -> dict: """Parse database options into their appropriate values. @@ -60,9 +73,9 @@ def __parse_db_opts__(opts: str) -> dict: "use_unicode": __parse_boolean__, # "cursorclass": __load_cursor_class__ "client_flag": int, - # "ssl": __parse_ssl_options__, "multi_statements": __parse_boolean__, "ssl_mode": __parse_ssl_mode_options__, + "ssl": __parse_ssl_options__, "local_infile": __parse_boolean__, "autocommit": __parse_boolean__, "binary_prefix": __parse_boolean__ diff --git a/tests/unit/test_mysqldb.py b/tests/unit/test_mysqldb.py index 793dc1d..df8fd49 100644 --- a/tests/unit/test_mysqldb.py +++ b/tests/unit/test_mysqldb.py @@ -62,6 +62,22 @@ def test_database_connection(mock_db_parser, mock_sql): "user": None, "password": None, "database": "users" + }), + (("mysql://localhost/users?ssl=key;keyname,cert;/path/to/cert,ca;caname," + "capath;/path/to/certificate/authority/files,cipher;ciphername"), + { + "host": "localhost", + "port": 3306, + "user": None, + "password": None, + "database": "users", + "ssl": { + "key": "keyname", + "cert": "/path/to/cert", + "ca": "caname", + "capath": "/path/to/certificate/authority/files", + "cipher": "ciphername" + } }))) def test_parse_db_url(sql_uri, expected): """Test that valid URIs are passed into valid connection dicts""" |