aboutsummaryrefslogtreecommitdiff
"""This module handles connections to MariaDB using python-mysqlclient."""
import logging
import contextlib
from logging import Logger
from urllib.parse import urlparse
from typing import Any, Iterator, Protocol, Callable

import MySQLdb as mdb
from MySQLdb.cursors import Cursor


_logger = logging.getLogger(__file__)

class InvalidOptionValue(Exception):
    """Raised whenever a parsed value is invalid for the specific option."""


def __parse_boolean__(val: str) -> bool:
    """Check whether the variable 'val' has the string value `true`."""
    true_vals = ("t", "T", "true", "TRUE", "True")
    false_vals = ("f", "F", "false", "FALSE", "False")
    if val.strip() not in true_vals + false_vals:
        raise InvalidOptionValue(f"Invalid value: {val}")
    return val.strip().lower() in true_vals


def __non_negative_int__(val: str) -> int:
    """Convert a value to a non-negative int."""
    error_message = f"Expected a non-negative value. Got {val}"
    try:
        _val = int(val)
        if _val < 0:
            raise InvalidOptionValue(error_message)
        return _val
    except ValueError as verr:
        raise InvalidOptionValue(error_message) from verr


def __parse_ssl_mode_options__(val: str) -> str:
    mode_opts = (
        "DISABLED", "PREFERRED", "REQUIRED", "VERIFY_CA", "VERIFY_IDENTITY")
    _val = val.strip().upper()
    if _val not in mode_opts:
        raise InvalidOptionValue(f"Invalid ssl_mode option: {_val}")
    return _val


def __parse_ssl_options__(val: str) -> dict:
    allowed_keys = ("key", "cert", "ca", "capath", "cipher")
    opts = {
        key.strip(): val.strip() for key,val in
        (keyval.split(";") for keyval in val.split(","))
    }
    disallowed = tuple(key for key in opts.keys() if key not in allowed_keys)
    assert len(disallowed) == 0, f"Invalid SSL keys: {', '.join(disallowed)}"
    return opts


def __parse_db_opts__(opts: str) -> dict:
    """Parse database options into their appropriate values.

    This assumes use of python-mysqlclient library."""
    allowed_opts = (
        "unix_socket", "connect_timeout", "compress", "named_pipe",
        "init_command", "read_default_file", "read_default_group",
        "cursorclass", "use_unicode", "charset", "collation", "auth_plugin",
        "sql_mode", "client_flag", "multi_statements", "ssl_mode", "ssl",
        "local_infile", "autocommit", "binary_prefix")
    conversion_fns: dict[str, Callable] = {
        **{opt: str for opt in allowed_opts},
        "connect_timeout": __non_negative_int__,
        "compress": __parse_boolean__,
        "use_unicode": __parse_boolean__,
        # "cursorclass": __load_cursor_class__
        "client_flag": int,
        "multi_statements": __parse_boolean__,
        "ssl_mode": __parse_ssl_mode_options__,
        "ssl": __parse_ssl_options__,
        "local_infile": __parse_boolean__,
        "autocommit": __parse_boolean__,
        "binary_prefix": __parse_boolean__
    }
    queries = tuple(filter(bool, opts.split("&")))
    if len(queries) > 0:
        keyvals: tuple[tuple[str, ...], ...] = tuple(
            tuple(item.strip() for item in query.split("="))
            for query in queries)
        def __check_opt__(opt):
            assert opt in allowed_opts, (
                f"Invalid database connection option ({opt}) provided.")
            return opt
        return {
            __check_opt__(key): conversion_fns[key](val)
            for key, val in keyvals
        }
    return {}


def parse_db_url(sql_uri: str) -> dict:
    """Parse the `sql_uri` variable into a dict of connection parameters."""
    parsed_db = urlparse(sql_uri)
    return {
        "host": parsed_db.hostname,
        "port": parsed_db.port or 3306,
        "user": parsed_db.username,
        "password": parsed_db.password,
        "database": parsed_db.path.strip("/").strip(),
        **__parse_db_opts__(parsed_db.query)
    }


class Connection(Protocol):
    """Type Annotation for MySQLdb's connection object"""

    def commit(self):
        """Finish a transaction and commit the changes."""

    def rollback(self):
        """Cancel the current transaction and roll back the changes."""

    def cursor(self, *args, **kwargs) -> Any:
        """A cursor in which queries may be performed"""

@contextlib.contextmanager
def database_connection(sql_uri: str, logger: logging.Logger = _logger) -> Iterator[Connection]:
    """Connect to MySQL database."""
    connection = mdb.connect(**parse_db_url(sql_uri))
    try:
        yield connection
        connection.commit()
    except mdb.Error as _mbde:
        logger.error("DB error encountered", exc_info=True)
        connection.rollback()
    finally:
        connection.close()


def debug_query(cursor: Cursor, logger: Logger) -> None:
    """Debug the actual query run with MySQLdb"""
    for attr in ("_executed", "statement", "_last_executed"):
        if hasattr(cursor, attr):
            logger.debug("MySQLdb QUERY: %s", getattr(cursor, attr))
            break