diff options
Diffstat (limited to 'gn_libs')
| -rw-r--r-- | gn_libs/debug.py | 13 | ||||
| -rw-r--r-- | gn_libs/http_logging.py | 56 | ||||
| -rw-r--r-- | gn_libs/jobs/__init__.py | 8 | ||||
| -rw-r--r-- | gn_libs/jobs/jobs.py | 146 | ||||
| -rw-r--r-- | gn_libs/jobs/launcher.py | 41 | ||||
| -rw-r--r-- | gn_libs/jobs/migrations.py | 29 | ||||
| -rw-r--r-- | gn_libs/logging.py | 41 | ||||
| -rw-r--r-- | gn_libs/monadic_requests.py | 20 | ||||
| -rw-r--r-- | gn_libs/mysqldb.py | 23 | ||||
| -rw-r--r-- | gn_libs/privileges.py | 166 | ||||
| -rw-r--r-- | gn_libs/protocols/__init__.py | 1 | ||||
| -rw-r--r-- | gn_libs/sqlite3.py | 3 |
12 files changed, 508 insertions, 39 deletions
diff --git a/gn_libs/debug.py b/gn_libs/debug.py index c1b896e..7ad10e0 100644 --- a/gn_libs/debug.py +++ b/gn_libs/debug.py @@ -1,6 +1,7 @@ """Debug utilities""" import logging import importlib.util +from typing import Callable __this_module_name__ = __name__ @@ -26,3 +27,15 @@ def __pk__(*args): logger = getLogger(__this_module_name__) logger.debug("%s: %s", title_vals, value) return value + + +def make_peeker(logger: logging.Logger) -> Callable: + """Make a peeker function that's very much like __pk__ but that uses the + given logger.""" + def peeker(*args): + value = args[-1] + title_vals = " => ".join(args[0:-1]) + logger.debug("%s: %s", title_vals, value) + return value + + return peeker diff --git a/gn_libs/http_logging.py b/gn_libs/http_logging.py new file mode 100644 index 0000000..c65e0a4 --- /dev/null +++ b/gn_libs/http_logging.py @@ -0,0 +1,56 @@ +"""Provide a way to emit logs to an HTTP endpoint""" +import logging +import json +import traceback +import urllib.request +from datetime import datetime + + +class SilentHTTPHandler(logging.Handler): + """A logging handler that emits logs to an HTTP endpoint silently. + + This handler converts log records to JSON and sends them via POST + to a specified HTTP endpoint. Failures are suppressed to avoid + interfering with the main application. + """ + def __init__(self, endpoint, timeout=0.1): + super().__init__() + self.endpoint = endpoint + self.timeout = timeout + + def emit(self, record): + try: + payload = { + "timestamp": datetime.utcfromtimestamp(record.created).isoformat(), + "level": record.levelname.lower(), + "logger": record.name, + "message": record.getMessage(), + } + for attr in ("remote_addr", "user_agent", "extra"): + if hasattr(record, attr): + payload.update({attr: getattr(record, attr)}) + + if record.exc_info: + payload["exception"] = "".join( + traceback.format_exception(*record.exc_info) + ) + + # fire-and-forget + self._send(payload) + + except Exception:# pylint: disable=[broad-exception-caught] + # absolute silence + pass + + def _send(self, payload): + try: + req = urllib.request.Request( + url=self.endpoint, + data=json.dumps(payload).encode("utf-8"), + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(req, timeout=5) as resp: + resp.read() # ignore body + except Exception:# pylint: disable=[broad-exception-caught] + pass diff --git a/gn_libs/jobs/__init__.py b/gn_libs/jobs/__init__.py index 6f400ef..7927f8d 100644 --- a/gn_libs/jobs/__init__.py +++ b/gn_libs/jobs/__init__.py @@ -1,9 +1,15 @@ +"""This package deals with launching and managing background/async jobs.""" from .migrations import run_migrations from .jobs import (job, + kill_job, launch_job, + delete_job, + delete_jobs, initialise_job, push_to_stream, - update_metadata) + update_metadata, + jobs_by_external_id, + delete_expired_jobs) def init_app(flask_app): """Initialise the migrations for flask""" diff --git a/gn_libs/jobs/jobs.py b/gn_libs/jobs/jobs.py index 1f66772..bccddd5 100644 --- a/gn_libs/jobs/jobs.py +++ b/gn_libs/jobs/jobs.py @@ -6,7 +6,6 @@ import shlex import logging import subprocess from pathlib import Path -from functools import reduce from functools import partial from typing import Union, Optional from datetime import datetime, timezone, timedelta @@ -41,14 +40,14 @@ def job_stdstream_outputs(conn, job_id, streamname: str): job_stderr = partial(job_stdstream_outputs, streamname="stderr") -job_stdout = partial(job_stdstream_outputs, streamname="stderr") +job_stdout = partial(job_stdstream_outputs, streamname="stdout") def job(conn: DbConnection, job_id: Union[str, uuid.UUID], fulldetails: bool = False) -> dict: """Fetch the job details for a job with a particular ID""" with _cursor(conn) as cursor: cursor.execute("SELECT * FROM jobs WHERE job_id=?", (str(job_id),)) - _job = dict(cursor.fetchone()) + _job = dict(cursor.fetchone() or {}) if not bool(_job): raise JobNotFound(f"Could not find job with ID {job_id}") @@ -61,7 +60,34 @@ def job(conn: DbConnection, job_id: Union[str, uuid.UUID], fulldetails: bool = F return _job -def __save_job__(conn: DbConnection, the_job: dict, expiry_seconds: int) -> dict: +def jobs_by_external_id(conn: DbConnection, external_id: Union[str, uuid.UUID]) -> tuple[dict, ...]: + """Fetch jobs by their external IDs.""" + with _cursor(conn) as cursor: + cursor.execute( + "SELECT jeids.external_id, jobs.* FROM jobs_external_ids AS jeids " + "INNER JOIN jobs ON jeids.job_id=jobs.job_id " + "WHERE jeids.external_id=? " + "ORDER BY jobs.created DESC", + (str(external_id),)) + _jobs = {row["job_id"]: {**dict(row), "metadata": {}} for row in cursor.fetchall()} + _jobs_ids = tuple(_job["job_id"] for _job in _jobs.values()) + + _paramstr = ", ".join(["?"] * len(_jobs_ids)) + cursor.execute( + f"SELECT * FROM jobs_metadata WHERE job_id IN ({_paramstr})", + _jobs_ids) + for row in cursor.fetchall(): + _jobs[row["job_id"]]["metadata"][row["metadata_key"]] = row["metadata_value"] + + return tuple(_jobs.values()) + + +def __save_job__( + conn: DbConnection, + the_job: dict, + expiry_seconds: int, + external_id: str = "" +) -> dict: """Save the job to database.""" with _cursor(conn) as cursor: @@ -76,6 +102,11 @@ def __save_job__(conn: DbConnection, the_job: dict, expiry_seconds: int) -> dict "expires": (expires and expires.isoformat()), "command": the_job["command"] }) + if bool(external_id.strip()): + cursor.execute( + "INSERT INTO jobs_external_ids(job_id, external_id) " + "VALUES(:job_id, :external_id)", + {"job_id": job_id, "external_id": external_id.strip()}) metadata = tuple({"job_id": job_id, "key": key, "value": value} for key,value in the_job["metadata"].items()) if len(metadata) > 0: @@ -87,16 +118,28 @@ def __save_job__(conn: DbConnection, the_job: dict, expiry_seconds: int) -> dict return the_job -def initialise_job( +def initialise_job(# pylint: disable=[too-many-arguments, too-many-positional-arguments] conn: DbConnection, job_id: uuid.UUID, command: list, job_type: str, - extra_meta: dict = {}, - expiry_seconds: Optional[int] = _DEFAULT_EXPIRY_SECONDS_ + extra_meta: Optional[dict] = None, + expiry_seconds: int = _DEFAULT_EXPIRY_SECONDS_, + external_id: Optional[Union[str, uuid.UUID]] = None ) -> dict: """Initialise the job and put the details in a SQLite3 database.""" - + if extra_meta is None: + extra_meta = {} + + def __process_external_id__(_id: Optional[Union[str, uuid.UUID]]) -> str: + if isinstance(_id, uuid.UUID): + return str(_id) + + if _id is not None and bool(_id.strip()): + return str(_id.strip()) + return "" + + _ext_id = __process_external_id__(external_id) _job = { "job_id": job_id, "command": shlex.join(command), @@ -105,18 +148,28 @@ def initialise_job( "status": "pending", "percent": 0, "job-type": job_type, - **extra_meta + **extra_meta, + **({"external_id": _ext_id} if bool(_ext_id) else {}) } } - return __save_job__(conn, _job, expiry_seconds) + return __save_job__(conn, _job, expiry_seconds, _ext_id) + + +def output_file(jobid: uuid.UUID, outdir: Path, stream: str) -> Path: + """Compute the path for the file where the launcher's `stream` output goes""" + assert stream in ("stdout", "stderr"), f"Invalid stream '{stream}'" + return outdir.joinpath(f"launcher_job_{jobid}.{stream}") + +stdout_filename = partial(output_file, stream="stdout") +stderr_filename = partial(output_file, stream="stderr") -def error_filename(jobid, error_dir): - "Compute the path of the file where errors will be dumped." - return f"{error_dir}/job_{jobid}.error" +def build_environment(extras: Optional[dict[str, str]] = None) -> dict[str, str]: + """Setup the runtime environment variables for the background script.""" + if extras is None: + extras = {} -def build_environment(extras: dict[str, str] = {}): return { **dict(os.environ), "PYTHONPATH": ":".join(sys.path), @@ -128,24 +181,32 @@ def launch_job( the_job: dict, sqlite3_url: str, error_dir: Path, - worker_manager: str = "gn_libs.jobs.launcher" + worker_manager: str = "gn_libs.jobs.launcher", + loglevel: str = "info" ) -> dict: """Launch a job in the background""" if not os.path.exists(error_dir): os.mkdir(error_dir) job_id = str(the_job["job_id"]) - with open(error_filename(job_id, error_dir), - "w", - encoding="utf-8") as errorfile: + with (open(stderr_filename(jobid=the_job["job_id"], outdir=error_dir), + "w", + encoding="utf-8") as stderrfile, + open(stdout_filename(jobid=the_job["job_id"], outdir=error_dir), + "w", + encoding="utf-8") as stdoutfile): subprocess.Popen( # pylint: disable=[consider-using-with] [ sys.executable, "-u", "-m", worker_manager, sqlite3_url, job_id, - str(error_dir)], - stderr=errorfile, + str(error_dir), + "--log-level", + loglevel + ], + stdout=stdoutfile, + stderr=stderrfile, env=build_environment()) return the_job @@ -167,7 +228,11 @@ def update_metadata(conn: DbConnection, job_id: Union[str, uuid.UUID], key: str, }) -def push_to_stream(conn: DbConnection, job_id: Union[str, uuid.UUID], stream_name: str, content: str): +def push_to_stream( + conn: DbConnection, + job_id: Union[str, uuid.UUID], + stream_name: str, content: str +): """Initialise, and keep adding content to the stream from the provided content.""" with _cursor(conn) as cursor: cursor.execute("SELECT * FROM jobs_standard_outputs " @@ -189,3 +254,42 @@ def push_to_stream(conn: DbConnection, job_id: Union[str, uuid.UUID], stream_nam "stream": stream_name, "content": new_content }) + + +def delete_jobs( + conn: DbConnection, job_ids: tuple[Union[uuid.UUID, str], ...]) -> None: + """Delete the given jobs.""" + with _cursor(conn) as cursor: + _paramstr = ", ".join(["?"] * len(job_ids)) + _params = tuple(str(job_id) for job_id in job_ids) + cursor.execute( + f"DELETE FROM jobs_standard_outputs WHERE job_id IN ({_paramstr})", + _params) + cursor.execute( + f"DELETE FROM jobs_metadata WHERE job_id IN ({_paramstr})", + _params) + cursor.execute( + f"DELETE FROM jobs_external_ids WHERE job_id IN ({_paramstr})", + _params) + cursor.execute(f"DELETE FROM jobs WHERE job_id IN ({_paramstr})", + _params) + + +def delete_job(conn: DbConnection, job_id: Union[uuid.UUID, str]) -> None: + """Delete a specific job.""" + return delete_jobs(conn, (job_id,)) + + +def delete_expired_jobs(conn: DbConnection) -> None: + """Delete all jobs that are expired.""" + with _cursor(conn) as cursor: + cursor.execute( + "SELECT job_id FROM jobs WHERE datetime(expires) <= datetime()") + return delete_jobs( + conn, tuple(row["job_id"] for row in cursor.fetchall())) + + +def kill_job(conn: DbConnection, job_id: Union[uuid.UUID, str]) -> None: + """Send a request to kill the job.""" + return update_metadata( + conn, job_id, "hangup_request", datetime.now(timezone.utc).isoformat()) diff --git a/gn_libs/jobs/launcher.py b/gn_libs/jobs/launcher.py index 5edcd07..f915b81 100644 --- a/gn_libs/jobs/launcher.py +++ b/gn_libs/jobs/launcher.py @@ -1,7 +1,10 @@ +"""Default launcher/manager script for background jobs.""" import os import sys import time import shlex +import signal +import logging import argparse import traceback import subprocess @@ -10,15 +13,19 @@ from pathlib import Path from gn_libs import jobs, sqlite3 +logger = logging.getLogger(__name__) + def run_job(conn, job, outputs_directory: Path): """Run the job.""" + logger.info("Setting up the job.") job_id = job["job_id"] stdout_file = outputs_directory.joinpath(f"{job_id}.stdout") stderr_file = outputs_directory.joinpath(f"{job_id}.stderr") jobs.update_metadata(conn, job_id, "stdout-file", str(stdout_file)) jobs.update_metadata(conn, job_id, "stderr-file", str(stderr_file)) try: + logger.info("Launching the job in a separate process.") with (stdout_file.open(mode="w") as outfile, stderr_file.open(mode="w") as errfile, stdout_file.open(mode="r") as stdout_in, @@ -28,8 +35,13 @@ def run_job(conn, job, outputs_directory: Path): encoding="utf-8", stdout=outfile, stderr=errfile) as process): + jobs.update_metadata(conn, job_id, "status", "running") while process.poll() is None: - jobs.update_metadata(conn, job_id, "status", "running") + _job = jobs.job(conn, job_id, True) + if bool(_job["metadata"].get("hangup_request")): + process.send_signal(signal.SIGHUP) + jobs.update_metadata(conn, job_id, "status", "stopped") + break; jobs.push_to_stream(conn, job_id, "stdout", stdout_in.read()) jobs.push_to_stream(conn, job_id, "stderr", stderr_in.read()) time.sleep(1) @@ -37,11 +49,23 @@ def run_job(conn, job, outputs_directory: Path): # Fetch any remaining content. jobs.push_to_stream(conn, job_id, "stdout", stdout_in.read()) jobs.push_to_stream(conn, job_id, "stderr", stderr_in.read()) + logger.info("Job completed. Cleaning up.") os.remove(stdout_file) os.remove(stderr_file) - return process.poll() - except: + exit_status = process.poll() + if exit_status == 0: + jobs.update_metadata(conn, job_id, "status", "completed") + else: + _job = jobs.job(conn, job_id, True) + if _job["metadata"]["status"] != "stopped": + jobs.update_metadata(conn, job_id, "status", "error") + + logger.info("exiting job manager/launcher") + return exit_status + except Exception as _exc:# pylint: disable=[broad-exception-caught] + logger.error("An exception was raised when attempting to run the job", + exc_info=True) jobs.update_metadata(conn, job_id, "status", "error") jobs.push_to_stream(conn, job_id, "stderr", traceback.format_exc()) return 4 @@ -61,14 +85,21 @@ def parse_args(): parser.add_argument("outputs_directory", help="Directory where output files will be created", type=Path) + parser.add_argument( + "--log-level", + type=str, + help="Determines what is logged out.", + choices=("debug", "info", "warning", "error", "critical"), + default="info") return parser.parse_args() + def main(): """Entry-point to this program.""" args = parse_args() + logger.setLevel(args.log_level.upper()) args.outputs_directory.mkdir(parents=True, exist_ok=True) - with (sqlite3.connection(args.jobs_db_uri) as conn, - sqlite3.cursor(conn) as cursor): + with sqlite3.connection(args.jobs_db_uri) as conn: job = jobs.job(conn, args.job_id) if job: return run_job(conn, job, args.outputs_directory) diff --git a/gn_libs/jobs/migrations.py b/gn_libs/jobs/migrations.py index 86fb958..2af16ae 100644 --- a/gn_libs/jobs/migrations.py +++ b/gn_libs/jobs/migrations.py @@ -1,6 +1,6 @@ """Database migrations for the jobs to ensure consistency downstream.""" from gn_libs.protocols import DbCursor -from gn_libs.sqlite3 import cursor, connection +from gn_libs.sqlite3 import connection, cursor as acquire_cursor def __create_table_jobs__(cursor: DbCursor): """Create the jobs table""" @@ -60,9 +60,34 @@ def __create_table_jobs_output_streams__(cursor: DbCursor): """) +def __create_table_jobs_external_ids__(cursor: DbCursor): + """Create the jobs_external_ids table. + + The purpose of this table is to allow external systems to link background + jobs to specific users/events that triggered them. What the external IDs are + is irrelevant to the background jobs system here, and should not affect how + the system works.""" + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS jobs_external_ids( + job_id TEXT PRIMARY KEY NOT NULL, + external_id TEXT NOT NULL, + FOREIGN KEY(job_id) REFERENCES jobs(job_id) + ON UPDATE CASCADE ON DELETE RESTRICT + ) WITHOUT ROWID + """) + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_tbl_jobs_external_ids_cols_external_id + ON jobs_external_ids(external_id) + """) + + def run_migrations(sqlite_url: str): + """Run the migrations to setup the background jobs database.""" with (connection(sqlite_url) as conn, - cursor(conn) as curr): + acquire_cursor(conn) as curr): __create_table_jobs__(curr) __create_table_jobs_metadata__(curr) __create_table_jobs_output_streams__(curr) + __create_table_jobs_external_ids__(curr) diff --git a/gn_libs/logging.py b/gn_libs/logging.py new file mode 100644 index 0000000..952d30f --- /dev/null +++ b/gn_libs/logging.py @@ -0,0 +1,41 @@ +"""Generalised setup for logging for Genenetwork systems.""" +import os +import logging + +from flask import Flask + +logging.basicConfig( + encoding="utf-8", + format="%(asctime)s — %(filename)s:%(lineno)s — %(levelname)s: %(message)s") + + +def __log_gunicorn__(app: Flask) -> Flask: + """Set up logging for the WSGI environment with GUnicorn""" + logger = logging.getLogger("gunicorn.error") + app.logger.handlers = logger.handlers + app.logger.setLevel(logger.level) + return app + + +def __log_dev__(app: Flask) -> Flask: + """Set up logging for the development environment.""" + root_logger = logging.getLogger() + root_logger.setLevel( + app.config.get("LOG_LEVEL", app.config.get("LOGLEVEL", "WARNING"))) + + return app + + +def setup_logging(app: Flask) -> Flask: + """Set up logging for the application.""" + software, *_version_and_comments = os.environ.get( + "SERVER_SOFTWARE", "").split('/') + return __log_gunicorn__(app) if bool(software) else __log_dev__(app) + + +def setup_modules_logging(app_logger: logging.Logger, modules: tuple[str, ...]): + """Setup module-level loggers to the same log-level as the application.""" + loglevel = logging.getLevelName(app_logger.getEffectiveLevel()) + for module in modules: + _logger = logging.getLogger(module) + _logger.setLevel(loglevel) diff --git a/gn_libs/monadic_requests.py b/gn_libs/monadic_requests.py index 0a3c282..a09acc5 100644 --- a/gn_libs/monadic_requests.py +++ b/gn_libs/monadic_requests.py @@ -19,8 +19,13 @@ def get(url, params=None, **kwargs) -> Either: :rtype: pymonad.either.Either """ + timeout = kwargs.get("timeout") + kwargs = {key: val for key,val in kwargs.items() if key != "timeout"} + if timeout is None: + timeout = (9.13, 20) + try: - resp = requests.get(url, params=params, **kwargs) + resp = requests.get(url, params=params, timeout=timeout, **kwargs) if resp.status_code in SUCCESS_CODES: return Right(resp.json()) return Left(resp) @@ -36,8 +41,13 @@ def post(url, data=None, json=None, **kwargs) -> Either: :rtype: pymonad.either.Either """ + timeout = kwargs.get("timeout") + kwargs = {key: val for key,val in kwargs.items() if key != "timeout"} + if timeout is None: + timeout = (9.13, 20) + try: - resp = requests.post(url, data=data, json=json, **kwargs) + resp = requests.post(url, data=data, json=json, timeout=timeout, **kwargs) if resp.status_code in SUCCESS_CODES: return Right(resp.json()) return Left(resp) @@ -55,10 +65,10 @@ def make_either_error_handler(msg): try: _data = error.json() except Exception as _exc: - raise Exception(error.content) from _exc - raise Exception(_data) + raise Exception(error.content) from _exc# pylint: disable=[broad-exception-raised] + raise Exception(_data)# pylint: disable=[broad-exception-raised] logger.debug("\n\n%s\n\n", msg) - raise Exception(error) + raise Exception(error)# pylint: disable=[broad-exception-raised] return __fail__ diff --git a/gn_libs/mysqldb.py b/gn_libs/mysqldb.py index 64a649d..0239f7e 100644 --- a/gn_libs/mysqldb.py +++ b/gn_libs/mysqldb.py @@ -3,13 +3,13 @@ import logging import contextlib from logging import Logger from urllib.parse import urlparse -from typing import Any, Iterator, Protocol, Callable +from typing import Any, Union, Iterator, Protocol, Callable import MySQLdb as mdb from MySQLdb.cursors import Cursor -_logger = logging.getLogger(__file__) +_logger = logging.getLogger(__name__) class InvalidOptionValue(Exception): """Raised whenever a parsed value is invalid for the specific option.""" @@ -45,7 +45,13 @@ def __parse_ssl_mode_options__(val: str) -> str: return _val -def __parse_ssl_options__(val: str) -> dict: +def __parse_ssl_options__(val: str) -> Union[dict, bool]: + if val.strip() == "" or val.strip().lower() == "false": + return False + + if val.strip().lower() == "true": + return True + allowed_keys = ("key", "cert", "ca", "capath", "cipher") opts = { key.strip(): val.strip() for key,val in @@ -61,6 +67,7 @@ def __parse_db_opts__(opts: str) -> dict: This assumes use of python-mysqlclient library.""" allowed_opts = ( + # See: https://mysqlclient.readthedocs.io/user_guide.html#functions-and-attributes "unix_socket", "connect_timeout", "compress", "named_pipe", "init_command", "read_default_file", "read_default_group", "cursorclass", "use_unicode", "charset", "collation", "auth_plugin", @@ -124,13 +131,21 @@ class Connection(Protocol): @contextlib.contextmanager def database_connection(sql_uri: str, logger: logging.Logger = _logger) -> Iterator[Connection]: """Connect to MySQL database.""" - connection = mdb.connect(**parse_db_url(sql_uri)) + _conn_opts = parse_db_url(sql_uri) + _logger.debug("Connecting to database with the following options: %s", + _conn_opts) + connection = mdb.connect(**_conn_opts) try: yield connection connection.commit() except mdb.Error as _mbde: logger.error("DB error encountered", exc_info=True) connection.rollback() + raise _mbde from None + except Exception as _exc: + connection.rollback() + logger.error("General exception encountered", exc_info=True) + raise _exc from None finally: connection.close() diff --git a/gn_libs/privileges.py b/gn_libs/privileges.py new file mode 100644 index 0000000..32c943d --- /dev/null +++ b/gn_libs/privileges.py @@ -0,0 +1,166 @@ +"""Utilities for handling privileges.""" +import logging +from functools import reduce +from typing import Union, Sequence, Iterator, TypeAlias, TypedDict + +logger = logging.getLogger(__name__) + +Operator: TypeAlias = str # Valid operators: "AND", "OR" +Privilege: TypeAlias = str +PrivilegesList: TypeAlias = Sequence[Privilege] +ParseTree = tuple[Operator, + # Leaves (`PrivilegesList` objects) on the left, + # trees (`ParseTree` objects) on the right + Union[PrivilegesList, tuple[PrivilegesList, 'ParseTree']]] + + +class SpecificationValueError(ValueError): + """Raised when there is an error in the specification string.""" + + +_OPERATORS_ = ("OR", "AND") +_EMPTY_SPEC_ERROR_ = SpecificationValueError( + "Empty specification. I do not know what to do.") + + +def __add_leaves__( + index: int, + tree: tuple[Operator], + leaves: dict +) -> Union[tuple[Operator], Union[ParseTree, tuple]]: + """Add leaves to the tree.""" + if leaves.get(index): + return tree + (leaves[index],) + return tree + (tuple()) + + +class ParsingState(TypedDict): + """Class to create a state object. Mostly used to silence MyPy""" + tokens: list[str] + trees: list[tuple[int, int, str, int, int]]#(name, parent, operator, start, end) + open_parens: int + current_tree: int + leaves: dict[int, tuple[str, ...]]#[parent-tree, [index, index, ...]] + + +def __build_tree__(tree_state: ParsingState) -> ParseTree: + """Given computed state, build the actual tree.""" + _built = [] + for idx, tree in enumerate(tree_state["trees"]): + _built.append(__add_leaves__(idx, (tree[2],), tree_state["leaves"])) + + logger.debug("Number of built trees: %s, %s", len(_built), _built) + _num_trees = len(_built) + for idx in range(0, _num_trees): + _last_tree = _built.pop() + logger.debug("LAST TREE: %s, %s", _last_tree, len(_last_tree)) + if len(_last_tree) <= 1:# Has no leaves or subtrees + _last_tree = None# type: ignore[assignment] + continue# more evil + _name = tree_state["trees"][_num_trees - 1 - idx][0] + _parent = tree_state["trees"][ + tree_state["trees"][_num_trees - 1 - idx][1]] + _op = tree_state["trees"][_num_trees - 1 - idx][2] + logger.debug("TREE => name: %s, operation: %s, parent: %s", + _name, _op, _parent) + if _name != _parent[0]:# not root tree + if _op == _parent[2]: + _built[_parent[0]] = ( + _built[_parent[0]][0],# Operator + _built[_parent[0]][1] + _last_tree[1]# merge leaves + ) + _last_tree[2:]#Add any trees left over + else: + _built[_parent[0]] += (_last_tree,) + + if _last_tree is None: + raise _EMPTY_SPEC_ERROR_ + return _last_tree + + +def __parse_tree__(tokens: Iterator[str]) -> ParseTree: + """Parse the tokens into a tree.""" + _state = ParsingState( + tokens=[], trees=[], open_parens=0, current_tree=0, leaves={}) + for _idx, _token in enumerate(tokens): + _state["tokens"].append(_token) + + if _idx==0: + if _token[1:].upper() not in _OPERATORS_: + raise SpecificationValueError(f"Invalid operator: {_token[1:]}") + _state["open_parens"] += 1 + _state["trees"].append((0, 0, _token[1:].upper(), _idx, -1)) + _state["current_tree"] = 0 + continue# this is bad! + + if _token == ")":# end a tree + logger.debug("ENDING A TREE: %s", _state) + _state["open_parens"] -= 1 + _state["trees"][_state["current_tree"]] = ( + _state["trees"][_state["current_tree"]][0:-1] + (_idx,)) + # We go back to the parent below. + _state["current_tree"] = _state["trees"][_state["current_tree"]][1] + continue# still really bad! + + if _token[1:].upper() in _OPERATORS_:# new child tree + _state["open_parens"] += 1 + _state["trees"].append((len(_state["trees"]), + _state["current_tree"], + _token[1:].upper(), + _idx, + -1)) + _state["current_tree"] = len(_state["trees"]) - 1 + continue# more evil still + + logger.debug("state: %s", _state) + # leaves + _state["leaves"][_state["current_tree"]] = _state["leaves"].get( + _state["current_tree"], tuple()) + (_token,) + + # Build parse-tree from state + if _state["open_parens"] != 0: + raise SpecificationValueError("Unbalanced parentheses.") + return __build_tree__(_state) + + +def __tokenise__(spec: str) -> Iterator[str]: + """Clean up and tokenise the string.""" + return (token.strip() + for token in spec.replace( + "(", " (" + ).replace( + ")", " ) " + ).replace( + "( ", "(" + ).split()) + + +def parse(spec: str) -> ParseTree: + """Parse a string specification for privileges and return a tree of data + objects of the form (<operator> (<check>))""" + if spec.strip() == "": + raise _EMPTY_SPEC_ERROR_ + + return __parse_tree__(__tokenise__(spec)) + + +def __make_checker__(check_fn): + def __checker__(privileges, *checks): + def __check__(acc, curr): + if curr[0] in _OPERATORS_: + return acc + (_OPERATOR_FUNCTION_[curr[0]]( + privileges, *curr[1:]),) + return acc + (check_fn((priv in privileges) for priv in curr),) + results = reduce(__check__, checks, tuple()) + return len(results) > 0 and check_fn(results) + + return __checker__ + + +_OPERATOR_FUNCTION_ = { + "OR": __make_checker__(any), + "AND": __make_checker__(all) +} +def check(spec: str, privileges: tuple[str, ...]) -> bool: + """Check that the sequence of `privileges` satisfies `spec`.""" + _spec = parse(spec) + return _OPERATOR_FUNCTION_[_spec[0]](privileges, *_spec[1:]) diff --git a/gn_libs/protocols/__init__.py b/gn_libs/protocols/__init__.py index e71f1ce..83a31a8 100644 --- a/gn_libs/protocols/__init__.py +++ b/gn_libs/protocols/__init__.py @@ -1 +1,2 @@ +"""This package is a collection of major Protocols/Interfaces definitions.""" from .db import DbCursor, DbConnection diff --git a/gn_libs/sqlite3.py b/gn_libs/sqlite3.py index 1dcdf29..78e1c41 100644 --- a/gn_libs/sqlite3.py +++ b/gn_libs/sqlite3.py @@ -1,7 +1,8 @@ +"""This module deals with connections to a(n) SQLite3 database.""" import logging import traceback import contextlib -from typing import Any, Protocol, Callable, Iterator +from typing import Callable, Iterator import sqlite3 |
