aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/auth/authorisation/data/views.py8
-rw-r--r--gn3/auth/authorisation/groups/views.py4
-rw-r--r--gn3/db_utils.py10
3 files changed, 8 insertions, 14 deletions
diff --git a/gn3/auth/authorisation/data/views.py b/gn3/auth/authorisation/data/views.py
index 1a4c031..e00df66 100644
--- a/gn3/auth/authorisation/data/views.py
+++ b/gn3/auth/authorisation/data/views.py
@@ -49,7 +49,7 @@ data = Blueprint("data", __name__)
@data.route("species")
def list_species() -> Response:
"""List all available species information."""
- with (gn3db.database_connection() as gn3conn,
+ with (gn3db.database_connection(app.config["SQL_URI"]) as gn3conn,
gn3conn.cursor(DictCursor) as cursor):
cursor.execute("SELECT * FROM Species")
return jsonify(tuple(dict(row) for row in cursor.fetchall()))
@@ -280,7 +280,7 @@ def migrate_users_data() -> Response:
with (require_oauth.acquire("migrate-data") as the_token,
db.connection(db_uri) as authconn,
redis.Redis(decode_responses=True) as rconn,
- gn3db.database_connection() as gn3conn):
+ gn3db.database_connection(app.config["SQL_URI"]) as gn3conn):
if the_token.client.client_id in authorised_clients:
by_user: dict[uuid.UUID, tuple[dict[str, str], ...]] = reduce(
__org_by_user_id__, redis_resources(rconn), {})
@@ -315,7 +315,7 @@ def __search_mrna__():
query = __request_key__("query", "")
limit = int(__request_key__("limit", 10000))
offset = int(__request_key__("offset", 0))
- with gn3db.database_connection() as gn3conn:
+ with gn3db.database_connection(app.config["SQL_URI"]) as gn3conn:
__ungrouped__ = partial(
ungrouped_mrna_data, gn3conn=gn3conn, search_query=query,
selected=__request_key_list__("selected"),
@@ -340,7 +340,7 @@ def __search_genotypes__():
query = __request_key__("query", "")
limit = int(__request_key__("limit", 10000))
offset = int(__request_key__("offset", 0))
- with gn3db.database_connection() as gn3conn:
+ with gn3db.database_connection(app.config["SQL_URI"]) as gn3conn:
__ungrouped__ = partial(
ungrouped_genotype_data, gn3conn=gn3conn, search_query=query,
selected=__request_key_list__("selected"),
diff --git a/gn3/auth/authorisation/groups/views.py b/gn3/auth/authorisation/groups/views.py
index d7a46c0..e933bcd 100644
--- a/gn3/auth/authorisation/groups/views.py
+++ b/gn3/auth/authorisation/groups/views.py
@@ -207,7 +207,7 @@ def ungrouped_data(dataset_type: str) -> Response:
raise AuthorisationError(f"Invalid dataset type {dataset_type}")
with require_oauth.acquire("profile group resource") as _the_token:
- with gn3dbutils.database_connection() as gn3conn:
+ with gn3dbutils.database_connection(current_app.config["SQL_URI"]) as gn3conn:
return jsonify(with_db_connection(partial(
retrieve_ungrouped_data, gn3conn=gn3conn,
dataset_type=dataset_type,
@@ -226,7 +226,7 @@ def link_data() -> Response:
raise InvalidData("Unexpected dataset type requested!")
def __link__(conn: db.DbConnection):
group = group_by_id(conn, group_id)
- with gn3dbutils.database_connection() as gn3conn:
+ with gn3dbutils.database_connection(current_app.config["SQL_URI"]) as gn3conn:
return link_data_to_group(
conn, gn3conn, dataset_type, dataset_ids, group)
diff --git a/gn3/db_utils.py b/gn3/db_utils.py
index e9db10f..7d6a445 100644
--- a/gn3/db_utils.py
+++ b/gn3/db_utils.py
@@ -4,7 +4,6 @@ from typing import Any, Iterator, Protocol, Tuple
from urllib.parse import urlparse
import MySQLdb as mdb
import xapian
-from flask import current_app
def parse_db_url(sql_uri: str) -> Tuple:
@@ -25,15 +24,10 @@ class Connection(Protocol):
...
-## We need to decouple current_app from this module and function, but since this
-## function is used throughout the code, that will require careful work to update
-## all the code to pass the `sql_uri` argument, and make it a compulsory argument
-## rather than its current optional state.
@contextlib.contextmanager
-def database_connection(sql_uri: str = "") -> Iterator[Connection]:
+def database_connection(sql_uri) -> Iterator[Connection]:
"""Connect to MySQL database."""
- host, user, passwd, db_name, port = parse_db_url(
- sql_uri or current_app.config["SQL_URI"])
+ host, user, passwd, db_name, port = parse_db_url(sql_uri)
connection = mdb.connect(db=db_name,
user=user,
passwd=passwd or '',