aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn_libs/mysqldb.py15
-rw-r--r--tests/unit/test_mysqldb.py16
2 files changed, 30 insertions, 1 deletions
diff --git a/gn_libs/mysqldb.py b/gn_libs/mysqldb.py
index 445b073..aa11356 100644
--- a/gn_libs/mysqldb.py
+++ b/gn_libs/mysqldb.py
@@ -43,6 +43,19 @@ def __parse_ssl_mode_options__(val: str) -> str:
if(_val not in mode_opts):
raise InvalidOptionValue(f"Invalid ssl_mode option: {_val}")
return _val
+
+
+def __parse_ssl_options__(val: str) -> dict:
+ allowed_keys = ("key", "cert", "ca", "capath", "cipher")
+ opts = {
+ key.strip(): val.strip() for key,val in
+ (keyval.split(";") for keyval in val.split(","))
+ }
+ disallowed = tuple(key for key in opts.keys() if key not in allowed_keys)
+ assert len(disallowed) == 0, f"Invalid SSL keys: {', '.join(disallowed)}"
+ return opts
+
+
def __parse_db_opts__(opts: str) -> dict:
"""Parse database options into their appropriate values.
@@ -60,9 +73,9 @@ def __parse_db_opts__(opts: str) -> dict:
"use_unicode": __parse_boolean__,
# "cursorclass": __load_cursor_class__
"client_flag": int,
- # "ssl": __parse_ssl_options__,
"multi_statements": __parse_boolean__,
"ssl_mode": __parse_ssl_mode_options__,
+ "ssl": __parse_ssl_options__,
"local_infile": __parse_boolean__,
"autocommit": __parse_boolean__,
"binary_prefix": __parse_boolean__
diff --git a/tests/unit/test_mysqldb.py b/tests/unit/test_mysqldb.py
index 793dc1d..df8fd49 100644
--- a/tests/unit/test_mysqldb.py
+++ b/tests/unit/test_mysqldb.py
@@ -62,6 +62,22 @@ def test_database_connection(mock_db_parser, mock_sql):
"user": None,
"password": None,
"database": "users"
+ }),
+ (("mysql://localhost/users?ssl=key;keyname,cert;/path/to/cert,ca;caname,"
+ "capath;/path/to/certificate/authority/files,cipher;ciphername"),
+ {
+ "host": "localhost",
+ "port": 3306,
+ "user": None,
+ "password": None,
+ "database": "users",
+ "ssl": {
+ "key": "keyname",
+ "cert": "/path/to/cert",
+ "ca": "caname",
+ "capath": "/path/to/certificate/authority/files",
+ "cipher": "ciphername"
+ }
})))
def test_parse_db_url(sql_uri, expected):
"""Test that valid URIs are passed into valid connection dicts"""