about summary refs log tree commit diff
path: root/gn_libs
diff options
context:
space:
mode:
Diffstat (limited to 'gn_libs')
-rw-r--r--gn_libs/debug.py13
-rw-r--r--gn_libs/jobs/__init__.py1
-rw-r--r--gn_libs/jobs/jobs.py79
-rw-r--r--gn_libs/jobs/launcher.py34
-rw-r--r--gn_libs/jobs/migrations.py5
-rw-r--r--gn_libs/monadic_requests.py20
-rw-r--r--gn_libs/mysqldb.py19
-rw-r--r--gn_libs/privileges.py166
-rw-r--r--gn_libs/protocols/__init__.py1
-rw-r--r--gn_libs/sqlite3.py3
10 files changed, 311 insertions, 30 deletions
diff --git a/gn_libs/debug.py b/gn_libs/debug.py
index c1b896e..7ad10e0 100644
--- a/gn_libs/debug.py
+++ b/gn_libs/debug.py
@@ -1,6 +1,7 @@
 """Debug utilities"""
 import logging
 import importlib.util
+from typing import Callable
 
 __this_module_name__ = __name__
 
@@ -26,3 +27,15 @@ def __pk__(*args):
     logger = getLogger(__this_module_name__)
     logger.debug("%s: %s", title_vals, value)
     return value
+
+
+def make_peeker(logger: logging.Logger) -> Callable:
+    """Make a peeker function that's very much like __pk__ but that uses the
+    given logger."""
+    def peeker(*args):
+        value = args[-1]
+        title_vals = " => ".join(args[0:-1])
+        logger.debug("%s: %s", title_vals, value)
+        return value
+
+    return peeker
diff --git a/gn_libs/jobs/__init__.py b/gn_libs/jobs/__init__.py
index 6f400ef..d6e4ce3 100644
--- a/gn_libs/jobs/__init__.py
+++ b/gn_libs/jobs/__init__.py
@@ -1,3 +1,4 @@
+"""This package deals with launching and managing background/async jobs."""
 from .migrations import run_migrations
 from .jobs import (job,
                    launch_job,
diff --git a/gn_libs/jobs/jobs.py b/gn_libs/jobs/jobs.py
index 38cd9c0..ec1c3a8 100644
--- a/gn_libs/jobs/jobs.py
+++ b/gn_libs/jobs/jobs.py
@@ -6,7 +6,7 @@ import shlex
 import logging
 import subprocess
 from pathlib import Path
-from functools import reduce
+from functools import partial
 from typing import Union, Optional
 from datetime import datetime, timezone, timedelta
 
@@ -29,11 +29,25 @@ def __job_metadata__(cursor: DbCursor, job_id: Union[str, uuid.UUID]) -> dict:
     }
 
 
+def job_stdstream_outputs(conn, job_id, streamname: str):
+    """Fetch the standard-error output for the job."""
+    with _cursor(conn) as cursor:
+        cursor.execute(
+            "SELECT * FROM jobs_standard_outputs "
+            "WHERE job_id=? AND output_stream=?",
+            (str(job_id), streamname))
+        return dict(cursor.fetchone() or {}).get("value")
+
+
+job_stderr = partial(job_stdstream_outputs, streamname="stderr")
+job_stdout = partial(job_stdstream_outputs, streamname="stdout")
+
+
 def job(conn: DbConnection, job_id: Union[str, uuid.UUID], fulldetails: bool = False) -> dict:
     """Fetch the job details for a job with a particular ID"""
     with _cursor(conn) as cursor:
         cursor.execute("SELECT * FROM jobs WHERE job_id=?", (str(job_id),))
-        _job = dict(cursor.fetchone())
+        _job = dict(cursor.fetchone() or {})
         if not bool(_job):
             raise JobNotFound(f"Could not find job with ID {job_id}")
 
@@ -72,16 +86,18 @@ def __save_job__(conn: DbConnection, the_job: dict, expiry_seconds: int) -> dict
     return the_job
 
 
