about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/app.py34
-rw-r--r--gn3/case_attributes.py22
-rw-r--r--gn3/db_utils.py75
-rw-r--r--gn3/settings.py35
-rw-r--r--tests/unit/test_db_utils.py69
5 files changed, 166 insertions, 69 deletions
diff --git a/gn3/app.py b/gn3/app.py
index e9a2bbe..eb9e5d3 100644
--- a/gn3/app.py
+++ b/gn3/app.py
@@ -1,4 +1,5 @@
 """Entry point from spinning up flask"""
+
 import os
 import sys
 import logging
@@ -29,6 +30,32 @@ from gn3.api.llm import gnqa
 from gn3.case_attributes import caseattr
 
 
+class ConfigurationError(Exception):
+    """Raised in case of a configuration error."""
+
+
+def verify_app_config(app: Flask) -> None:
+    """Verify that configuration variables are as expected
+    It includes:
+        1. making sure mandatory settings are defined
+        2. provides examples for what to set as config variables (helps local dev)
+    """
+    app_config = {
+        "AUTH_SERVER_URL": """AUTH_SERVER_URL is used for api requests that need login.
+        For local dev, use the running auth server url, which defaults to http://127.0.0.1:8081
+        """,
+    }
+    error_message = []
+
+    for setting, err in app_config.items():
+        print(f"{setting}: {app.config.get(setting)}")
+        if setting in app.config and bool(app.config[setting]):
+            continue
+        error_message.append(err)
+    if error_message:
+        raise ConfigurationError("\n".join(error_message))
+
+
 def create_app(config: Union[Dict, str, None] = None) -> Flask:
     """Create a new flask object"""
     app = Flask(__name__)
@@ -37,7 +64,7 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask:
 
     # Load environment configuration
     if "GN3_CONF" in os.environ:
-        app.config.from_envvar('GN3_CONF')
+        app.config.from_envvar("GN3_CONF")
 
     # Load app specified configuration
     if config is not None:
@@ -51,6 +78,7 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask:
     if secrets_file and Path(secrets_file).exists():
         app.config.from_envvar("GN3_SECRETS")
     # END: SECRETS
+    verify_app_config(app)
     setup_app_handlers(app)
     # DO NOT log anything before this point
     logging.info("Guix Profile: '%s'.", os.environ.get("GUIX_PROFILE"))
@@ -60,7 +88,9 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask:
         app,
         origins=app.config["CORS_ORIGINS"],
         allow_headers=app.config["CORS_HEADERS"],
-        supports_credentials=True, intercept_exceptions=False)
+        supports_credentials=True,
+        intercept_exceptions=False,
+    )
 
     app.register_blueprint(general, url_prefix="/api/")
     app.register_blueprint(gemma, url_prefix="/api/gemma")
