about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-12-09 12:03:24 -0600
committerFrederick Muriuki Muriithi2024-12-09 12:03:24 -0600
commit48ee2d70bf31d576f87e6a02498b135f96591ed4 (patch)
tree95a4e92f252457eb37086b6c7481038ee88da2e7
parentd8eb7c2d983f4b7382db47c384bc673baccd170c (diff)
downloadgn-libs-48ee2d70bf31d576f87e6a02498b135f96591ed4.tar.gz
Parse SSL options for the connection string.
-rw-r--r--gn_libs/mysqldb.py15
-rw-r--r--tests/unit/test_mysqldb.py16
2 files changed, 30 insertions, 1 deletions
diff --git a/gn_libs/mysqldb.py b/gn_libs/mysqldb.py
index 445b073..aa11356 100644
--- a/gn_libs/mysqldb.py
+++ b/gn_libs/mysqldb.py
@@ -43,6 +43,19 @@ def __parse_ssl_mode_options__(val: str) -> str:
     if(_val not in mode_opts):
         raise InvalidOptionValue(f"Invalid ssl_mode option: {_val}")
     return _val
+
+
+def __parse_ssl_options__(val: str) -> dict:
+    allowed_keys = ("key", "cert", "ca", "capath", "cipher")
+    opts = {
+        key.strip(): val.strip() for key,val in
+        (keyval.split(";") for keyval in val.split(","))
+    }
+    disallowed = tuple(key for key in opts.keys() if key not in allowed_keys)
+    assert len(disallowed) == 0, f"Invalid SSL keys: {', '.join(disallowed)}"
+    return opts
+
+
 def __parse_db_opts__(opts: str) -> dict:
     """Parse database options into their appropriate values.
 
@@ -60,9 +73,9 @@ def __parse_db_opts__(opts: str) -> dict:
         "use_unicode": __parse_boolean__,
         # "cursorclass": __load_cursor_class__
         "client_flag": int,
-        # "ssl": __parse_ssl_options__,
         "multi_statements": __parse_boolean__,
         "ssl_mode": __parse_ssl_mode_options__,
+        "ssl": __parse_ssl_options__,
         "local_infile": __parse_boolean__,
         "autocommit": __parse_boolean__,
         "binary_prefix": __parse_boolean__
diff --git a/tests/unit/test_mysqldb.py b/tests/unit/test_mysqldb.py
index 793dc1d..df8fd49 100644
--- a/tests/unit/test_mysqldb.py
+++ b/tests/unit/test_mysqldb.py
@@ -62,6 +62,22 @@ def test_database_connection(mock_db_parser, mock_sql):
           "user": None,
           "password": None,
           "database": "users"
+      }),
+     (("mysql://localhost/users?ssl=key;keyname,cert;/path/to/cert,ca;caname,"
+       "capath;/path/to/certificate/authority/files,cipher;ciphername"),
+      {
+          "host": "localhost",
+          "port": 3306,
+          "user": None,
+          "password": None,
+          "database": "users",
+          "ssl": {
+              "key": "keyname",
+              "cert": "/path/to/cert",
+              "ca": "caname",
+              "capath": "/path/to/certificate/authority/files",
+              "cipher": "ciphername"
+          }
       })))
 def test_parse_db_url(sql_uri, expected):
     """Test that valid URIs are passed into valid connection dicts"""