diff options
Diffstat (limited to 'gn_libs')
| -rw-r--r-- | gn_libs/debug.py | 29 | ||||
| -rw-r--r-- | gn_libs/jobs/__init__.py | 11 | ||||
| -rw-r--r-- | gn_libs/jobs/jobs.py | 213 | ||||
| -rw-r--r-- | gn_libs/jobs/launcher.py | 108 | ||||
| -rw-r--r-- | gn_libs/jobs/migrations.py | 69 | ||||
| -rw-r--r-- | gn_libs/monadic_requests.py | 74 | ||||
| -rw-r--r-- | gn_libs/mysqldb.py | 19 | ||||
| -rw-r--r-- | gn_libs/privileges.py | 166 | ||||
| -rw-r--r-- | gn_libs/protocols/__init__.py | 2 | ||||
| -rw-r--r-- | gn_libs/protocols/db.py | 35 | ||||
| -rw-r--r-- | gn_libs/sqlite3.py | 45 |
11 files changed, 764 insertions, 7 deletions
diff --git a/gn_libs/debug.py b/gn_libs/debug.py index 6b7173b..7ad10e0 100644 --- a/gn_libs/debug.py +++ b/gn_libs/debug.py @@ -1,6 +1,7 @@ """Debug utilities""" import logging -from flask import current_app +import importlib.util +from typing import Callable __this_module_name__ = __name__ @@ -8,10 +9,16 @@ __this_module_name__ = __name__ # pylint: disable=invalid-name def getLogger(name: str): """Return a logger""" - return ( - logging.getLogger(name) - if not bool(current_app) - else current_app.logger) + flask_spec = importlib.util.find_spec("flask") + if bool(flask_spec): + current_app = importlib.import_module("flask").current_app + return ( + logging.getLogger(name) + if not bool(current_app) + else current_app.logger) + + return logging.getLogger(name) + def __pk__(*args): """Format log entry""" @@ -20,3 +27,15 @@ def __pk__(*args): logger = getLogger(__this_module_name__) logger.debug("%s: %s", title_vals, value) return value + + +def make_peeker(logger: logging.Logger) -> Callable: + """Make a peeker function that's very much like __pk__ but that uses the + given logger.""" + def peeker(*args): + value = args[-1] + title_vals = " => ".join(args[0:-1]) + logger.debug("%s: %s", title_vals, value) + return value + + return peeker diff --git a/gn_libs/jobs/__init__.py b/gn_libs/jobs/__init__.py new file mode 100644 index 0000000..d6e4ce3 --- /dev/null +++ b/gn_libs/jobs/__init__.py @@ -0,0 +1,11 @@ +"""This package deals with launching and managing background/async jobs.""" +from .migrations import run_migrations +from .jobs import (job, + launch_job, + initialise_job, + push_to_stream, + update_metadata) + +def init_app(flask_app): + """Initialise the migrations for flask""" + run_migrations(flask_app.config["ASYNCHRONOUS_JOBS_SQLITE_DB"]) diff --git a/gn_libs/jobs/jobs.py b/gn_libs/jobs/jobs.py new file mode 100644 index 0000000..ec1c3a8 --- /dev/null +++ b/gn_libs/jobs/jobs.py @@ -0,0 +1,213 @@ +"""Handle asynchronous/background jobs. Job data is stored in SQLite database(s).""" +import os +import sys +import uuid +import shlex +import logging +import subprocess +from pathlib import Path +from functools import partial +from typing import Union, Optional +from datetime import datetime, timezone, timedelta + +from gn_libs.sqlite3 import DbCursor, DbConnection, cursor as _cursor + +_logger_ = logging.getLogger(__name__) +_DEFAULT_EXPIRY_SECONDS_ = 2 * 24 * 60 * 60 # 2 days, in seconds + + +class JobNotFound(Exception): + """Raised if we try to retrieve a non-existent job.""" + + +def __job_metadata__(cursor: DbCursor, job_id: Union[str, uuid.UUID]) -> dict: + """Fetch extra job metadata.""" + cursor.execute("SELECT * FROM jobs_metadata WHERE job_id=?", (str(job_id),)) + return { + row["metadata_key"]: row["metadata_value"] + for row in cursor.fetchall() + } + + +def job_stdstream_outputs(conn, job_id, streamname: str): + """Fetch the standard-error output for the job.""" + with _cursor(conn) as cursor: + cursor.execute( + "SELECT * FROM jobs_standard_outputs " + "WHERE job_id=? AND output_stream=?", + (str(job_id), streamname)) + return dict(cursor.fetchone() or {}).get("value") + + +job_stderr = partial(job_stdstream_outputs, streamname="stderr") +job_stdout = partial(job_stdstream_outputs, streamname="stdout") + + +def job(conn: DbConnection, job_id: Union[str, uuid.UUID], fulldetails: bool = False) -> dict: + """Fetch the job details for a job with a particular ID""" + with _cursor(conn) as cursor: + cursor.execute("SELECT * FROM jobs WHERE job_id=?", (str(job_id),)) + _job = dict(cursor.fetchone() or {}) + if not bool(_job): + raise JobNotFound(f"Could not find job with ID {job_id}") + + _job["metadata"] = __job_metadata__(cursor, job_id) + + if fulldetails: + _job["stderr"] = job_stderr(conn, job_id) + _job["stdout"] = job_stdout(conn, job_id) + + return _job + + +def __save_job__(conn: DbConnection, the_job: dict, expiry_seconds: int) -> dict: + """Save the job to database.""" + + with _cursor(conn) as cursor: + job_id = str(the_job["job_id"]) + expires = ((the_job["created"] + timedelta(seconds=expiry_seconds)) + if expiry_seconds > 0 else None) + cursor.execute("INSERT INTO jobs(job_id, created, expires, command) " + "VALUES(:job_id, :created, :expires, :command)", + { + "job_id": job_id, + "created": the_job["created"].isoformat(), + "expires": (expires and expires.isoformat()), + "command": the_job["command"] + }) + metadata = tuple({"job_id": job_id, "key": key, "value": value} + for key,value in the_job["metadata"].items()) + if len(metadata) > 0: + cursor.executemany( + "INSERT INTO jobs_metadata(job_id, metadata_key, metadata_value) " + "VALUES (:job_id, :key, :value)", + metadata) + + return the_job + + +def initialise_job(# pylint: disable=[too-many-arguments, too-many-positional-arguments] + conn: DbConnection, + job_id: uuid.UUID, + command: list, + job_type: str, + extra_meta: Optional[dict] = None, + expiry_seconds: int = _DEFAULT_EXPIRY_SECONDS_ +) -> dict: + """Initialise the job and put the details in a SQLite3 database.""" + if extra_meta is None: + extra_meta = {} + + _job = { + "job_id": job_id, + "command": shlex.join(command), + "created": datetime.now(timezone.utc), + "metadata": { + "status": "pending", + "percent": 0, + "job-type": job_type, + **extra_meta + } + } + return __save_job__(conn, _job, expiry_seconds) + + +def output_file(jobid: uuid.UUID, outdir: Path, stream: str) -> Path: + """Compute the path for the file where the launcher's `stream` output goes""" + assert stream in ("stdout", "stderr"), f"Invalid stream '{stream}'" + return outdir.joinpath(f"launcher_job_{jobid}.{stream}") + + +stdout_filename = partial(output_file, stream="stdout") +stderr_filename = partial(output_file, stream="stderr") + + +def build_environment(extras: Optional[dict[str, str]] = None) -> dict[str, str]: + """Setup the runtime environment variables for the background script.""" + if extras is None: + extras = {} + + return { + **dict(os.environ), + "PYTHONPATH": ":".join(sys.path), + **extras + } + + +def launch_job( + the_job: dict, + sqlite3_url: str, + error_dir: Path, + worker_manager: str = "gn_libs.jobs.launcher", + loglevel: str = "info" +) -> dict: + """Launch a job in the background""" + if not os.path.exists(error_dir): + os.mkdir(error_dir) + + job_id = str(the_job["job_id"]) + with (open(stderr_filename(jobid=the_job["job_id"], outdir=error_dir), + "w", + encoding="utf-8") as stderrfile, + open(stdout_filename(jobid=the_job["job_id"], outdir=error_dir), + "w", + encoding="utf-8") as stdoutfile): + subprocess.Popen( # pylint: disable=[consider-using-with] + [ + sys.executable, "-u", + "-m", worker_manager, + sqlite3_url, + job_id, + str(error_dir), + "--log-level", + loglevel + ], + stdout=stdoutfile, + stderr=stderrfile, + env=build_environment()) + + return the_job + + +def update_metadata(conn: DbConnection, job_id: Union[str, uuid.UUID], key: str, value: str): + """Update the value of a metadata item.""" + with _cursor(conn) as cursor: + cursor.execute( + "INSERT INTO jobs_metadata(job_id, metadata_key, metadata_value) " + "VALUES (:job_id, :key, :value) " + "ON CONFLICT (job_id, metadata_key) DO UPDATE " + "SET metadata_value=:value " + "WHERE job_id=:job_id AND metadata_key=:key", + { + "job_id": str(job_id), + "key": key, + "value": value + }) + + +def push_to_stream( + conn: DbConnection, + job_id: Union[str, uuid.UUID], + stream_name: str, content: str +): + """Initialise, and keep adding content to the stream from the provided content.""" + with _cursor(conn) as cursor: + cursor.execute("SELECT * FROM jobs_standard_outputs " + "WHERE job_id=:job_id AND output_stream=:stream", + { + "job_id": str(job_id), + "stream": stream_name + }) + result = cursor.fetchone() + new_content = ((bool(result) and result["value"]) or "") + content + cursor.execute( + "INSERT INTO jobs_standard_outputs(job_id, output_stream, value) " + "VALUES(:job_id, :stream, :content) " + "ON CONFLICT (job_id, output_stream) DO UPDATE " + "SET value=:content " + "WHERE job_id=:job_id AND output_stream=:stream", + { + "job_id": str(job_id), + "stream": stream_name, + "content": new_content + }) diff --git a/gn_libs/jobs/launcher.py b/gn_libs/jobs/launcher.py new file mode 100644 index 0000000..d565f9e --- /dev/null +++ b/gn_libs/jobs/launcher.py @@ -0,0 +1,108 @@ +"""Default launcher/manager script for background jobs.""" +import os +import sys +import time +import shlex +import logging +import argparse +import traceback +import subprocess +from uuid import UUID +from pathlib import Path + +from gn_libs import jobs, sqlite3 + +logger = logging.getLogger(__name__) + + +def run_job(conn, job, outputs_directory: Path): + """Run the job.""" + logger.info("Setting up the job.") + job_id = job["job_id"] + stdout_file = outputs_directory.joinpath(f"{job_id}.stdout") + stderr_file = outputs_directory.joinpath(f"{job_id}.stderr") + jobs.update_metadata(conn, job_id, "stdout-file", str(stdout_file)) + jobs.update_metadata(conn, job_id, "stderr-file", str(stderr_file)) + try: + logger.info("Launching the job in a separate process.") + with (stdout_file.open(mode="w") as outfile, + stderr_file.open(mode="w") as errfile, + stdout_file.open(mode="r") as stdout_in, + stderr_file.open(mode="r") as stderr_in, + subprocess.Popen( + shlex.split(job["command"]), + encoding="utf-8", + stdout=outfile, + stderr=errfile) as process): + while process.poll() is None: + jobs.update_metadata(conn, job_id, "status", "running") + jobs.push_to_stream(conn, job_id, "stdout", stdout_in.read()) + jobs.push_to_stream(conn, job_id, "stderr", stderr_in.read()) + time.sleep(1) + + # Fetch any remaining content. + jobs.push_to_stream(conn, job_id, "stdout", stdout_in.read()) + jobs.push_to_stream(conn, job_id, "stderr", stderr_in.read()) + logger.info("Job completed. Cleaning up.") + + os.remove(stdout_file) + os.remove(stderr_file) + exit_status = process.poll() + if exit_status == 0: + jobs.update_metadata(conn, job_id, "status", "completed") + else: + jobs.update_metadata(conn, job_id, "status", "error") + + logger.info("exiting job manager/launcher") + return exit_status + except Exception as _exc:# pylint: disable=[broad-exception-caught] + logger.error("An exception was raised when attempting to run the job", + exc_info=True) + jobs.update_metadata(conn, job_id, "status", "error") + jobs.push_to_stream(conn, job_id, "stderr", traceback.format_exc()) + return 4 + + +def parse_args(): + """Define and parse CLI args.""" + parser = argparse.ArgumentParser( + prog="GN Jobs Launcher", + description = ( + "Generic launcher and manager of jobs defined with gn-libs")) + parser.add_argument( + "jobs_db_uri", + help="The URI to the SQLite3 database holding the jobs' details") + parser.add_argument( + "job_id", help="The id of the job being processed", type=UUID) + parser.add_argument("outputs_directory", + help="Directory where output files will be created", + type=Path) + parser.add_argument( + "--log-level", + type=str, + help="Determines what is logged out.", + choices=("debug", "info", "warning", "error", "critical"), + default="info") + return parser.parse_args() + + +def main(): + """Entry-point to this program.""" + args = parse_args() + logger.setLevel(args.log_level.upper()) + args.outputs_directory.mkdir(parents=True, exist_ok=True) + with sqlite3.connection(args.jobs_db_uri) as conn: + job = jobs.job(conn, args.job_id) + if job: + return run_job(conn, job, args.outputs_directory) + + jobs.update_metadata(conn, args.job_id, "status", "error") + jobs.push_to_stream(conn, args.job_id, "stderr", "Job not found!") + return 2 + + return 3 + + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/gn_libs/jobs/migrations.py b/gn_libs/jobs/migrations.py new file mode 100644 index 0000000..0c9825b --- /dev/null +++ b/gn_libs/jobs/migrations.py @@ -0,0 +1,69 @@ +"""Database migrations for the jobs to ensure consistency downstream.""" +from gn_libs.protocols import DbCursor +from gn_libs.sqlite3 import connection, cursor as acquire_cursor + +def __create_table_jobs__(cursor: DbCursor): + """Create the jobs table""" + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS jobs( + job_id TEXT PRIMARY KEY NOT NULL, + created TEXT NOT NULL, + expires TEXT, + command TEXT NOT NULL + ) WITHOUT ROWID + """) + + +def __create_table_jobs_metadata__(cursor: DbCursor): + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS jobs_metadata( + job_id TEXT, + metadata_key TEXT NOT NULL, + metadata_value TEXT NOT NULL, + FOREIGN KEY(job_id) REFERENCES jobs(job_id) + ON UPDATE CASCADE ON DELETE RESTRICT, + PRIMARY KEY(job_id, metadata_key) + ) WITHOUT ROWID + """) + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_tbl_jobs_metadata_cols_job_id + ON jobs_metadata(job_id) + """) + + +def __create_table_jobs_output_streams__(cursor: DbCursor): + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS jobs_standard_outputs( + job_id TEXT NOT NULL, + output_stream TEXT, + value TEXT, + FOREIGN KEY(job_id) REFERENCES jobs(job_id) + ON UPDATE CASCADE ON DELETE RESTRICT, + CHECK (output_stream IN ('stdout', 'stderr')), + PRIMARY KEY(job_id, output_stream) + ) WITHOUT ROWID + """) + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_tbl_jobs_standard_outputs_cols_job_id + ON jobs_standard_outputs(job_id) + """) + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS + idx_tbl_jobs_standard_outputs_cols_job_id_output_stream + ON jobs_standard_outputs(job_id, output_stream) + """) + + +def run_migrations(sqlite_url: str): + """Run the migrations to setup the background jobs database.""" + with (connection(sqlite_url) as conn, + acquire_cursor(conn) as curr): + __create_table_jobs__(curr) + __create_table_jobs_metadata__(curr) + __create_table_jobs_output_streams__(curr) diff --git a/gn_libs/monadic_requests.py b/gn_libs/monadic_requests.py new file mode 100644 index 0000000..a09acc5 --- /dev/null +++ b/gn_libs/monadic_requests.py @@ -0,0 +1,74 @@ +"""Wrap requests functions with monads.""" +import logging + +import requests +from requests.models import Response +from pymonad.either import Left, Right, Either + +logger = logging.getLogger(__name__) + +# HTML Status codes indicating a successful request. +SUCCESS_CODES = (200, 201, 202, 203, 204, 205, 206, 207, 208, 226) + + +def get(url, params=None, **kwargs) -> Either: + """ + A wrapper around `requests.get` function. + + Takes the same arguments as `requests.get`. + + :rtype: pymonad.either.Either + """ + timeout = kwargs.get("timeout") + kwargs = {key: val for key,val in kwargs.items() if key != "timeout"} + if timeout is None: + timeout = (9.13, 20) + + try: + resp = requests.get(url, params=params, timeout=timeout, **kwargs) + if resp.status_code in SUCCESS_CODES: + return Right(resp.json()) + return Left(resp) + except requests.exceptions.RequestException as exc: + return Left(exc) + + +def post(url, data=None, json=None, **kwargs) -> Either: + """ + A wrapper around `requests.post` function. + + Takes the same arguments as `requests.post`. + + :rtype: pymonad.either.Either + """ + timeout = kwargs.get("timeout") + kwargs = {key: val for key,val in kwargs.items() if key != "timeout"} + if timeout is None: + timeout = (9.13, 20) + + try: + resp = requests.post(url, data=data, json=json, timeout=timeout, **kwargs) + if resp.status_code in SUCCESS_CODES: + return Right(resp.json()) + return Left(resp) + except requests.exceptions.RequestException as exc: + return Left(exc) + + +def make_either_error_handler(msg): + """Make generic error handler for pymonads Either objects.""" + def __fail__(error): + if issubclass(type(error), Exception): + logger.debug("\n\n%s (Exception)\n\n", msg, exc_info=True) + raise error + if issubclass(type(error), Response): + try: + _data = error.json() + except Exception as _exc: + raise Exception(error.content) from _exc# pylint: disable=[broad-exception-raised] + raise Exception(_data)# pylint: disable=[broad-exception-raised] + + logger.debug("\n\n%s\n\n", msg) + raise Exception(error)# pylint: disable=[broad-exception-raised] + + return __fail__ diff --git a/gn_libs/mysqldb.py b/gn_libs/mysqldb.py index 64a649d..3f6390e 100644 --- a/gn_libs/mysqldb.py +++ b/gn_libs/mysqldb.py @@ -9,7 +9,7 @@ import MySQLdb as mdb from MySQLdb.cursors import Cursor -_logger = logging.getLogger(__file__) +_logger = logging.getLogger(__name__) class InvalidOptionValue(Exception): """Raised whenever a parsed value is invalid for the specific option.""" @@ -46,6 +46,12 @@ def __parse_ssl_mode_options__(val: str) -> str: def __parse_ssl_options__(val: str) -> dict: + if val.strip() == "" or val.strip().lower() == "false": + return False + + if val.strip().lower() == "true": + return True + allowed_keys = ("key", "cert", "ca", "capath", "cipher") opts = { key.strip(): val.strip() for key,val in @@ -61,6 +67,7 @@ def __parse_db_opts__(opts: str) -> dict: This assumes use of python-mysqlclient library.""" allowed_opts = ( + # See: https://mysqlclient.readthedocs.io/user_guide.html#functions-and-attributes "unix_socket", "connect_timeout", "compress", "named_pipe", "init_command", "read_default_file", "read_default_group", "cursorclass", "use_unicode", "charset", "collation", "auth_plugin", @@ -124,13 +131,21 @@ class Connection(Protocol): @contextlib.contextmanager def database_connection(sql_uri: str, logger: logging.Logger = _logger) -> Iterator[Connection]: """Connect to MySQL database.""" - connection = mdb.connect(**parse_db_url(sql_uri)) + _conn_opts = parse_db_url(sql_uri) + _logger.debug("Connecting to database with the following options: %s", + _conn_opts) + connection = mdb.connect(**_conn_opts) try: yield connection connection.commit() except mdb.Error as _mbde: logger.error("DB error encountered", exc_info=True) connection.rollback() + raise _mbde from None + except Exception as _exc: + connection.rollback() + logger.error("General exception encountered", exc_info=True) + raise _exc from None finally: connection.close() diff --git a/gn_libs/privileges.py b/gn_libs/privileges.py new file mode 100644 index 0000000..32c943d --- /dev/null +++ b/gn_libs/privileges.py @@ -0,0 +1,166 @@ +"""Utilities for handling privileges.""" +import logging +from functools import reduce +from typing import Union, Sequence, Iterator, TypeAlias, TypedDict + +logger = logging.getLogger(__name__) + +Operator: TypeAlias = str # Valid operators: "AND", "OR" +Privilege: TypeAlias = str +PrivilegesList: TypeAlias = Sequence[Privilege] +ParseTree = tuple[Operator, + # Leaves (`PrivilegesList` objects) on the left, + # trees (`ParseTree` objects) on the right + Union[PrivilegesList, tuple[PrivilegesList, 'ParseTree']]] + + +class SpecificationValueError(ValueError): + """Raised when there is an error in the specification string.""" + + +_OPERATORS_ = ("OR", "AND") +_EMPTY_SPEC_ERROR_ = SpecificationValueError( + "Empty specification. I do not know what to do.") + + +def __add_leaves__( + index: int, + tree: tuple[Operator], + leaves: dict +) -> Union[tuple[Operator], Union[ParseTree, tuple]]: + """Add leaves to the tree.""" + if leaves.get(index): + return tree + (leaves[index],) + return tree + (tuple()) + + +class ParsingState(TypedDict): + """Class to create a state object. Mostly used to silence MyPy""" + tokens: list[str] + trees: list[tuple[int, int, str, int, int]]#(name, parent, operator, start, end) + open_parens: int + current_tree: int + leaves: dict[int, tuple[str, ...]]#[parent-tree, [index, index, ...]] + + +def __build_tree__(tree_state: ParsingState) -> ParseTree: + """Given computed state, build the actual tree.""" + _built = [] + for idx, tree in enumerate(tree_state["trees"]): + _built.append(__add_leaves__(idx, (tree[2],), tree_state["leaves"])) + + logger.debug("Number of built trees: %s, %s", len(_built), _built) + _num_trees = len(_built) + for idx in range(0, _num_trees): + _last_tree = _built.pop() + logger.debug("LAST TREE: %s, %s", _last_tree, len(_last_tree)) + if len(_last_tree) <= 1:# Has no leaves or subtrees + _last_tree = None# type: ignore[assignment] + continue# more evil + _name = tree_state["trees"][_num_trees - 1 - idx][0] + _parent = tree_state["trees"][ + tree_state["trees"][_num_trees - 1 - idx][1]] + _op = tree_state["trees"][_num_trees - 1 - idx][2] + logger.debug("TREE => name: %s, operation: %s, parent: %s", + _name, _op, _parent) + if _name != _parent[0]:# not root tree + if _op == _parent[2]: + _built[_parent[0]] = ( + _built[_parent[0]][0],# Operator + _built[_parent[0]][1] + _last_tree[1]# merge leaves + ) + _last_tree[2:]#Add any trees left over + else: + _built[_parent[0]] += (_last_tree,) + + if _last_tree is None: + raise _EMPTY_SPEC_ERROR_ + return _last_tree + + +def __parse_tree__(tokens: Iterator[str]) -> ParseTree: + """Parse the tokens into a tree.""" + _state = ParsingState( + tokens=[], trees=[], open_parens=0, current_tree=0, leaves={}) + for _idx, _token in enumerate(tokens): + _state["tokens"].append(_token) + + if _idx==0: + if _token[1:].upper() not in _OPERATORS_: + raise SpecificationValueError(f"Invalid operator: {_token[1:]}") + _state["open_parens"] += 1 + _state["trees"].append((0, 0, _token[1:].upper(), _idx, -1)) + _state["current_tree"] = 0 + continue# this is bad! + + if _token == ")":# end a tree + logger.debug("ENDING A TREE: %s", _state) + _state["open_parens"] -= 1 + _state["trees"][_state["current_tree"]] = ( + _state["trees"][_state["current_tree"]][0:-1] + (_idx,)) + # We go back to the parent below. + _state["current_tree"] = _state["trees"][_state["current_tree"]][1] + continue# still really bad! + + if _token[1:].upper() in _OPERATORS_:# new child tree + _state["open_parens"] += 1 + _state["trees"].append((len(_state["trees"]), + _state["current_tree"], + _token[1:].upper(), + _idx, + -1)) + _state["current_tree"] = len(_state["trees"]) - 1 + continue# more evil still + + logger.debug("state: %s", _state) + # leaves + _state["leaves"][_state["current_tree"]] = _state["leaves"].get( + _state["current_tree"], tuple()) + (_token,) + + # Build parse-tree from state + if _state["open_parens"] != 0: + raise SpecificationValueError("Unbalanced parentheses.") + return __build_tree__(_state) + + +def __tokenise__(spec: str) -> Iterator[str]: + """Clean up and tokenise the string.""" + return (token.strip() + for token in spec.replace( + "(", " (" + ).replace( + ")", " ) " + ).replace( + "( ", "(" + ).split()) + + +def parse(spec: str) -> ParseTree: + """Parse a string specification for privileges and return a tree of data + objects of the form (<operator> (<check>))""" + if spec.strip() == "": + raise _EMPTY_SPEC_ERROR_ + + return __parse_tree__(__tokenise__(spec)) + + +def __make_checker__(check_fn): + def __checker__(privileges, *checks): + def __check__(acc, curr): + if curr[0] in _OPERATORS_: + return acc + (_OPERATOR_FUNCTION_[curr[0]]( + privileges, *curr[1:]),) + return acc + (check_fn((priv in privileges) for priv in curr),) + results = reduce(__check__, checks, tuple()) + return len(results) > 0 and check_fn(results) + + return __checker__ + + +_OPERATOR_FUNCTION_ = { + "OR": __make_checker__(any), + "AND": __make_checker__(all) +} +def check(spec: str, privileges: tuple[str, ...]) -> bool: + """Check that the sequence of `privileges` satisfies `spec`.""" + _spec = parse(spec) + return _OPERATOR_FUNCTION_[_spec[0]](privileges, *_spec[1:]) diff --git a/gn_libs/protocols/__init__.py b/gn_libs/protocols/__init__.py new file mode 100644 index 0000000..83a31a8 --- /dev/null +++ b/gn_libs/protocols/__init__.py @@ -0,0 +1,2 @@ +"""This package is a collection of major Protocols/Interfaces definitions.""" +from .db import DbCursor, DbConnection diff --git a/gn_libs/protocols/db.py b/gn_libs/protocols/db.py new file mode 100644 index 0000000..b365f8b --- /dev/null +++ b/gn_libs/protocols/db.py @@ -0,0 +1,35 @@ +"""Generic database protocols.""" +from typing import Any, Protocol + + +class DbCursor(Protocol): + """Type annotation for a generic database cursor object.""" + def execute(self, *args, **kwargs) -> Any: + """Execute a single query""" + + def executemany(self, *args, **kwargs) -> Any: + """ + Execute parameterized SQL statement sql against all parameter sequences + or mappings found in the sequence parameters. + """ + + def fetchone(self, *args, **kwargs): + """Fetch single result if present, or `None`.""" + + def fetchmany(self, *args, **kwargs): + """Fetch many results if present or `None`.""" + + def fetchall(self, *args, **kwargs): + """Fetch all results if present or `None`.""" + + +class DbConnection(Protocol): + """Type annotation for a generic database connection object.""" + def cursor(self) -> Any: + """A cursor object""" + + def commit(self) -> Any: + """Commit the transaction.""" + + def rollback(self) -> Any: + """Rollback the transaction.""" diff --git a/gn_libs/sqlite3.py b/gn_libs/sqlite3.py new file mode 100644 index 0000000..78e1c41 --- /dev/null +++ b/gn_libs/sqlite3.py @@ -0,0 +1,45 @@ +"""This module deals with connections to a(n) SQLite3 database.""" +import logging +import traceback +import contextlib +from typing import Callable, Iterator + +import sqlite3 + +from .protocols import DbCursor, DbConnection + +_logger_ = logging.getLogger(__name__) + + +@contextlib.contextmanager +def connection(db_path: str, row_factory: Callable = sqlite3.Row) -> Iterator[DbConnection]: + """Create the connection to the auth database.""" + logging.debug("SQLite3 DB Path: '%s'.", db_path) + conn = sqlite3.connect(db_path) + conn.row_factory = row_factory + conn.set_trace_callback(logging.debug) + conn.execute("PRAGMA foreign_keys = ON") + try: + yield conn + except sqlite3.Error as exc: + conn.rollback() + _logger_.debug(traceback.format_exc()) + raise exc + finally: + conn.commit() + conn.close() + + +@contextlib.contextmanager +def cursor(conn: DbConnection) -> Iterator[DbCursor]: + """Get a cursor from the given connection to the auth database.""" + cur = conn.cursor() + try: + yield cur + conn.commit() + except sqlite3.Error as exc: + conn.rollback() + _logger_.debug(traceback.format_exc()) + raise exc + finally: + cur.close() |
