diff options
author | Frederick Muriuki Muriithi | 2024-12-09 12:03:24 -0600 |
---|---|---|
committer | Frederick Muriuki Muriithi | 2024-12-09 12:03:24 -0600 |
commit | 48ee2d70bf31d576f87e6a02498b135f96591ed4 (patch) | |
tree | 95a4e92f252457eb37086b6c7481038ee88da2e7 | |
parent | d8eb7c2d983f4b7382db47c384bc673baccd170c (diff) | |
download | gn-libs-48ee2d70bf31d576f87e6a02498b135f96591ed4.tar.gz |
Parse SSL options for the connection string.
-rw-r--r-- | gn_libs/mysqldb.py | 15 | ||||
-rw-r--r-- | tests/unit/test_mysqldb.py | 16 |
2 files changed, 30 insertions, 1 deletions
diff --git a/gn_libs/mysqldb.py b/gn_libs/mysqldb.py index 445b073..aa11356 100644 --- a/gn_libs/mysqldb.py +++ b/gn_libs/mysqldb.py @@ -43,6 +43,19 @@ def __parse_ssl_mode_options__(val: str) -> str: if(_val not in mode_opts): raise InvalidOptionValue(f"Invalid ssl_mode option: {_val}") return _val + + +def __parse_ssl_options__(val: str) -> dict: + allowed_keys = ("key", "cert", "ca", "capath", "cipher") + opts = { + key.strip(): val.strip() for key,val in + (keyval.split(";") for keyval in val.split(",")) + } + disallowed = tuple(key for key in opts.keys() if key not in allowed_keys) + assert len(disallowed) == 0, f"Invalid SSL keys: {', '.join(disallowed)}" + return opts + + def __parse_db_opts__(opts: str) -> dict: """Parse database options into their appropriate values. @@ -60,9 +73,9 @@ def __parse_db_opts__(opts: str) -> dict: "use_unicode": __parse_boolean__, # "cursorclass": __load_cursor_class__ "client_flag": int, - # "ssl": __parse_ssl_options__, "multi_statements": __parse_boolean__, "ssl_mode": __parse_ssl_mode_options__, + "ssl": __parse_ssl_options__, "local_infile": __parse_boolean__, "autocommit": __parse_boolean__, "binary_prefix": __parse_boolean__ diff --git a/tests/unit/test_mysqldb.py b/tests/unit/test_mysqldb.py index 793dc1d..df8fd49 100644 --- a/tests/unit/test_mysqldb.py +++ b/tests/unit/test_mysqldb.py @@ -62,6 +62,22 @@ def test_database_connection(mock_db_parser, mock_sql): "user": None, "password": None, "database": "users" + }), + (("mysql://localhost/users?ssl=key;keyname,cert;/path/to/cert,ca;caname," + "capath;/path/to/certificate/authority/files,cipher;ciphername"), + { + "host": "localhost", + "port": 3306, + "user": None, + "password": None, + "database": "users", + "ssl": { + "key": "keyname", + "cert": "/path/to/cert", + "ca": "caname", + "capath": "/path/to/certificate/authority/files", + "cipher": "ciphername" + } }))) def test_parse_db_url(sql_uri, expected): """Test that valid URIs are passed into valid connection dicts""" |