diff options
| -rw-r--r-- | gn3/app.py | 34 | ||||
| -rw-r--r-- | gn3/case_attributes.py | 22 | ||||
| -rw-r--r-- | gn3/db_utils.py | 75 | ||||
| -rw-r--r-- | gn3/settings.py | 35 | ||||
| -rw-r--r-- | tests/unit/test_db_utils.py | 69 |
5 files changed, 166 insertions, 69 deletions
diff --git a/gn3/app.py b/gn3/app.py index e9a2bbe..eb9e5d3 100644 --- a/gn3/app.py +++ b/gn3/app.py @@ -1,4 +1,5 @@ """Entry point from spinning up flask""" + import os import sys import logging @@ -29,6 +30,32 @@ from gn3.api.llm import gnqa from gn3.case_attributes import caseattr +class ConfigurationError(Exception): + """Raised in case of a configuration error.""" + + +def verify_app_config(app: Flask) -> None: + """Verify that configuration variables are as expected + It includes: + 1. making sure mandatory settings are defined + 2. provides examples for what to set as config variables (helps local dev) + """ + app_config = { + "AUTH_SERVER_URL": """AUTH_SERVER_URL is used for api requests that need login. + For local dev, use the running auth server url, which defaults to http://127.0.0.1:8081 + """, + } + error_message = [] + + for setting, err in app_config.items(): + print(f"{setting}: {app.config.get(setting)}") + if setting in app.config and bool(app.config[setting]): + continue + error_message.append(err) + if error_message: + raise ConfigurationError("\n".join(error_message)) + + def create_app(config: Union[Dict, str, None] = None) -> Flask: """Create a new flask object""" app = Flask(__name__) @@ -37,7 +64,7 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask: # Load environment configuration if "GN3_CONF" in os.environ: - app.config.from_envvar('GN3_CONF') + app.config.from_envvar("GN3_CONF") # Load app specified configuration if config is not None: @@ -51,6 +78,7 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask: if secrets_file and Path(secrets_file).exists(): app.config.from_envvar("GN3_SECRETS") # END: SECRETS + verify_app_config(app) setup_app_handlers(app) # DO NOT log anything before this point logging.info("Guix Profile: '%s'.", os.environ.get("GUIX_PROFILE")) @@ -60,7 +88,9 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask: app, origins=app.config["CORS_ORIGINS"], allow_headers=app.config["CORS_HEADERS"], - supports_credentials=True, intercept_exceptions=False) + supports_credentials=True, + intercept_exceptions=False, + ) app.register_blueprint(general, url_prefix="/api/") app.register_blueprint(gemma, url_prefix="/api/gemma") diff --git a/gn3/case_attributes.py b/gn3/case_attributes.py index 9baff1e..2c878d2 100644 --- a/gn3/case_attributes.py +++ b/gn3/case_attributes.py @@ -77,7 +77,7 @@ def required_access( result = requests.get( # this section fetches the resource ID from the auth server urljoin(current_app.config["AUTH_SERVER_URL"], - "auth/resource/inbredset/resource-id" + "auth/resource/populations/resource-id" f"/{__species_id__(conn)}/{inbredset_id}")) if result.status_code == 200: resource_id = result.json()["resource-id"] @@ -186,7 +186,7 @@ def __case_attribute_values_by_inbred_set__( "ON ca.CaseAttributeId=caxrn.CaseAttributeId " "INNER JOIN Strain AS s " "ON caxrn.StrainId=s.Id " - "WHERE ca.InbredSetId=%(inbredset_id)s " + "WHERE caxrn.InbredSetId=%(inbredset_id)s " "ORDER BY StrainName", {"inbredset_id": inbredset_id}) return tuple( @@ -203,7 +203,7 @@ def __process_orig_data__(fieldnames, cadata, strains) -> tuple[dict, ...]: data = {item["StrainName"]: item for item in cadata} return tuple( { - "Strain": strain["Name"], + "Sample": strain["Name"], **{ key: data.get( strain["Name"], {}).get("case-attributes", {}).get(key, "") @@ -216,7 +216,7 @@ def __process_edit_data__(fieldnames, form_data) -> tuple[dict, ...]: def __process__(acc, strain_cattrs): strain, cattrs = strain_cattrs return acc + ({ - "Strain": strain, **{ + "Sample": strain, **{ field: cattrs["case-attributes"].get(field, "") for field in fieldnames[1:] } @@ -327,19 +327,19 @@ def __apply_additions__( def __apply_modifications__( cursor, inbredset_id: int, modifications_diff, fieldnames) -> None: """Apply modifications: changes values of existing case attributes.""" - cattrs = tuple(field for field in fieldnames if field != "Strain") + cattrs = tuple(field for field in fieldnames if field != "Sample") def __retrieve_changes__(acc, row): orig = dict(zip(fieldnames, row["Original"].split(","))) new = dict(zip(fieldnames, row["Current"].split(","))) return acc + tuple({ - "Strain": new["Strain"], + "Sample": new["Sample"], cattr: new[cattr] } for cattr in cattrs if new[cattr] != orig[cattr]) new_rows: tuple[dict, ...] = reduce( __retrieve_changes__, modifications_diff, tuple()) - strain_names = tuple({row["Strain"] for row in new_rows}) + strain_names = tuple({row["Sample"] for row in new_rows}) cursor.execute("SELECT Id AS StrainId, Name AS StrainName FROM Strain " f"WHERE Name IN ({', '.join(['%s'] * len(strain_names))})", strain_names) @@ -364,7 +364,7 @@ def __apply_modifications__( tuple( { "isetid": inbredset_id, - "strainid": strain_ids[row["Strain"]], + "strainid": strain_ids[row["Sample"]], "cattrid": cattr_ids[cattr], "value": row[cattr] } @@ -377,11 +377,11 @@ def __apply_modifications__( tuple( { "isetid": inbredset_id, - "strainid": strain_ids[row["Strain"]], + "strainid": strain_ids[row["Sample"]], "cattrid": cattr_ids[cattr] } for row in new_rows - for cattr in (key for key in row.keys() if key != "Strain") + for cattr in (key for key in row.keys() if key != "Sample") if not bool(row[cattr].strip()))) def __apply_deletions__( @@ -471,7 +471,7 @@ def edit_case_attributes(inbredset_id: int, auth_token = None) -> Response: required_access(auth_token, inbredset_id, ("system:inbredset:edit-case-attribute",)) - fieldnames = tuple(["Strain"] + sorted( + fieldnames = tuple(["Sample"] + sorted( attr["Name"] for attr in __case_attribute_labels_by_inbred_set__(conn, inbredset_id))) try: diff --git a/gn3/db_utils.py b/gn3/db_utils.py index 0d9bd0a..7004590 100644 --- a/gn3/db_utils.py +++ b/gn3/db_utils.py @@ -1,23 +1,71 @@ """module contains all db related stuff""" -import contextlib import logging -from typing import Any, Iterator, Protocol, Tuple +import contextlib from urllib.parse import urlparse -import MySQLdb as mdb +from typing import Any, Iterator, Protocol, Callable + import xapian +import MySQLdb as mdb LOGGER = logging.getLogger(__file__) -def parse_db_url(sql_uri: str) -> Tuple: - """function to parse SQL_URI env variable note:there\ - is a default value for SQL_URI so a tuple result is\ - always expected""" +def __check_true__(val: str) -> bool: + """Check whether the variable 'val' has the string value `true`.""" + return val.strip().lower() == "true" + + +def __parse_db_opts__(opts: str) -> dict: + """Parse database options into their appropriate values. + + This assumes use of python-mysqlclient library.""" + allowed_opts = ( + "unix_socket", "connect_timeout", "compress", "named_pipe", + "init_command", "read_default_file", "read_default_group", + "cursorclass", "use_unicode", "charset", "collation", "auth_plugin", + "sql_mode", "client_flag", "multi_statements", "ssl_mode", "ssl", + "local_infile", "autocommit", "binary_prefix") + conversion_fns: dict[str, Callable] = { + **{opt: str for opt in allowed_opts}, + "connect_timeout": int, + "compress": __check_true__, + "use_unicode": __check_true__, + # "cursorclass": __load_cursor_class__ + "client_flag": int, + "multi_statements": __check_true__, + # "ssl": __parse_ssl_options__, + "local_infile": __check_true__, + "autocommit": __check_true__, + "binary_prefix": __check_true__ + } + queries = tuple(filter(bool, opts.split("&"))) + if len(queries) > 0: + keyvals: tuple[tuple[str, ...], ...] = tuple( + tuple(item.strip() for item in query.split("=")) + for query in queries) + def __check_opt__(opt): + assert opt in allowed_opts, ( + f"Invalid database connection option ({opt}) provided.") + return opt + return { + __check_opt__(key): conversion_fns[key](val) + for key, val in keyvals + } + return {} + + +def parse_db_url(sql_uri: str) -> dict: + """Parse the `sql_uri` variable into a dict of connection parameters.""" parsed_db = urlparse(sql_uri) - return ( - parsed_db.hostname, parsed_db.username, parsed_db.password, - parsed_db.path[1:], parsed_db.port) + return { + "host": parsed_db.hostname, + "port": parsed_db.port or 3306, + "user": parsed_db.username, + "password": parsed_db.password, + "database": parsed_db.path.strip("/").strip(), + **__parse_db_opts__(parsed_db.query) + } # pylint: disable=missing-class-docstring, missing-function-docstring, too-few-public-methods @@ -30,12 +78,7 @@ class Connection(Protocol): @contextlib.contextmanager def database_connection(sql_uri: str, logger: logging.Logger = LOGGER) -> Iterator[Connection]: """Connect to MySQL database.""" - host, user, passwd, db_name, port = parse_db_url(sql_uri) - connection = mdb.connect(db=db_name, - user=user, - passwd=passwd or '', - host=host, - port=port or 3306) + connection = mdb.connect(**parse_db_url(sql_uri)) try: yield connection except mdb.Error as _mbde: diff --git a/gn3/settings.py b/gn3/settings.py index 439d88c..9a3f7eb 100644 --- a/gn3/settings.py +++ b/gn3/settings.py @@ -5,10 +5,8 @@ DO NOT import from this file, use `flask.current_app.config` instead to get the application settings. """ import os -import uuid import tempfile -BCRYPT_SALT = "$2b$12$mxLvu9XRLlIaaSeDxt8Sle" # Change this! DATA_DIR = "" GEMMA_WRAPPER_CMD = os.environ.get("GEMMA_WRAPPER", "gemma-wrapper") CACHEDIR = "" @@ -29,14 +27,10 @@ LMDB_PATH = os.environ.get( SQL_URI = os.environ.get( "SQL_URI", "mysql://webqtlout:webqtlout@localhost/db_webqtl") SECRET_KEY = "password" -# gn2 results only used in fetching dataset info - # FAHAMU API TOKEN FAHAMU_AUTH_TOKEN = "" -GN2_BASE_URL = "http://www.genenetwork.org/" - # wgcna script WGCNA_RSCRIPT = "wgcna_analysis.R" # qtlreaper command @@ -83,31 +77,4 @@ ROUND_TO = 10 MULTIPROCESSOR_PROCS = 6 # Number of processes to spawn -AUTH_SERVER_URL = "https://auth.genenetwork.org" -AUTH_MIGRATIONS = "migrations/auth" -OAUTH2_SCOPE = ( - "profile", "group", "role", "resource", "user", "masquerade", - "introspect") - - -try: - # *** SECURITY CONCERN *** - # Clients with access to this privileges create a security concern. - # Be careful when adding to this configuration - OAUTH2_CLIENTS_WITH_INTROSPECTION_PRIVILEGE = tuple( - uuid.UUID(client_id) for client_id in - os.environ.get( - "OAUTH2_CLIENTS_WITH_INTROSPECTION_PRIVILEGE", "").split(",")) -except ValueError as _valerr: - OAUTH2_CLIENTS_WITH_INTROSPECTION_PRIVILEGE = tuple() - -try: - # *** SECURITY CONCERN *** - # Clients with access to this privileges create a security concern. - # Be careful when adding to this configuration - OAUTH2_CLIENTS_WITH_DATA_MIGRATION_PRIVILEGE = tuple( - uuid.UUID(client_id) for client_id in - os.environ.get( - "OAUTH2_CLIENTS_WITH_DATA_MIGRATION_PRIVILEGE", "").split(",")) -except ValueError as _valerr: - OAUTH2_CLIENTS_WITH_DATA_MIGRATION_PRIVILEGE = tuple() +AUTH_SERVER_URL = "" diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py index beb7169..3c7ce59 100644 --- a/tests/unit/test_db_utils.py +++ b/tests/unit/test_db_utils.py @@ -10,16 +10,73 @@ from gn3.db_utils import parse_db_url, database_connection @mock.patch("gn3.db_utils.parse_db_url") def test_database_connection(mock_db_parser, mock_sql): """test for creating database connection""" - mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None) + mock_db_parser.return_value = { + "host": "localhost", + "user": "guest", + "password": "4321", + "database": "users", + "port": 3306 + } + mock_sql.Error = Exception with database_connection("mysql://guest:4321@localhost/users") as _conn: mock_sql.connect.assert_called_with( db="users", user="guest", passwd="4321", host="localhost", port=3306) + @pytest.mark.unit_test -def test_parse_db_url(): - """test for parsing db_uri env variable""" - results = parse_db_url("mysql://username:4321@localhost/test") - expected_results = ("localhost", "username", "4321", "test", None) - assert results == expected_results +@pytest.mark.parametrize( + "sql_uri,expected", + (("mysql://theuser:passwd@thehost:3306/thedb", + { + "host": "thehost", + "port": 3306, + "user": "theuser", + "password": "passwd", + "database": "thedb" + }), + (("mysql://auser:passwd@somehost:3307/thedb?" + "unix_socket=/run/mysqld/mysqld.sock&connect_timeout=30"), + { + "host": "somehost", + "port": 3307, + "user": "auser", + "password": "passwd", + "database": "thedb", + "unix_socket": "/run/mysqld/mysqld.sock", + "connect_timeout": 30 + }), + ("mysql://guest:4321@localhost/users", + { + "host": "localhost", + "port": 3306, + "user": "guest", + "password": "4321", + "database": "users" + }), + ("mysql://localhost/users", + { + "host": "localhost", + "port": 3306, + "user": None, + "password": None, + "database": "users" + }))) +def test_parse_db_url(sql_uri, expected): + """Test that valid URIs are passed into valid connection dicts""" + assert parse_db_url(sql_uri) == expected + + +@pytest.mark.unit_test +@pytest.mark.parametrize( + "sql_uri,invalidopt", + (("mysql://localhost/users?socket=/run/mysqld/mysqld.sock", "socket"), + ("mysql://localhost/users?connect_timeout=30¬avalidoption=value", + "notavalidoption"))) +def test_parse_db_url_with_invalid_options(sql_uri, invalidopt): + """Test that invalid options cause the function to raise an exception.""" + with pytest.raises(AssertionError) as exc_info: + parse_db_url(sql_uri) + + assert exc_info.value.args[0] == f"Invalid database connection option ({invalidopt}) provided." |
