diff options
-rw-r--r-- | gn3/api/async_commands.py | 4 | ||||
-rw-r--r-- | gn3/api/correlation.py | 4 | ||||
-rw-r--r-- | gn3/api/heatmaps.py | 6 | ||||
-rw-r--r-- | gn3/api/menu.py | 6 | ||||
-rw-r--r-- | gn3/commands.py | 5 | ||||
-rw-r--r-- | gn3/db_utils.py | 9 | ||||
-rw-r--r-- | scripts/partial_correlations.py | 8 | ||||
-rw-r--r-- | tests/integration/conftest.py | 9 | ||||
-rw-r--r-- | tests/integration/test_correlation.py | 10 | ||||
-rw-r--r-- | tests/performance/perf_query.py | 14 | ||||
-rw-r--r-- | tests/unit/test_db_utils.py | 63 |
11 files changed, 59 insertions, 79 deletions
diff --git a/gn3/api/async_commands.py b/gn3/api/async_commands.py index c0cf4bb..da146b8 100644 --- a/gn3/api/async_commands.py +++ b/gn3/api/async_commands.py @@ -11,6 +11,6 @@ def command_state(command_id): state = rconn.hgetall(name=command_id) if not state: return jsonify( - status=404, - error="The command id provided does not exist.") + error="The command id provided does not exist.", + status="error"), 404 return jsonify(dict(state.items())) diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py index 6e70899..eb4cc7d 100644 --- a/gn3/api/correlation.py +++ b/gn3/api/correlation.py @@ -9,7 +9,7 @@ from flask import request from flask import current_app from gn3.settings import SQL_URI -from gn3.db_utils import database_connector +from gn3.db_utils import database_connection from gn3.commands import run_sample_corr_cmd from gn3.responses.pcorrs_responses import build_response from gn3.commands import run_async_cmd, compose_pcorrs_command @@ -64,7 +64,7 @@ def compute_lit_corr(species=None, gene_id=None): might be needed for actual computing of the correlation results """ - with database_connector() as conn: + with database_connection(current_app.config["SQL_URI"]) as conn: target_traits_gene_ids = request.get_json() target_trait_gene_list = list(target_traits_gene_ids.items()) diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py index 80c8ca8..26c165f 100644 --- a/gn3/api/heatmaps.py +++ b/gn3/api/heatmaps.py @@ -5,9 +5,9 @@ Module to hold the entrypoint functions that generate heatmaps import io from flask import jsonify from flask import request -from flask import Blueprint +from flask import Blueprint, current_app from gn3.heatmaps import build_heatmap -from gn3.db_utils import database_connector +from gn3.db_utils import database_connection heatmaps = Blueprint("heatmaps", __name__) @@ -24,7 +24,7 @@ def clustered_heatmaps(): return jsonify({ "message": "You need to provide at least two trait names." }), 400 - with database_connector() as conn: + with database_connection(current_app.config["SQL_URI"]) as conn: def parse_trait_fullname(trait): name_parts = trait.split(":") return f"{name_parts[1]}::{name_parts[0]}" diff --git a/gn3/api/menu.py b/gn3/api/menu.py index cc77ab8..58b761e 100644 --- a/gn3/api/menu.py +++ b/gn3/api/menu.py @@ -1,14 +1,14 @@ """API for data used to generate menus""" -from flask import jsonify, Blueprint +from flask import jsonify, Blueprint, current_app from gn3.db.menu import gen_dropdown_json -from gn3.db_utils import database_connector +from gn3.db_utils import database_connection menu = Blueprint("menu", __name__) @menu.route("/generate/json") def generate_json(): """Get the menu in the JSON format""" - with database_connector() as conn: + with database_connection(current_app.config["SQL_URI"]) as conn: return jsonify(gen_dropdown_json(conn)) diff --git a/gn3/commands.py b/gn3/commands.py index 0e78fd2..a90e895 100644 --- a/gn3/commands.py +++ b/gn3/commands.py @@ -14,6 +14,8 @@ from typing import Tuple from typing import Union from typing import Sequence from uuid import uuid4 + +from flask import current_app from redis.client import Redis # Used only in type hinting from gn3.chancy import random_string @@ -80,7 +82,8 @@ def compose_pcorrs_command( prefix_cmd = ( f"{sys.executable}", "-m", "scripts.partial_correlations", - primary_trait, ",".join(control_traits), __parse_method__(method)) + primary_trait, ",".join(control_traits), __parse_method__(method), + current_app.config["SQL_URI"]) if ( kwargs.get("target_database") is not None and kwargs.get("target_traits") is None): diff --git a/gn3/db_utils.py b/gn3/db_utils.py index 4827358..e9db10f 100644 --- a/gn3/db_utils.py +++ b/gn3/db_utils.py @@ -17,15 +17,6 @@ def parse_db_url(sql_uri: str) -> Tuple: parsed_db.path[1:], parsed_db.port) -# This function is deprecated. Use database_connection instead. -def database_connector(sql_uri: str = "") -> mdb.Connection: - """function to create db connector""" - host, user, passwd, db_name, db_port = parse_db_url( - sql_uri or current_app.config["SQL_URI"]) - - return mdb.connect(host, user, passwd, db_name, port=(db_port or 3306)) - - # pylint: disable=missing-class-docstring, missing-function-docstring, too-few-public-methods class Connection(Protocol): """Type Annotation for MySQLdb's connection object""" diff --git a/scripts/partial_correlations.py b/scripts/partial_correlations.py index 1fbab78..aab8f08 100644 --- a/scripts/partial_correlations.py +++ b/scripts/partial_correlations.py @@ -4,7 +4,7 @@ import json import traceback from argparse import ArgumentParser -from gn3.db_utils import database_connector +from gn3.db_utils import database_connection from gn3.responses.pcorrs_responses import OutputEncoder from gn3.computations.partial_correlations import ( partial_correlations_with_target_db, @@ -108,6 +108,10 @@ def process_cli_arguments(): help="The correlation method to use", type=str, choices=("pearsons", "spearmans")) + parser.add_argument( + "sql_uri", + help="The uri to use to connect to the database", + type=str) against_db_parser(against_traits_parser( parser.add_subparsers( title="subcommands", @@ -119,7 +123,7 @@ def main(): """Entry point for the script""" args = process_cli_arguments() - with database_connector() as conn: + with database_connection(args.sql_uri) as conn: print(json.dumps(run_pcorrs(conn, args), cls=OutputEncoder)) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index c953d84..795c42d 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -4,7 +4,7 @@ import MySQLdb from gn3.app import create_app from gn3.chancy import random_string -from gn3.db_utils import parse_db_url, database_connector +from gn3.db_utils import parse_db_url, database_connection @pytest.fixture(scope="session") def client(): @@ -18,17 +18,18 @@ def client(): @pytest.fixture(scope="session") -def db_conn(): +def db_conn(client): """Create a db connection fixture for tests""" # 01) Generate random string to append to all test db artifacts for the session + live_db_uri = client.application.config["SQL_URI"] rand_str = random_string(15) - live_db_details = parse_db_url() + live_db_details = parse_db_url(live_db_uri) test_db_name = f"test_{live_db_details[3]}_{rand_str}" # # 02) Create new test db # Use context manager to ensure the live connection is automatically # closed on exit - with database_connector() as live_db_conn: + with database_connection(live_db_uri) as live_db_conn: with live_db_conn.cursor() as live_db_cursor: live_db_cursor.execute(f"CREATE DATABASE {test_db_name}") # diff --git a/tests/integration/test_correlation.py b/tests/integration/test_correlation.py index c1d518d..51769b5 100644 --- a/tests/integration/test_correlation.py +++ b/tests/integration/test_correlation.py @@ -65,15 +65,15 @@ class CorrelationIntegrationTest(TestCase): @pytest.mark.integration_test @mock.patch("gn3.api.correlation.compute_all_lit_correlation") - @mock.patch("gn3.api.correlation.database_connector") - def test_lit_correlation(self, database_connector, mock_compute_corr): + @mock.patch("gn3.api.correlation.database_connection") + def test_lit_correlation(self, database_connection, mock_compute_corr): """Test api/correlation/lit_corr/{species}/{gene_id}""" mock_compute_corr.return_value = [] - database_connector.return_value = mock.Mock() - database_connector.return_value.__enter__ = mock.Mock() - database_connector.return_value.__exit__ = mock.Mock() + database_connection.return_value = mock.Mock() + database_connection.return_value.__enter__ = mock.Mock() + database_connection.return_value.__exit__ = mock.Mock() post_data = {"1426678_at": "68031", "1426679_at": "68036", diff --git a/tests/performance/perf_query.py b/tests/performance/perf_query.py index e534e9b..cdac9a9 100644 --- a/tests/performance/perf_query.py +++ b/tests/performance/perf_query.py @@ -7,7 +7,7 @@ from inspect import getmembers from inspect import isfunction from functools import wraps -from gn3.db_utils import database_connector +from gn3.db_utils import database_connection def timer(func): @@ -26,9 +26,10 @@ def timer(func): def query_executor(query: str, + sql_uri: str, fetch_all: bool = True): """function to execute a query""" - with database_connector() as conn: + with database_connection(sql_uri) as conn: with conn.cursor() as cursor: cursor.execute(query) @@ -58,22 +59,22 @@ def fetch_probeset_query(dataset_name: str): @timer -def perf_hc_m2_dataset(): +def perf_hc_m2_dataset(sql_uri: str): """test the default dataset HC_M2_0606_P""" dataset_name = "HC_M2_0606_P" print(f"Performance test for {dataset_name}") - query_executor(fetch_probeset_query(dataset_name=dataset_name)) + query_executor(fetch_probeset_query(dataset_name=dataset_name), sql_uri) @timer -def perf_umutaffyexon_dataset(): +def perf_umutaffyexon_dataset(sql_uri): """largest dataset in gn""" dataset_name = "UMUTAffyExon_0209_RMA" print(f"Performance test for {dataset_name}") - query_executor(fetch_probeset_query(dataset_name=dataset_name)) + query_executor(fetch_probeset_query(dataset_name=dataset_name), sql_uri) def fetch_perf_functions(): @@ -102,6 +103,7 @@ def fetch_cmd_args(): if __name__ == '__main__': + # Figure out how to pass the database uri here... Maybe use click. func_list = fetch_cmd_args() for func_obj in func_list: func_obj() diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py index a319692..0211107 100644 --- a/tests/unit/test_db_utils.py +++ b/tests/unit/test_db_utils.py @@ -1,49 +1,28 @@ """module contains test for db_utils""" - -from unittest import TestCase from unittest import mock -from types import SimpleNamespace import pytest import gn3 -from gn3.db_utils import database_connector -from gn3.db_utils import parse_db_url - -@pytest.fixture(scope="class") -def setup_app(request, fxtr_app): - """Setup the fixtures for the class.""" - request.cls.app = fxtr_app - -class TestDatabase(TestCase): - """class contains testd for db connection functions""" - - @pytest.mark.unit_test - @mock.patch("gn3.db_utils.mdb") - @mock.patch("gn3.db_utils.parse_db_url") - def test_database_connector(self, mock_db_parser, mock_sql): - """test for creating database connection""" - mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None) - callable_cursor = lambda: SimpleNamespace(execute=3) - cursor_object = SimpleNamespace(cursor=callable_cursor) - mock_sql.connect.return_value = cursor_object - mock_sql.close.return_value = None - results = database_connector() - +from gn3.db_utils import parse_db_url, database_connection + +@pytest.mark.unit_test +@mock.patch("gn3.db_utils.mdb") +@mock.patch("gn3.db_utils.parse_db_url") +def test_database_connection(mock_db_parser, mock_sql): + """test for creating database connection""" + print(f"MOCK SQL: {mock_sql}") + print(f"MOCK DB PARSER: {mock_db_parser}") + mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None) + + with database_connection("mysql://guest:4321@localhost/users") as conn: mock_sql.connect.assert_called_with( - "localhost", "guest", "4321", "users", port=3306) - self.assertIsInstance( - results, SimpleNamespace, "database not created successfully") - - @pytest.mark.unit_test - @pytest.mark.usefixtures("setup_app") - def test_parse_db_url(self): - """test for parsing db_uri env variable""" - print(self.__dict__) - with self.app.app_context(), mock.patch.dict(# pylint: disable=[no-member] - gn3.db_utils.current_app.config, - {"SQL_URI": "mysql://username:4321@localhost/test"}, - clear=True): - results = parse_db_url() - expected_results = ("localhost", "username", "4321", "test", None) - self.assertEqual(results, expected_results) + db="users", user="guest", passwd="4321", host="localhost", + port=3306) + +@pytest.mark.unit_test +def test_parse_db_url(): + """test for parsing db_uri env variable""" + results = parse_db_url("mysql://username:4321@localhost/test") + expected_results = ("localhost", "username", "4321", "test", None) + assert results == expected_results |