-def initialise_job(
+def initialise_job(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
         conn: DbConnection,
         job_id: uuid.UUID,
         command: list,
         job_type: str,
-        extra_meta: dict = {},
-        expiry_seconds: Optional[int] = _DEFAULT_EXPIRY_SECONDS_
+        extra_meta: Optional[dict] = None,
+        expiry_seconds: int = _DEFAULT_EXPIRY_SECONDS_
 ) -> dict:
     """Initialise the job and put the details in a SQLite3 database."""
-    
+    if extra_meta is None:
+        extra_meta = {}
+
     _job = {
         "job_id": job_id,
         "command": shlex.join(command),
@@ -96,34 +112,59 @@ def initialise_job(
     return __save_job__(conn, _job, expiry_seconds)
 
 
-def error_filename(jobid, error_dir):
-    "Compute the path of the file where errors will be dumped."
-    return f"{error_dir}/job_{jobid}.error"
+def output_file(jobid: uuid.UUID, outdir: Path, stream: str) -> Path:
+    """Compute the path for the file where the launcher's `stream` output goes"""
+    assert stream in ("stdout", "stderr"), f"Invalid stream '{stream}'"
+    return outdir.joinpath(f"launcher_job_{jobid}.{stream}")
+
+
+stdout_filename = partial(output_file, stream="stdout")
+stderr_filename = partial(output_file, stream="stderr")
+
+
+def build_environment(extras: Optional[dict[str, str]] = None) -> dict[str, str]:
+    """Setup the runtime environment variables for the background script."""
+    if extras is None:
+        extras = {}
+
+    return {
+        **dict(os.environ),
+        "PYTHONPATH": ":".join(sys.path),
+        **extras
+    }
 
 
 def launch_job(
         the_job: dict,
         sqlite3_url: str,
         error_dir: Path,
-        worker_manager: str = "gn_libs.jobs.launcher"
+        worker_manager: str = "gn_libs.jobs.launcher",
+        loglevel: str = "info"
 ) -> dict:
     """Launch a job in the background"""
     if not os.path.exists(error_dir):
         os.mkdir(error_dir)
 
     job_id = str(the_job["job_id"])
-    with open(error_filename(job_id, error_dir),
-              "w",
-              encoding="utf-8") as errorfile:
+    with (open(stderr_filename(jobid=the_job["job_id"], outdir=error_dir),
+               "w",
+               encoding="utf-8") as stderrfile,
+          open(stdout_filename(jobid=the_job["job_id"], outdir=error_dir),
+               "w",
+               encoding="utf-8") as stdoutfile):
         subprocess.Popen( # pylint: disable=[consider-using-with]
             [
                 sys.executable, "-u",
                 "-m", worker_manager,
                 sqlite3_url,
                 job_id,
-                str(error_dir)],
-            stderr=errorfile,
-            env={"PYTHONPATH": ":".join(sys.path)})
+                str(error_dir),
+                "--log-level",
+                loglevel
+            ],
+            stdout=stdoutfile,
+            stderr=stderrfile,
+            env=build_environment())
 
     return the_job
 
@@ -144,7 +185,11 @@ def update_metadata(conn: DbConnection, job_id: Union[str, uuid.UUID], key: str,
             })
 
 
