aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/api/async_commands.py4
-rw-r--r--gn3/api/correlation.py4
-rw-r--r--gn3/api/heatmaps.py6
-rw-r--r--gn3/api/menu.py6
-rw-r--r--gn3/commands.py5
-rw-r--r--gn3/db_utils.py9
-rw-r--r--scripts/partial_correlations.py8
-rw-r--r--tests/integration/conftest.py9
-rw-r--r--tests/integration/test_correlation.py10
-rw-r--r--tests/performance/perf_query.py14
-rw-r--r--tests/unit/test_db_utils.py63
11 files changed, 59 insertions, 79 deletions
diff --git a/gn3/api/async_commands.py b/gn3/api/async_commands.py
index c0cf4bb..da146b8 100644
--- a/gn3/api/async_commands.py
+++ b/gn3/api/async_commands.py
@@ -11,6 +11,6 @@ def command_state(command_id):
state = rconn.hgetall(name=command_id)
if not state:
return jsonify(
- status=404,
- error="The command id provided does not exist.")
+ error="The command id provided does not exist.",
+ status="error"), 404
return jsonify(dict(state.items()))
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py
index 6e70899..eb4cc7d 100644
--- a/gn3/api/correlation.py
+++ b/gn3/api/correlation.py
@@ -9,7 +9,7 @@ from flask import request
from flask import current_app
from gn3.settings import SQL_URI
-from gn3.db_utils import database_connector
+from gn3.db_utils import database_connection
from gn3.commands import run_sample_corr_cmd
from gn3.responses.pcorrs_responses import build_response
from gn3.commands import run_async_cmd, compose_pcorrs_command
@@ -64,7 +64,7 @@ def compute_lit_corr(species=None, gene_id=None):
might be needed for actual computing of the correlation results
"""
- with database_connector() as conn:
+ with database_connection(current_app.config["SQL_URI"]) as conn:
target_traits_gene_ids = request.get_json()
target_trait_gene_list = list(target_traits_gene_ids.items())
diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py
index 80c8ca8..26c165f 100644
--- a/gn3/api/heatmaps.py
+++ b/gn3/api/heatmaps.py
@@ -5,9 +5,9 @@ Module to hold the entrypoint functions that generate heatmaps
import io
from flask import jsonify
from flask import request
-from flask import Blueprint
+from flask import Blueprint, current_app
from gn3.heatmaps import build_heatmap
-from gn3.db_utils import database_connector
+from gn3.db_utils import database_connection
heatmaps = Blueprint("heatmaps", __name__)
@@ -24,7 +24,7 @@ def clustered_heatmaps():
return jsonify({
"message": "You need to provide at least two trait names."
}), 400
- with database_connector() as conn:
+ with database_connection(current_app.config["SQL_URI"]) as conn:
def parse_trait_fullname(trait):
name_parts = trait.split(":")
return f"{name_parts[1]}::{name_parts[0]}"
diff --git a/gn3/api/menu.py b/gn3/api/menu.py
index cc77ab8..58b761e 100644
--- a/gn3/api/menu.py
+++ b/gn3/api/menu.py
@@ -1,14 +1,14 @@
"""API for data used to generate menus"""
-from flask import jsonify, Blueprint
+from flask import jsonify, Blueprint, current_app
from gn3.db.menu import gen_dropdown_json
-from gn3.db_utils import database_connector
+from gn3.db_utils import database_connection
menu = Blueprint("menu", __name__)
@menu.route("/generate/json")
def generate_json():
"""Get the menu in the JSON format"""
- with database_connector() as conn:
+ with database_connection(current_app.config["SQL_URI"]) as conn:
return jsonify(gen_dropdown_json(conn))
diff --git a/gn3/commands.py b/gn3/commands.py
index 0e78fd2..a90e895 100644
--- a/gn3/commands.py
+++ b/gn3/commands.py
@@ -14,6 +14,8 @@ from typing import Tuple
from typing import Union
from typing import Sequence
from uuid import uuid4
+
+from flask import current_app
from redis.client import Redis # Used only in type hinting
from gn3.chancy import random_string
@@ -80,7 +82,8 @@ def compose_pcorrs_command(
prefix_cmd = (
f"{sys.executable}", "-m", "scripts.partial_correlations",
- primary_trait, ",".join(control_traits), __parse_method__(method))
+ primary_trait, ",".join(control_traits), __parse_method__(method),
+ current_app.config["SQL_URI"])
if (
kwargs.get("target_database") is not None
and kwargs.get("target_traits") is None):
diff --git a/gn3/db_utils.py b/gn3/db_utils.py
index 4827358..e9db10f 100644
--- a/gn3/db_utils.py
+++ b/gn3/db_utils.py
@@ -17,15 +17,6 @@ def parse_db_url(sql_uri: str) -> Tuple:
parsed_db.path[1:], parsed_db.port)
-# This function is deprecated. Use database_connection instead.
-def database_connector(sql_uri: str = "") -> mdb.Connection:
- """function to create db connector"""
- host, user, passwd, db_name, db_port = parse_db_url(
- sql_uri or current_app.config["SQL_URI"])
-
- return mdb.connect(host, user, passwd, db_name, port=(db_port or 3306))
-
-
# pylint: disable=missing-class-docstring, missing-function-docstring, too-few-public-methods
class Connection(Protocol):
"""Type Annotation for MySQLdb's connection object"""
diff --git a/scripts/partial_correlations.py b/scripts/partial_correlations.py
index 1fbab78..aab8f08 100644
--- a/scripts/partial_correlations.py
+++ b/scripts/partial_correlations.py
@@ -4,7 +4,7 @@ import json
import traceback
from argparse import ArgumentParser
-from gn3.db_utils import database_connector
+from gn3.db_utils import database_connection
from gn3.responses.pcorrs_responses import OutputEncoder
from gn3.computations.partial_correlations import (
partial_correlations_with_target_db,
@@ -108,6 +108,10 @@ def process_cli_arguments():
help="The correlation method to use",
type=str,
choices=("pearsons", "spearmans"))
+ parser.add_argument(
+ "sql_uri",
+ help="The uri to use to connect to the database",
+ type=str)
against_db_parser(against_traits_parser(
parser.add_subparsers(
title="subcommands",
@@ -119,7 +123,7 @@ def main():
"""Entry point for the script"""
args = process_cli_arguments()
- with database_connector() as conn:
+ with database_connection(args.sql_uri) as conn:
print(json.dumps(run_pcorrs(conn, args), cls=OutputEncoder))
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
index c953d84..795c42d 100644
--- a/tests/integration/conftest.py
+++ b/tests/integration/conftest.py
@@ -4,7 +4,7 @@ import MySQLdb
from gn3.app import create_app
from gn3.chancy import random_string
-from gn3.db_utils import parse_db_url, database_connector
+from gn3.db_utils import parse_db_url, database_connection
@pytest.fixture(scope="session")
def client():
@@ -18,17 +18,18 @@ def client():
@pytest.fixture(scope="session")
-def db_conn():
+def db_conn(client):
"""Create a db connection fixture for tests"""
# 01) Generate random string to append to all test db artifacts for the session
+ live_db_uri = client.application.config["SQL_URI"]
rand_str = random_string(15)
- live_db_details = parse_db_url()
+ live_db_details = parse_db_url(live_db_uri)
test_db_name = f"test_{live_db_details[3]}_{rand_str}"
#
# 02) Create new test db
# Use context manager to ensure the live connection is automatically
# closed on exit
- with database_connector() as live_db_conn:
+ with database_connection(live_db_uri) as live_db_conn:
with live_db_conn.cursor() as live_db_cursor:
live_db_cursor.execute(f"CREATE DATABASE {test_db_name}")
#
diff --git a/tests/integration/test_correlation.py b/tests/integration/test_correlation.py
index c1d518d..51769b5 100644
--- a/tests/integration/test_correlation.py
+++ b/tests/integration/test_correlation.py
@@ -65,15 +65,15 @@ class CorrelationIntegrationTest(TestCase):
@pytest.mark.integration_test
@mock.patch("gn3.api.correlation.compute_all_lit_correlation")
- @mock.patch("gn3.api.correlation.database_connector")
- def test_lit_correlation(self, database_connector, mock_compute_corr):
+ @mock.patch("gn3.api.correlation.database_connection")
+ def test_lit_correlation(self, database_connection, mock_compute_corr):
"""Test api/correlation/lit_corr/{species}/{gene_id}"""
mock_compute_corr.return_value = []
- database_connector.return_value = mock.Mock()
- database_connector.return_value.__enter__ = mock.Mock()
- database_connector.return_value.__exit__ = mock.Mock()
+ database_connection.return_value = mock.Mock()
+ database_connection.return_value.__enter__ = mock.Mock()
+ database_connection.return_value.__exit__ = mock.Mock()
post_data = {"1426678_at": "68031",
"1426679_at": "68036",
diff --git a/tests/performance/perf_query.py b/tests/performance/perf_query.py
index e534e9b..cdac9a9 100644
--- a/tests/performance/perf_query.py
+++ b/tests/performance/perf_query.py
@@ -7,7 +7,7 @@ from inspect import getmembers
from inspect import isfunction
from functools import wraps
-from gn3.db_utils import database_connector
+from gn3.db_utils import database_connection
def timer(func):
@@ -26,9 +26,10 @@ def timer(func):
def query_executor(query: str,
+ sql_uri: str,
fetch_all: bool = True):
"""function to execute a query"""
- with database_connector() as conn:
+ with database_connection(sql_uri) as conn:
with conn.cursor() as cursor:
cursor.execute(query)
@@ -58,22 +59,22 @@ def fetch_probeset_query(dataset_name: str):
@timer
-def perf_hc_m2_dataset():
+def perf_hc_m2_dataset(sql_uri: str):
"""test the default dataset HC_M2_0606_P"""
dataset_name = "HC_M2_0606_P"
print(f"Performance test for {dataset_name}")
- query_executor(fetch_probeset_query(dataset_name=dataset_name))
+ query_executor(fetch_probeset_query(dataset_name=dataset_name), sql_uri)
@timer
-def perf_umutaffyexon_dataset():
+def perf_umutaffyexon_dataset(sql_uri):
"""largest dataset in gn"""
dataset_name = "UMUTAffyExon_0209_RMA"
print(f"Performance test for {dataset_name}")
- query_executor(fetch_probeset_query(dataset_name=dataset_name))
+ query_executor(fetch_probeset_query(dataset_name=dataset_name), sql_uri)
def fetch_perf_functions():
@@ -102,6 +103,7 @@ def fetch_cmd_args():
if __name__ == '__main__':
+ # Figure out how to pass the database uri here... Maybe use click.
func_list = fetch_cmd_args()
for func_obj in func_list:
func_obj()
diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py
index a319692..0211107 100644
--- a/tests/unit/test_db_utils.py
+++ b/tests/unit/test_db_utils.py
@@ -1,49 +1,28 @@
"""module contains test for db_utils"""
-
-from unittest import TestCase
from unittest import mock
-from types import SimpleNamespace
import pytest
import gn3
-from gn3.db_utils import database_connector
-from gn3.db_utils import parse_db_url
-
-@pytest.fixture(scope="class")
-def setup_app(request, fxtr_app):
- """Setup the fixtures for the class."""
- request.cls.app = fxtr_app
-
-class TestDatabase(TestCase):
- """class contains testd for db connection functions"""
-
- @pytest.mark.unit_test
- @mock.patch("gn3.db_utils.mdb")
- @mock.patch("gn3.db_utils.parse_db_url")
- def test_database_connector(self, mock_db_parser, mock_sql):
- """test for creating database connection"""
- mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None)
- callable_cursor = lambda: SimpleNamespace(execute=3)
- cursor_object = SimpleNamespace(cursor=callable_cursor)
- mock_sql.connect.return_value = cursor_object
- mock_sql.close.return_value = None
- results = database_connector()
-
+from gn3.db_utils import parse_db_url, database_connection
+
+@pytest.mark.unit_test
+@mock.patch("gn3.db_utils.mdb")
+@mock.patch("gn3.db_utils.parse_db_url")
+def test_database_connection(mock_db_parser, mock_sql):
+ """test for creating database connection"""
+ print(f"MOCK SQL: {mock_sql}")
+ print(f"MOCK DB PARSER: {mock_db_parser}")
+ mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None)
+
+ with database_connection("mysql://guest:4321@localhost/users") as conn:
mock_sql.connect.assert_called_with(
- "localhost", "guest", "4321", "users", port=3306)
- self.assertIsInstance(
- results, SimpleNamespace, "database not created successfully")
-
- @pytest.mark.unit_test
- @pytest.mark.usefixtures("setup_app")
- def test_parse_db_url(self):
- """test for parsing db_uri env variable"""
- print(self.__dict__)
- with self.app.app_context(), mock.patch.dict(# pylint: disable=[no-member]
- gn3.db_utils.current_app.config,
- {"SQL_URI": "mysql://username:4321@localhost/test"},
- clear=True):
- results = parse_db_url()
- expected_results = ("localhost", "username", "4321", "test", None)
- self.assertEqual(results, expected_results)
+ db="users", user="guest", passwd="4321", host="localhost",
+ port=3306)
+
+@pytest.mark.unit_test
+def test_parse_db_url():
+ """test for parsing db_uri env variable"""
+ results = parse_db_url("mysql://username:4321@localhost/test")
+ expected_results = ("localhost", "username", "4321", "test", None)
+ assert results == expected_results