From dbf0f9f0d34c9969aa6ae76f556745a9eb122106 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Sat, 15 Apr 2023 19:35:53 +0300 Subject: Decouple `gn3.db_utils` from `flask.current_app`. Decouple the `gn3.db_utils` module from the global `flask.current_app` object, ensuring that the database uri value is passed in as a required argument to the `gn3.db_utils.database_connection` function. --- gn3/auth/authorisation/data/views.py | 8 ++++---- gn3/auth/authorisation/groups/views.py | 4 ++-- gn3/db_utils.py | 10 ++-------- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/gn3/auth/authorisation/data/views.py b/gn3/auth/authorisation/data/views.py index 1a4c031..e00df66 100644 --- a/gn3/auth/authorisation/data/views.py +++ b/gn3/auth/authorisation/data/views.py @@ -49,7 +49,7 @@ data = Blueprint("data", __name__) @data.route("species") def list_species() -> Response: """List all available species information.""" - with (gn3db.database_connection() as gn3conn, + with (gn3db.database_connection(app.config["SQL_URI"]) as gn3conn, gn3conn.cursor(DictCursor) as cursor): cursor.execute("SELECT * FROM Species") return jsonify(tuple(dict(row) for row in cursor.fetchall())) @@ -280,7 +280,7 @@ def migrate_users_data() -> Response: with (require_oauth.acquire("migrate-data") as the_token, db.connection(db_uri) as authconn, redis.Redis(decode_responses=True) as rconn, - gn3db.database_connection() as gn3conn): + gn3db.database_connection(app.config["SQL_URI"]) as gn3conn): if the_token.client.client_id in authorised_clients: by_user: dict[uuid.UUID, tuple[dict[str, str], ...]] = reduce( __org_by_user_id__, redis_resources(rconn), {}) @@ -315,7 +315,7 @@ def __search_mrna__(): query = __request_key__("query", "") limit = int(__request_key__("limit", 10000)) offset = int(__request_key__("offset", 0)) - with gn3db.database_connection() as gn3conn: + with gn3db.database_connection(app.config["SQL_URI"]) as gn3conn: __ungrouped__ = partial( ungrouped_mrna_data, gn3conn=gn3conn, search_query=query, selected=__request_key_list__("selected"), @@ -340,7 +340,7 @@ def __search_genotypes__(): query = __request_key__("query", "") limit = int(__request_key__("limit", 10000)) offset = int(__request_key__("offset", 0)) - with gn3db.database_connection() as gn3conn: + with gn3db.database_connection(app.config["SQL_URI"]) as gn3conn: __ungrouped__ = partial( ungrouped_genotype_data, gn3conn=gn3conn, search_query=query, selected=__request_key_list__("selected"), diff --git a/gn3/auth/authorisation/groups/views.py b/gn3/auth/authorisation/groups/views.py index d7a46c0..e933bcd 100644 --- a/gn3/auth/authorisation/groups/views.py +++ b/gn3/auth/authorisation/groups/views.py @@ -207,7 +207,7 @@ def ungrouped_data(dataset_type: str) -> Response: raise AuthorisationError(f"Invalid dataset type {dataset_type}") with require_oauth.acquire("profile group resource") as _the_token: - with gn3dbutils.database_connection() as gn3conn: + with gn3dbutils.database_connection(current_app.config["SQL_URI"]) as gn3conn: return jsonify(with_db_connection(partial( retrieve_ungrouped_data, gn3conn=gn3conn, dataset_type=dataset_type, @@ -226,7 +226,7 @@ def link_data() -> Response: raise InvalidData("Unexpected dataset type requested!") def __link__(conn: db.DbConnection): group = group_by_id(conn, group_id) - with gn3dbutils.database_connection() as gn3conn: + with gn3dbutils.database_connection(current_app.config["SQL_URI"]) as gn3conn: return link_data_to_group( conn, gn3conn, dataset_type, dataset_ids, group) diff --git a/gn3/db_utils.py b/gn3/db_utils.py index e9db10f..7d6a445 100644 --- a/gn3/db_utils.py +++ b/gn3/db_utils.py @@ -4,7 +4,6 @@ from typing import Any, Iterator, Protocol, Tuple from urllib.parse import urlparse import MySQLdb as mdb import xapian -from flask import current_app def parse_db_url(sql_uri: str) -> Tuple: @@ -25,15 +24,10 @@ class Connection(Protocol): ... -## We need to decouple current_app from this module and function, but since this -## function is used throughout the code, that will require careful work to update -## all the code to pass the `sql_uri` argument, and make it a compulsory argument -## rather than its current optional state. @contextlib.contextmanager -def database_connection(sql_uri: str = "") -> Iterator[Connection]: +def database_connection(sql_uri) -> Iterator[Connection]: """Connect to MySQL database.""" - host, user, passwd, db_name, port = parse_db_url( - sql_uri or current_app.config["SQL_URI"]) + host, user, passwd, db_name, port = parse_db_url(sql_uri) connection = mdb.connect(db=db_name, user=user, passwd=passwd or '', -- cgit v1.2.3