diff --git a/gn3/case_attributes.py b/gn3/case_attributes.py
index 9baff1e..2c878d2 100644
--- a/gn3/case_attributes.py
+++ b/gn3/case_attributes.py
@@ -77,7 +77,7 @@ def required_access(
             result = requests.get(
                 # this section fetches the resource ID from the auth server
                 urljoin(current_app.config["AUTH_SERVER_URL"],
-                        "auth/resource/inbredset/resource-id"
+                        "auth/resource/populations/resource-id"
                         f"/{__species_id__(conn)}/{inbredset_id}"))
             if result.status_code == 200:
                 resource_id = result.json()["resource-id"]
@@ -186,7 +186,7 @@ def __case_attribute_values_by_inbred_set__(
             "ON ca.CaseAttributeId=caxrn.CaseAttributeId "
             "INNER JOIN Strain AS s "
             "ON caxrn.StrainId=s.Id "
-            "WHERE ca.InbredSetId=%(inbredset_id)s "
+            "WHERE caxrn.InbredSetId=%(inbredset_id)s "
             "ORDER BY StrainName",
             {"inbredset_id": inbredset_id})
         return tuple(
@@ -203,7 +203,7 @@ def __process_orig_data__(fieldnames, cadata, strains) -> tuple[dict, ...]:
     data = {item["StrainName"]: item for item in cadata}
     return tuple(
         {
-            "Strain": strain["Name"],
+            "Sample": strain["Name"],
             **{
                 key: data.get(
                     strain["Name"], {}).get("case-attributes", {}).get(key, "")
@@ -216,7 +216,7 @@ def __process_edit_data__(fieldnames, form_data) -> tuple[dict, ...]:
     def __process__(acc, strain_cattrs):
         strain, cattrs = strain_cattrs
         return acc + ({
-            "Strain": strain, **{
+            "Sample": strain, **{
             field: cattrs["case-attributes"].get(field, "")
             for field in fieldnames[1:]
             }
@@ -327,19 +327,19 @@ def __apply_additions__(
 def __apply_modifications__(
         cursor, inbredset_id: int, modifications_diff, fieldnames) -> None:
     """Apply modifications: changes values of existing case attributes."""
-    cattrs = tuple(field for field in fieldnames if field != "Strain")
+    cattrs = tuple(field for field in fieldnames if field != "Sample")
 
     def __retrieve_changes__(acc, row):
         orig = dict(zip(fieldnames, row["Original"].split(",")))
         new = dict(zip(fieldnames, row["Current"].split(",")))
         return acc + tuple({
-            "Strain": new["Strain"],
+            "Sample": new["Sample"],
             cattr: new[cattr]
         } for cattr in cattrs if new[cattr] != orig[cattr])
 
     new_rows: tuple[dict, ...] = reduce(
         __retrieve_changes__, modifications_diff, tuple())
-    strain_names = tuple({row["Strain"] for row in new_rows})
+    strain_names = tuple({row["Sample"] for row in new_rows})
     cursor.execute("SELECT Id AS StrainId, Name AS StrainName FROM Strain "
                    f"WHERE Name IN ({', '.join(['%s'] * len(strain_names))})",
                    strain_names)
@@ -364,7 +364,7 @@ def __apply_modifications__(
         tuple(
             {
                 "isetid": inbredset_id,
-                "strainid": strain_ids[row["Strain"]],
+                "strainid": strain_ids[row["Sample"]],
                 "cattrid": cattr_ids[cattr],
                 "value": row[cattr]
             }
@@ -377,11 +377,11 @@ def __apply_modifications__(
         tuple(
             {
                 "isetid": inbredset_id,
-                "strainid": strain_ids[row["Strain"]],
+                "strainid": strain_ids[row["Sample"]],
                 "cattrid": cattr_ids[cattr]
             }
             for row in new_rows
-            for cattr in (key for key in row.keys() if key != "Strain")
+            for cattr in (key for key in row.keys() if key != "Sample")
             if not bool(row[cattr].strip())))
 
 def __apply_deletions__(
@@ -471,7 +471,7 @@ def edit_case_attributes(inbredset_id: int, auth_token = None) -> Response:
         required_access(auth_token,
                         inbredset_id,
                         ("system:inbredset:edit-case-attribute",))
-        fieldnames = tuple(["Strain"] + sorted(
+        fieldnames = tuple(["Sample"] + sorted(
             attr["Name"] for attr in
             __case_attribute_labels_by_inbred_set__(conn, inbredset_id)))
         try:
diff --git a/gn3/db_utils.py b/gn3/db_utils.py
index 0d9bd0a..7004590 100644
--- a/gn3/db_utils.py
+++ b/gn3/db_utils.py
@@ -1,23 +1,71 @@
 """module contains all db related stuff"""
-import contextlib
 import logging
-from typing import Any, Iterator, Protocol, Tuple
+import contextlib
 from urllib.parse import urlparse
-import MySQLdb as mdb
+from typing import Any, Iterator, Protocol, Callable
+
 import xapian
+import MySQLdb as mdb
 
 
 LOGGER = logging.getLogger(__file__)
 
 
-def parse_db_url(sql_uri: str) -> Tuple:
-    """function to parse SQL_URI env variable note:there\
-    is a default value for SQL_URI so a tuple result is\
-    always expected"""
+def __check_true__(val: str) -> bool:
+    """Check whether the variable 'val' has the string value `true`."""
+    return val.strip().lower() == "true"
+
+
+def __parse_db_opts__(opts: str) -> dict:
+    """Parse database options into their appropriate values.
+
+    This assumes use of python-mysqlclient library."""
+    allowed_opts = (
+        "unix_socket", "connect_timeout", "compress", "named_pipe",
+        "init_command", "read_default_file", "read_default_group",
+        "cursorclass", "use_unicode", "charset", "collation", "auth_plugin",
+        "sql_mode", "client_flag", "multi_statements", "ssl_mode", "ssl",
+        "local_infile", "autocommit", "binary_prefix")
+    conversion_fns: dict[str, Callable] = {
+        **{opt: str for opt in allowed_opts},
+        "connect_timeout": int,
+        "compress": __check_true__,
+        "use_unicode": __check_true__,
+        # "cursorclass": __load_cursor_class__
+        "client_flag": int,
+        "multi_statements": __check_true__,
+        # "ssl": __parse_ssl_options__,
+        "local_infile": __check_true__,
+        "autocommit": __check_true__,
+        "binary_prefix": __check_true__
+    }
+    queries = tuple(filter(bool, opts.split("&")))
+    if len(queries) > 0:
+        keyvals: tuple[tuple[str, ...], ...] = tuple(
+            tuple(item.strip() for item in query.split("="))
+            for query in queries)
+        def __check_opt__(opt):
+            assert opt in allowed_opts, (
+                f"Invalid database connection option ({opt}) provided.")
+            return opt
+        return {
+            __check_opt__(key): conversion_fns[key](val)
+            for key, val in keyvals
+        }
+    return {}
+
+
+def parse_db_url(sql_uri: str) -> dict:
+    """Parse the `sql_uri` variable into a dict of connection parameters."""
     parsed_db = urlparse(sql_uri)
-    return (
-        parsed_db.hostname, parsed_db.username, parsed_db.password,
-        parsed_db.path[1:], parsed_db.port)
+    return {
+        "host": parsed_db.hostname,
+        "port": parsed_db.port or 3306,
+        "user": parsed_db.username,
+        "password": parsed_db.password,
+        "database": parsed_db.path.strip("/").strip(),
+        **__parse_db_opts__(parsed_db.query)
+    }
 
 
 # pylint: disable=missing-class-docstring, missing-function-docstring, too-few-public-methods
@@ -30,12 +78,7 @@ class Connection(Protocol):
 @contextlib.contextmanager
 def database_connection(sql_uri: str, logger: logging.Logger = LOGGER) -> Iterator[Connection]:
     """Connect to MySQL database."""
-    host, user, passwd, db_name, port = parse_db_url(sql_uri)
-    connection = mdb.connect(db=db_name,
-                             user=user,
-                             passwd=passwd or '',
-                             host=host,
-                             port=port or 3306)
+    connection = mdb.connect(**parse_db_url(sql_uri))
     try:
         yield connection
     except mdb.Error as _mbde:
diff --git a/gn3/settings.py b/gn3/settings.py
index 439d88c..9a3f7eb 100644
--- a/gn3/settings.py
+++ b/gn3/settings.py
@@ -5,10 +5,8 @@ DO NOT import from this file, use `flask.current_app.config` instead to get the
 application settings.
 """
 import os
-import uuid
 import tempfile
 
-BCRYPT_SALT = "$2b$12$mxLvu9XRLlIaaSeDxt8Sle"  # Change this!
 DATA_DIR = ""
 GEMMA_WRAPPER_CMD = os.environ.get("GEMMA_WRAPPER", "gemma-wrapper")
 CACHEDIR = ""
@@ -29,14 +27,10 @@ LMDB_PATH = os.environ.get(
 SQL_URI = os.environ.get(
     "SQL_URI", "mysql://webqtlout:webqtlout@localhost/db_webqtl")
 SECRET_KEY = "password"
-# gn2 results only used in fetching dataset info
-
 
 # FAHAMU API TOKEN
 FAHAMU_AUTH_TOKEN = ""
 
-GN2_BASE_URL = "http://www.genenetwork.org/"
-
 # wgcna script
 WGCNA_RSCRIPT = "wgcna_analysis.R"
 # qtlreaper command
@@ -83,31 +77,4 @@ ROUND_TO = 10
 
 MULTIPROCESSOR_PROCS = 6  # Number of processes to spawn
 
-AUTH_SERVER_URL = "https://auth.genenetwork.org"
-AUTH_MIGRATIONS = "migrations/auth"
-OAUTH2_SCOPE = (
-    "profile", "group", "role", "resource", "user", "masquerade",
-    "introspect")
-
-
-try:
-    # *** SECURITY CONCERN ***
-    # Clients with access to this privileges create a security concern.
-    # Be careful when adding to this configuration
-    OAUTH2_CLIENTS_WITH_INTROSPECTION_PRIVILEGE = tuple(
-        uuid.UUID(client_id) for client_id in
-        os.environ.get(
-            "OAUTH2_CLIENTS_WITH_INTROSPECTION_PRIVILEGE", "").split(","))
-except ValueError as _valerr:
-    OAUTH2_CLIENTS_WITH_INTROSPECTION_PRIVILEGE = tuple()
-
-try:
-    # *** SECURITY CONCERN ***
-    # Clients with access to this privileges create a security concern.
-    # Be careful when adding to this configuration
-    OAUTH2_CLIENTS_WITH_DATA_MIGRATION_PRIVILEGE = tuple(
-        uuid.UUID(client_id) for client_id in
-        os.environ.get(
-            "OAUTH2_CLIENTS_WITH_DATA_MIGRATION_PRIVILEGE", "").split(","))
-except ValueError as _valerr:
-    OAUTH2_CLIENTS_WITH_DATA_MIGRATION_PRIVILEGE = tuple()
+AUTH_SERVER_URL = ""
diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py
index beb7169..3c7ce59 100644
--- a/tests/unit/test_db_utils.py
+++ b/tests/unit/test_db_utils.py
@@ -10,16 +10,73 @@ from gn3.db_utils import parse_db_url, database_connection
 @mock.patch("gn3.db_utils.parse_db_url")
 def test_database_connection(mock_db_parser, mock_sql):
     """test for creating database connection"""
-    mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None)
+    mock_db_parser.return_value = {
+        "host": "localhost",
+        "user": "guest",
+        "password": "4321",
+        "database": "users",
+        "port": 3306
+    }
 
+    mock_sql.Error = Exception
     with database_connection("mysql://guest:4321@localhost/users") as _conn:
         mock_sql.connect.assert_called_with(
             db="users", user="guest", passwd="4321", host="localhost",
             port=3306)
 
+
 @pytest.mark.unit_test
-def test_parse_db_url():
-    """test for parsing db_uri env variable"""
-    results = parse_db_url("mysql://username:4321@localhost/test")
-    expected_results = ("localhost", "username", "4321", "test", None)
-    assert results == expected_results
+@pytest.mark.parametrize(
+    "sql_uri,expected",
+    (("mysql://theuser:passwd@thehost:3306/thedb",
+      {
+          "host": "thehost",
+          "port": 3306,
+          "user": "theuser",
+          "password": "passwd",
+          "database": "thedb"
+      }),
+     (("mysql://auser:passwd@somehost:3307/thedb?"
+       "unix_socket=/run/mysqld/mysqld.sock&connect_timeout=30"),
+      {
+          "host": "somehost",
+          "port": 3307,
+          "user": "auser",
+          "password": "passwd",
+          "database": "thedb",
+          "unix_socket": "/run/mysqld/mysqld.sock",
+          "connect_timeout": 30
+      }),
+     ("mysql://guest:4321@localhost/users",
+      {
+          "host": "localhost",
+          "port": 3306,
+          "user": "guest",
+          "password": "4321",
+          "database": "users"
+      }),
+     ("mysql://localhost/users",
+      {
+          "host": "localhost",
+          "port": 3306,
+          "user": None,
+          "password": None,
+          "database": "users"
+      })))
+def test_parse_db_url(sql_uri, expected):
+    """Test that valid URIs are passed into valid connection dicts"""
+    assert parse_db_url(sql_uri) == expected
+
+
+@pytest.mark.unit_test
+@pytest.mark.parametrize(
+    "sql_uri,invalidopt",
+    (("mysql://localhost/users?socket=/run/mysqld/mysqld.sock", "socket"),
+     ("mysql://localhost/users?connect_timeout=30&notavalidoption=value",
+      "notavalidoption")))
+def test_parse_db_url_with_invalid_options(sql_uri, invalidopt):
+    """Test that invalid options cause the function to raise an exception."""
+    with pytest.raises(AssertionError) as exc_info:
+        parse_db_url(sql_uri)
+
+    assert exc_info.value.args[0] == f"Invalid database connection option ({invalidopt}) provided."