diff options
Diffstat (limited to 'gn3/db_utils.py')
| -rw-r--r-- | gn3/db_utils.py | 93 |
1 files changed, 60 insertions, 33 deletions
diff --git a/gn3/db_utils.py b/gn3/db_utils.py index 0d9bd0a..e1816b0 100644 --- a/gn3/db_utils.py +++ b/gn3/db_utils.py @@ -1,49 +1,76 @@ """module contains all db related stuff""" -import contextlib import logging -from typing import Any, Iterator, Protocol, Tuple +import contextlib from urllib.parse import urlparse -import MySQLdb as mdb +from typing import Callable + import xapian +# XXXX: Replace instances that call db_utils.Connection or +# db_utils.database_connection with a direct call to gn_libs. +# pylint: disable=[W0611] +from gn_libs.mysqldb import Connection, database_connection # type: ignore + LOGGER = logging.getLogger(__file__) -def parse_db_url(sql_uri: str) -> Tuple: - """function to parse SQL_URI env variable note:there\ - is a default value for SQL_URI so a tuple result is\ - always expected""" - parsed_db = urlparse(sql_uri) - return ( - parsed_db.hostname, parsed_db.username, parsed_db.password, - parsed_db.path[1:], parsed_db.port) +def __check_true__(val: str) -> bool: + """Check whether the variable 'val' has the string value `true`.""" + return val.strip().lower() == "true" -# pylint: disable=missing-class-docstring, missing-function-docstring, too-few-public-methods -class Connection(Protocol): - """Type Annotation for MySQLdb's connection object""" - def cursor(self, *args, **kwargs) -> Any: - """A cursor in which queries may be performed""" +def __parse_db_opts__(opts: str) -> dict: + """Parse database options into their appropriate values. + This assumes use of python-mysqlclient library.""" + allowed_opts = ( + "unix_socket", "connect_timeout", "compress", "named_pipe", + "init_command", "read_default_file", "read_default_group", + "cursorclass", "use_unicode", "charset", "collation", "auth_plugin", + "sql_mode", "client_flag", "multi_statements", "ssl_mode", "ssl", + "local_infile", "autocommit", "binary_prefix") + conversion_fns: dict[str, Callable] = { + **{opt: str for opt in allowed_opts}, + "connect_timeout": int, + "compress": __check_true__, + "use_unicode": __check_true__, + # "cursorclass": __load_cursor_class__ + "client_flag": int, + "multi_statements": __check_true__, + # "ssl": __parse_ssl_options__, + "local_infile": __check_true__, + "autocommit": __check_true__, + "binary_prefix": __check_true__ + } + queries = tuple(filter(bool, opts.split("&"))) + if len(queries) > 0: + keyvals: tuple[tuple[str, ...], ...] = tuple( + tuple(item.strip() for item in query.split("=")) + for query in queries) -@contextlib.contextmanager -def database_connection(sql_uri: str, logger: logging.Logger = LOGGER) -> Iterator[Connection]: - """Connect to MySQL database.""" - host, user, passwd, db_name, port = parse_db_url(sql_uri) - connection = mdb.connect(db=db_name, - user=user, - passwd=passwd or '', - host=host, - port=port or 3306) - try: - yield connection - except mdb.Error as _mbde: - logger.error("DB error encountered", exc_info=True) - connection.rollback() - finally: - connection.commit() - connection.close() + def __check_opt__(opt): + assert opt in allowed_opts, ( + f"Invalid database connection option ({opt}) provided.") + return opt + return { + __check_opt__(key): conversion_fns[key](val) + for key, val in keyvals + } + return {} + + +def parse_db_url(sql_uri: str) -> dict: + """Parse the `sql_uri` variable into a dict of connection parameters.""" + parsed_db = urlparse(sql_uri) + return { + "host": parsed_db.hostname, + "port": parsed_db.port or 3306, + "user": parsed_db.username, + "password": parsed_db.password, + "database": parsed_db.path.strip("/").strip(), + **__parse_db_opts__(parsed_db.query) + } @contextlib.contextmanager |
