about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/auth/authorisation/data/views.py8
-rw-r--r--gn3/auth/authorisation/groups/views.py4
-rw-r--r--gn3/db_utils.py10
3 files changed, 8 insertions, 14 deletions
diff --git a/gn3/auth/authorisation/data/views.py b/gn3/auth/authorisation/data/views.py
index 1a4c031..e00df66 100644
--- a/gn3/auth/authorisation/data/views.py
+++ b/gn3/auth/authorisation/data/views.py
@@ -49,7 +49,7 @@ data = Blueprint("data", __name__)
 @data.route("species")
 def list_species() -> Response:
     """List all available species information."""
-    with (gn3db.database_connection() as gn3conn,
+    with (gn3db.database_connection(app.config["SQL_URI"]) as gn3conn,
           gn3conn.cursor(DictCursor) as cursor):
         cursor.execute("SELECT * FROM Species")
         return jsonify(tuple(dict(row) for row in cursor.fetchall()))
@@ -280,7 +280,7 @@ def migrate_users_data() -> Response:
         with (require_oauth.acquire("migrate-data") as the_token,
               db.connection(db_uri) as authconn,
               redis.Redis(decode_responses=True) as rconn,
-              gn3db.database_connection() as gn3conn):
+              gn3db.database_connection(app.config["SQL_URI"]) as gn3conn):
             if the_token.client.client_id in authorised_clients:
                 by_user: dict[uuid.UUID, tuple[dict[str, str], ...]] = reduce(
                     __org_by_user_id__, redis_resources(rconn), {})
@@ -315,7 +315,7 @@ def __search_mrna__():
     query = __request_key__("query", "")
     limit = int(__request_key__("limit", 10000))
     offset = int(__request_key__("offset", 0))
-    with gn3db.database_connection() as gn3conn:
+    with gn3db.database_connection(app.config["SQL_URI"]) as gn3conn:
         __ungrouped__ = partial(
             ungrouped_mrna_data, gn3conn=gn3conn, search_query=query,
             selected=__request_key_list__("selected"),
@@ -340,7 +340,7 @@ def __search_genotypes__():
     query = __request_key__("query", "")
     limit = int(__request_key__("limit", 10000))
     offset = int(__request_key__("offset", 0))
-    with gn3db.database_connection() as gn3conn:
+    with gn3db.database_connection(app.config["SQL_URI"]) as gn3conn:
         __ungrouped__ = partial(
             ungrouped_genotype_data, gn3conn=gn3conn, search_query=query,
             selected=__request_key_list__("selected"),
diff --git a/gn3/auth/authorisation/groups/views.py b/gn3/auth/authorisation/groups/views.py
index d7a46c0..e933bcd 100644
--- a/gn3/auth/authorisation/groups/views.py
+++ b/gn3/auth/authorisation/groups/views.py
@@ -207,7 +207,7 @@ def ungrouped_data(dataset_type: str) -> Response:
         raise AuthorisationError(f"Invalid dataset type {dataset_type}")
 
     with require_oauth.acquire("profile group resource") as _the_token:
-        with gn3dbutils.database_connection() as gn3conn:
+        with gn3dbutils.database_connection(current_app.config["SQL_URI"]) as gn3conn:
             return jsonify(with_db_connection(partial(
                 retrieve_ungrouped_data, gn3conn=gn3conn,
                 dataset_type=dataset_type,
@@ -226,7 +226,7 @@ def link_data() -> Response:
             raise InvalidData("Unexpected dataset type requested!")
         def __link__(conn: db.DbConnection):
             group = group_by_id(conn, group_id)
-            with gn3dbutils.database_connection() as gn3conn:
+            with gn3dbutils.database_connection(current_app.config["SQL_URI"]) as gn3conn:
                 return link_data_to_group(
                     conn, gn3conn, dataset_type, dataset_ids, group)
 
diff --git a/gn3/db_utils.py b/gn3/db_utils.py
index e9db10f..7d6a445 100644
--- a/gn3/db_utils.py
+++ b/gn3/db_utils.py
@@ -4,7 +4,6 @@ from typing import Any, Iterator, Protocol, Tuple
 from urllib.parse import urlparse
 import MySQLdb as mdb
 import xapian
-from flask import current_app
 
 
 def parse_db_url(sql_uri: str) -> Tuple:
@@ -25,15 +24,10 @@ class Connection(Protocol):
         ...
 
 
-## We need to decouple current_app from this module and function, but since this
-## function is used throughout the code, that will require careful work to update
-## all the code to pass the `sql_uri` argument, and make it a compulsory argument
-## rather than its current optional state.
 @contextlib.contextmanager
-def database_connection(sql_uri: str = "") -> Iterator[Connection]:
+def database_connection(sql_uri) -> Iterator[Connection]:
     """Connect to MySQL database."""
-    host, user, passwd, db_name, port = parse_db_url(
-        sql_uri or current_app.config["SQL_URI"])
+    host, user, passwd, db_name, port = parse_db_url(sql_uri)
     connection = mdb.connect(db=db_name,
                              user=user,
                              passwd=passwd or '',