aboutsummaryrefslogtreecommitdiff
path: root/gn_libs/mysqldb.py
blob: 764412faba211f0d4e7b84a4dc1347049e0a6cc1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""This module handles connections to MariaDB using python-mysqlclient."""
import logging
import contextlib
from logging import Logger
from urllib.parse import urlparse
from typing import Any, Iterator, Protocol, Callable

import MySQLdb as mdb
from MySQLdb.cursors import Cursor


_logger = logging.getLogger(__file__)

class InvalidOptionValue(Exception):
    """Raised whenever a parsed value is invalid for the specific option."""


def __parse_boolean__(val: str) -> bool:
    """Check whether the variable 'val' has the string value `true`."""
    true_vals = ("t", "T", "true", "TRUE", "True")
    false_vals = ("f", "F", "false", "FALSE", "False")
    if (val.strip() not in (true_vals + false_vals)):
        raise InvalidOptionValue(f"Invalid value: {val}")
    return val.strip().lower() in true_vals


def __non_negative_int__(val: str) -> int:
    """Convert a value to a non-negative int."""
    error_message = f"Expected a non-negative value. Got {val}"
    try:
        _val = int(val)
        if (_val < 0):
            raise InvalidOptionValue(error_message)
        return _val
    except ValueError as verr:
        raise InvalidOptionValue(error_message) from verr


def __parse_ssl_mode_options__(val: str) -> str:
    mode_opts = (
        "DISABLED", "PREFERRED", "REQUIRED", "VERIFY_CA", "VERIFY_IDENTITY")
    _val = val.strip().upper()
    if(_val not in mode_opts):
        raise InvalidOptionValue(f"Invalid ssl_mode option: {_val}")
    return _val


def __parse_ssl_options__(val: str) -> dict:
    allowed_keys = ("key", "cert", "ca", "capath", "cipher")
    opts = {
        key.strip(): val.strip() for key,val in
        (keyval.split(";") for keyval in val.split(","))
    }
    disallowed = tuple(key for key in opts.keys() if key not in allowed_keys)
    assert len(disallowed) == 0, f"Invalid SSL keys: {', '.join(disallowed)}"
    return opts


def __parse_db_opts__(opts: str) -> dict:
    """Parse database options into their appropriate values.

    This assumes use of python-mysqlclient library."""
    allowed_opts = (
        "unix_socket", "connect_timeout", "compress", "named_pipe",
        "init_command", "read_default_file", "read_default_group",
        "cursorclass", "use_unicode", "charset", "collation", "auth_plugin",
        "sql_mode", "client_flag", "multi_statements", "ssl_mode", "ssl",
        "local_infile", "autocommit", "binary_prefix")
    conversion_fns: dict[str, Callable] = {
        **{opt: str for opt in allowed_opts},
        "connect_timeout": __non_negative_int__,
        "compress": __parse_boolean__,
        "use_unicode": __parse_boolean__,
        # "cursorclass": __load_cursor_class__
        "client_flag": int,
        "multi_statements": __parse_boolean__,
        "ssl_mode": __parse_ssl_mode_options__,
        "ssl": __parse_ssl_options__,
        "local_infile": __parse_boolean__,
        "autocommit": __parse_boolean__,
        "binary_prefix": __parse_boolean__
    }
    queries = tuple(filter(bool, opts.split("&")))
    if len(queries) > 0:
        keyvals: tuple[tuple[str, ...], ...] = tuple(
            tuple(item.strip() for item in query.split("="))
            for query in queries)
        def __check_opt__(opt):
            assert opt in allowed_opts, (
                f"Invalid database connection option ({opt}) provided.")
            return opt
        return {
            __check_opt__(key): conversion_fns[key](val)
            for key, val in keyvals
        }
    return {}


def parse_db_url(sql_uri: str) -> dict:
    """Parse the `sql_uri` variable into a dict of connection parameters."""
    parsed_db = urlparse(sql_uri)
    return {
        "host": parsed_db.hostname,
        "port": parsed_db.port or 3306,
        "user": parsed_db.username,
        "password": parsed_db.password,
        "database": parsed_db.path.strip("/").strip(),
        **__parse_db_opts__(parsed_db.query)
    }


class Connection(Protocol):
    """Type Annotation for MySQLdb's connection object"""

    def commit(self):
        """Finish a transaction and commit the changes."""

    def rollback(self):
        """Cancel the current transaction and roll back the changes."""

    def cursor(self, *args, **kwargs) -> Any:
        """A cursor in which queries may be performed"""

@contextlib.contextmanager
def database_connection(sql_uri: str, logger: logging.Logger = _logger) -> Iterator[Connection]:
    """Connect to MySQL database."""
    connection = mdb.connect(**parse_db_url(sql_uri))
    try:
        yield connection
        connection.commit()
    except mdb.Error as _mbde:
        logger.error("DB error encountered", exc_info=True)
        connection.rollback()
    finally:
        connection.close()


def debug_query(cursor: Cursor, logger: Logger) -> None:
    """Debug the actual query run with MySQLdb"""
    for attr in ("_executed", "statement", "_last_executed"):
        if hasattr(cursor, attr):
            logger.debug("MySQLdb QUERY: %s", getattr(cursor, attr))
            break