-def push_to_stream(conn: DbConnection, job_id: Union[str, uuid.UUID], stream_name: str, content: str):
+def push_to_stream(
+        conn: DbConnection,
+        job_id: Union[str, uuid.UUID],
+        stream_name: str, content: str
+):
     """Initialise, and keep adding content to the stream from the provided content."""
     with _cursor(conn) as cursor:
         cursor.execute("SELECT * FROM jobs_standard_outputs "
diff --git a/gn_libs/jobs/launcher.py b/gn_libs/jobs/launcher.py
index b7369a4..d565f9e 100644
--- a/gn_libs/jobs/launcher.py
+++ b/gn_libs/jobs/launcher.py
@@ -1,6 +1,9 @@
+"""Default launcher/manager script for background jobs."""
+import os
 import sys
 import time
 import shlex
+import logging
 import argparse
 import traceback
 import subprocess
@@ -9,15 +12,19 @@ from pathlib import Path
 
 from gn_libs import jobs, sqlite3
 
+logger = logging.getLogger(__name__)
+
 
 def run_job(conn, job, outputs_directory: Path):
     """Run the job."""
+    logger.info("Setting up the job.")
     job_id = job["job_id"]
     stdout_file = outputs_directory.joinpath(f"{job_id}.stdout")
     stderr_file = outputs_directory.joinpath(f"{job_id}.stderr")
     jobs.update_metadata(conn, job_id, "stdout-file", str(stdout_file))
     jobs.update_metadata(conn, job_id, "stderr-file", str(stderr_file))
     try:
+        logger.info("Launching the job in a separate process.")
         with (stdout_file.open(mode="w") as outfile,
               stderr_file.open(mode="w") as errfile,
               stdout_file.open(mode="r") as stdout_in,
@@ -36,7 +43,21 @@ def run_job(conn, job, outputs_directory: Path):
             # Fetch any remaining content.
             jobs.push_to_stream(conn, job_id, "stdout", stdout_in.read())
             jobs.push_to_stream(conn, job_id, "stderr", stderr_in.read())
-    except:
+            logger.info("Job completed. Cleaning up.")
+
+        os.remove(stdout_file)
+        os.remove(stderr_file)
+        exit_status = process.poll()
+        if exit_status == 0:
+            jobs.update_metadata(conn, job_id, "status", "completed")
+        else:
+            jobs.update_metadata(conn, job_id, "status", "error")
+
+        logger.info("exiting job manager/launcher")
+        return exit_status
+    except Exception as _exc:# pylint: disable=[broad-exception-caught]
+        logger.error("An exception was raised when attempting to run the job",
+                     exc_info=True)
         jobs.update_metadata(conn, job_id, "status", "error")
         jobs.push_to_stream(conn, job_id, "stderr", traceback.format_exc())
         return 4
@@ -56,14 +77,21 @@ def parse_args():
     parser.add_argument("outputs_directory",
                         help="Directory where output files will be created",
                         type=Path)
+    parser.add_argument(
+        "--log-level",
+        type=str,
+        help="Determines what is logged out.",
+        choices=("debug", "info", "warning", "error", "critical"),
+        default="info")
     return parser.parse_args()
 
+
 def main():
     """Entry-point to this program."""
     args = parse_args()
+    logger.setLevel(args.log_level.upper())
     args.outputs_directory.mkdir(parents=True, exist_ok=True)
-    with (sqlite3.connection(args.jobs_db_uri) as conn,
-          sqlite3.cursor(conn) as cursor):
+    with sqlite3.connection(args.jobs_db_uri) as conn:
         job = jobs.job(conn, args.job_id)
         if job:
             return run_job(conn, job, args.outputs_directory)
diff --git a/gn_libs/jobs/migrations.py b/gn_libs/jobs/migrations.py
index 86fb958..0c9825b 100644
--- a/gn_libs/jobs/migrations.py
+++ b/gn_libs/jobs/migrations.py
@@ -1,6 +1,6 @@
 """Database migrations for the jobs to ensure consistency downstream."""
 from gn_libs.protocols import DbCursor
-from gn_libs.sqlite3 import cursor, connection
+from gn_libs.sqlite3 import connection, cursor as acquire_cursor
 
 def __create_table_jobs__(cursor: DbCursor):
     """Create the jobs table"""
@@ -61,8 +61,9 @@ def __create_table_jobs_output_streams__(cursor: DbCursor):
 
 
 def run_migrations(sqlite_url: str):
+    """Run the migrations to setup the background jobs database."""
     with (connection(sqlite_url) as conn,
-          cursor(conn) as curr):
+          acquire_cursor(conn) as curr):
         __create_table_jobs__(curr)
         __create_table_jobs_metadata__(curr)
         __create_table_jobs_output_streams__(curr)
diff --git a/gn_libs/monadic_requests.py b/gn_libs/monadic_requests.py
index 0a3c282..a09acc5 100644
--- a/gn_libs/monadic_requests.py
+++ b/gn_libs/monadic_requests.py
@@ -19,8 +19,13 @@ def get(url, params=None, **kwargs) -> Either:
 
     :rtype: pymonad.either.Either
     """
+    timeout = kwargs.get("timeout")
+    kwargs = {key: val for key,val in kwargs.items() if key != "timeout"}
+    if timeout is None:
+        timeout = (9.13, 20)
+
     try:
-        resp = requests.get(url, params=params, **kwargs)
+        resp = requests.get(url, params=params, timeout=timeout, **kwargs)
         if resp.status_code in SUCCESS_CODES:
             return Right(resp.json())
         return Left(resp)
@@ -36,8 +41,13 @@ def post(url, data=None, json=None, **kwargs) -> Either:
 
     :rtype: pymonad.either.Either
     """
+    timeout = kwargs.get("timeout")
+    kwargs = {key: val for key,val in kwargs.items() if key != "timeout"}
+    if timeout is None:
+        timeout = (9.13, 20)
+
     try:
-        resp = requests.post(url, data=data, json=json, **kwargs)
+        resp = requests.post(url, data=data, json=json, timeout=timeout, **kwargs)
         if resp.status_code in SUCCESS_CODES:
             return Right(resp.json())
         return Left(resp)
@@ -55,10 +65,10 @@ def make_either_error_handler(msg):
             try:
                 _data = error.json()
             except Exception as _exc:
-                raise Exception(error.content) from _exc
-            raise Exception(_data)
+                raise Exception(error.content) from _exc# pylint: disable=[broad-exception-raised]
+            raise Exception(_data)# pylint: disable=[broad-exception-raised]
 
         logger.debug("\n\n%s\n\n", msg)
-        raise Exception(error)
+        raise Exception(error)# pylint: disable=[broad-exception-raised]
 
     return __fail__
diff --git a/gn_libs/mysqldb.py b/gn_libs/mysqldb.py
index 64a649d..3f6390e 100644
--- a/gn_libs/mysqldb.py
+++ b/gn_libs/mysqldb.py
@@ -9,7 +9,7 @@ import MySQLdb as mdb
 from MySQLdb.cursors import Cursor
 
 
-_logger = logging.getLogger(__file__)
+_logger = logging.getLogger(__name__)
 
 class InvalidOptionValue(Exception):
     """Raised whenever a parsed value is invalid for the specific option."""
@@ -46,6 +46,12 @@ def __parse_ssl_mode_options__(val: str) -> str:
 
 
 def __parse_ssl_options__(val: str) -> dict:
+    if val.strip() == "" or val.strip().lower() == "false":
+        return False
+
+    if val.strip().lower() == "true":
+        return True
+
     allowed_keys = ("key", "cert", "ca", "capath", "cipher")
     opts = {
         key.strip(): val.strip() for key,val in
@@ -61,6 +67,7 @@ def __parse_db_opts__(opts: str) -> dict:
 
     This assumes use of python-mysqlclient library."""
     allowed_opts = (
+        # See: https://mysqlclient.readthedocs.io/user_guide.html#functions-and-attributes
         "unix_socket", "connect_timeout", "compress", "named_pipe",
         "init_command", "read_default_file", "read_default_group",
         "cursorclass", "use_unicode", "charset", "collation", "auth_plugin",
@@ -124,13 +131,21 @@ class Connection(Protocol):
 @contextlib.contextmanager
 def database_connection(sql_uri: str, logger: logging.Logger = _logger) -> Iterator[Connection]:
     """Connect to MySQL database."""
-    connection = mdb.connect(**parse_db_url(sql_uri))
+    _conn_opts = parse_db_url(sql_uri)
+    _logger.debug("Connecting to database with the following options: %s",
+                  _conn_opts)
+    connection = mdb.connect(**_conn_opts)
     try:
         yield connection
         connection.commit()
     except mdb.Error as _mbde:
         logger.error("DB error encountered", exc_info=True)
         connection.rollback()
+        raise _mbde from None
+    except Exception as _exc:
+        connection.rollback()
+        logger.error("General exception encountered", exc_info=True)
+        raise _exc from None
     finally:
         connection.close()
 
diff --git a/gn_libs/privileges.py b/gn_libs/privileges.py
new file mode 100644
index 0000000..32c943d
--- /dev/null
+++ b/gn_libs/privileges.py
@@ -0,0 +1,166 @@
+"""Utilities for handling privileges."""
+import logging
+from functools import reduce
+from typing import Union, Sequence, Iterator, TypeAlias, TypedDict
+
+logger = logging.getLogger(__name__)
+
+Operator: TypeAlias = str # Valid operators: "AND", "OR"
+Privilege: TypeAlias = str
+PrivilegesList: TypeAlias = Sequence[Privilege]
+ParseTree = tuple[Operator,
+                  # Leaves (`PrivilegesList` objects) on the left,
+                  # trees (`ParseTree` objects) on the right
+                  Union[PrivilegesList, tuple[PrivilegesList, 'ParseTree']]]
+
+
+class SpecificationValueError(ValueError):
+    """Raised when there is an error in the specification string."""
+
+
+_OPERATORS_ = ("OR", "AND")
+_EMPTY_SPEC_ERROR_ = SpecificationValueError(
+    "Empty specification. I do not know what to do.")
+
+
+def __add_leaves__(
+        index: int,
+        tree: tuple[Operator],
+        leaves: dict
+) -> Union[tuple[Operator], Union[ParseTree, tuple]]:
+    """Add leaves to the tree."""
+    if leaves.get(index):
+        return tree + (leaves[index],)
+    return tree + (tuple())
+
+
+class ParsingState(TypedDict):
+    """Class to create a state object. Mostly used to silence MyPy"""
+    tokens: list[str]
+    trees: list[tuple[int, int, str, int, int]]#(name, parent, operator, start, end)
+    open_parens: int
+    current_tree: int
+    leaves: dict[int, tuple[str, ...]]#[parent-tree, [index, index, ...]]
+
+
+def __build_tree__(tree_state: ParsingState) -> ParseTree:
+    """Given computed state, build the actual tree."""
+    _built = []
+    for idx, tree in enumerate(tree_state["trees"]):
+        _built.append(__add_leaves__(idx, (tree[2],), tree_state["leaves"]))
+
+    logger.debug("Number of built trees: %s, %s", len(_built), _built)
+    _num_trees = len(_built)
+    for idx in range(0, _num_trees):
+        _last_tree = _built.pop()
+        logger.debug("LAST TREE: %s, %s", _last_tree, len(_last_tree))
+        if len(_last_tree) <= 1:# Has no leaves or subtrees
+            _last_tree = None# type: ignore[assignment]
+            continue# more evil
+        _name = tree_state["trees"][_num_trees - 1 - idx][0]
+        _parent = tree_state["trees"][
+            tree_state["trees"][_num_trees - 1 - idx][1]]
+        _op = tree_state["trees"][_num_trees - 1 - idx][2]
+        logger.debug("TREE => name: %s, operation: %s, parent: %s",
+                     _name, _op, _parent)
+        if _name != _parent[0]:# not root tree
+            if _op == _parent[2]:
+                _built[_parent[0]] = (
+                    _built[_parent[0]][0],# Operator
+                    _built[_parent[0]][1] + _last_tree[1]# merge leaves
+                ) + _last_tree[2:]#Add any trees left over
+            else:
+                _built[_parent[0]] += (_last_tree,)
+
+    if _last_tree is None:
+        raise _EMPTY_SPEC_ERROR_
+    return _last_tree
+
+
+def __parse_tree__(tokens: Iterator[str]) -> ParseTree:
+    """Parse the tokens into a tree."""
+    _state = ParsingState(
+        tokens=[], trees=[], open_parens=0, current_tree=0, leaves={})
+    for _idx, _token in enumerate(tokens):
+        _state["tokens"].append(_token)
+
+        if _idx==0:
+            if _token[1:].upper() not in _OPERATORS_:
+                raise SpecificationValueError(f"Invalid operator: {_token[1:]}")
+            _state["open_parens"] += 1
+            _state["trees"].append((0, 0, _token[1:].upper(), _idx, -1))
+            _state["current_tree"] = 0
+            continue# this is bad!
+
+        if _token == ")":# end a tree
+            logger.debug("ENDING A TREE: %s", _state)
+            _state["open_parens"] -= 1
+            _state["trees"][_state["current_tree"]] = (
+                _state["trees"][_state["current_tree"]][0:-1] + (_idx,))
+            # We go back to the parent below.
+            _state["current_tree"] = _state["trees"][_state["current_tree"]][1]
+            continue# still really bad!
+
+        if _token[1:].upper() in _OPERATORS_:# new child tree
+            _state["open_parens"] += 1
+            _state["trees"].append((len(_state["trees"]),
+                                    _state["current_tree"],
+                                    _token[1:].upper(),
+                                    _idx,
+                                    -1))
+            _state["current_tree"] = len(_state["trees"]) - 1
+            continue# more evil still
+
+        logger.debug("state: %s", _state)
+        # leaves
+        _state["leaves"][_state["current_tree"]] = _state["leaves"].get(
+            _state["current_tree"], tuple()) + (_token,)
+
+    # Build parse-tree from state
+    if _state["open_parens"] != 0:
+        raise SpecificationValueError("Unbalanced parentheses.")
+    return __build_tree__(_state)
+
+
+def __tokenise__(spec: str) -> Iterator[str]:
+    """Clean up and tokenise the string."""
+    return (token.strip()
+            for token in spec.replace(
+                "(", " ("
+            ).replace(
+                ")", " ) "
+            ).replace(
+                "( ", "("
+            ).split())
+
+
+def parse(spec: str) -> ParseTree:
+    """Parse a string specification for privileges and return a tree of data
+    objects of the form (<operator> (<check>))"""
+    if spec.strip() == "":
+        raise _EMPTY_SPEC_ERROR_
+
+    return __parse_tree__(__tokenise__(spec))
+
+
+def __make_checker__(check_fn):
+    def __checker__(privileges, *checks):
+        def __check__(acc, curr):
+            if curr[0] in _OPERATORS_:
+                return acc + (_OPERATOR_FUNCTION_[curr[0]](
+                    privileges, *curr[1:]),)
+            return acc + (check_fn((priv in privileges) for priv in curr),)
+        results = reduce(__check__, checks, tuple())
+        return len(results) > 0 and check_fn(results)
+
+    return __checker__
+
+
+_OPERATOR_FUNCTION_ = {
+    "OR": __make_checker__(any),
+    "AND": __make_checker__(all)
+}
+def check(spec: str, privileges: tuple[str, ...]) -> bool:
+    """Check that the sequence of `privileges` satisfies `spec`."""
+    _spec = parse(spec)
+    return _OPERATOR_FUNCTION_[_spec[0]](privileges, *_spec[1:])
diff --git a/gn_libs/protocols/__init__.py b/gn_libs/protocols/__init__.py
index e71f1ce..83a31a8 100644
--- a/gn_libs/protocols/__init__.py
+++ b/gn_libs/protocols/__init__.py
@@ -1 +1,2 @@
+"""This package is a collection of major Protocols/Interfaces definitions."""
 from .db import DbCursor, DbConnection
diff --git a/gn_libs/sqlite3.py b/gn_libs/sqlite3.py
index 1dcdf29..78e1c41 100644
--- a/gn_libs/sqlite3.py
+++ b/gn_libs/sqlite3.py
@@ -1,7 +1,8 @@
+"""This module deals with connections to a(n) SQLite3 database."""
 import logging
 import traceback
 import contextlib
-from typing import Any, Protocol, Callable, Iterator
+from typing import Callable, Iterator
 
 import sqlite3