about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-03-08 08:00:16 +0300
committerFrederick Muriuki Muriithi2022-03-08 08:00:16 +0300
commiteae345ed252c01e541d64c7e5b60b488d84268c6 (patch)
tree819fe27db22b757da2eeafd438abe01ca8ab8cc8
parent84f51f48a59da93e287d793d983ace4d06ccb483 (diff)
downloadgenenetwork3-eae345ed252c01e541d64c7e5b60b488d84268c6.tar.gz
Create database connections within context managers
Use the `with` context manager to open database connections, so as to ensure
that those connections are closed once the call is completed. This hopefully
avoids the 'too many connections' error
-rw-r--r--gn3/api/correlation.py16
-rw-r--r--gn3/api/heatmaps.py20
-rw-r--r--gn3/db_utils.py4
-rwxr-xr-xscripts/partial_correlations.py26
4 files changed, 32 insertions, 34 deletions
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py
index 14c029c..f2ac4d7 100644
--- a/gn3/api/correlation.py
+++ b/gn3/api/correlation.py
@@ -68,17 +68,15 @@ def compute_lit_corr(species=None, gene_id=None):
     might be needed for actual computing of the correlation results
     """
 
-    conn, _cursor_object = database_connector()
-    target_traits_gene_ids = request.get_json()
-    target_trait_gene_list = list(target_traits_gene_ids.items())
+    with database_connector() as conn:
+        target_traits_gene_ids = request.get_json()
+        target_trait_gene_list = list(target_traits_gene_ids.items())
 
-    lit_corr_results = compute_all_lit_correlation(
-        conn=conn, trait_lists=target_trait_gene_list,
-        species=species, gene_id=gene_id)
+        lit_corr_results = compute_all_lit_correlation(
+            conn=conn, trait_lists=target_trait_gene_list,
+            species=species, gene_id=gene_id)
 
-    conn.close()
-
-    return jsonify(lit_corr_results)
+        return jsonify(lit_corr_results)
 
 
 @correlation.route("/tissue_corr/<string:corr_method>", methods=["POST"])
diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py
index b2511c3..80c8ca8 100644
--- a/gn3/api/heatmaps.py
+++ b/gn3/api/heatmaps.py
@@ -24,14 +24,14 @@ def clustered_heatmaps():
         return jsonify({
             "message": "You need to provide at least two trait names."
         }), 400
-    conn, _cursor = database_connector()
-    def parse_trait_fullname(trait):
-        name_parts = trait.split(":")
-        return f"{name_parts[1]}::{name_parts[0]}"
-    traits_fullnames = [parse_trait_fullname(trait) for trait in traits_names]
+    with database_connector() as conn:
+        def parse_trait_fullname(trait):
+            name_parts = trait.split(":")
+            return f"{name_parts[1]}::{name_parts[0]}"
+        traits_fullnames = [parse_trait_fullname(trait) for trait in traits_names]
 
-    with io.StringIO() as io_str:
-        figure = build_heatmap(traits_fullnames, conn, vertical=vertical)
-        figure.write_json(io_str)
-        fig_json = io_str.getvalue()
-    return fig_json, 200
+        with io.StringIO() as io_str:
+            figure = build_heatmap(traits_fullnames, conn, vertical=vertical)
+            figure.write_json(io_str)
+            fig_json = io_str.getvalue()
+        return fig_json, 200
diff --git a/gn3/db_utils.py b/gn3/db_utils.py
index 7263705..3703cbb 100644
--- a/gn3/db_utils.py
+++ b/gn3/db_utils.py
@@ -14,10 +14,10 @@ def parse_db_url() -> Tuple:
             parsed_db.password, parsed_db.path[1:])
 
 
-def database_connector() -> Tuple:
+def database_connector() -> mdb.Connection:
     """function to create db connector"""
     host, user, passwd, db_name = parse_db_url()
     conn = mdb.connect(host, user, passwd, db_name)
     cursor = conn.cursor()
 
-    return (conn, cursor)
+    return conn
diff --git a/scripts/partial_correlations.py b/scripts/partial_correlations.py
index ee442df..f203daa 100755
--- a/scripts/partial_correlations.py
+++ b/scripts/partial_correlations.py
@@ -35,19 +35,19 @@ def cleanup_string(the_str):
     return the_str.strip('"\t\n\r ')
 
 def run_partial_corrs(args):
-    try:
-        conn, _cursor_object = database_connector()
-        return partial_correlations_entry(
-            conn, cleanup_string(args.primary_trait),
-            tuple(cleanup_string(args.control_traits).split(",")),
-            cleanup_string(args.method), args.criteria,
-            cleanup_string(args.target_database))
-    except Exception as exc:
-        print(traceback.format_exc(), file=sys.stderr)
-        return {
-            "status": "exception",
-            "message": traceback.format_exc()
-        }
+    with database_connector() as conn:
+        try:
+            return partial_correlations_entry(
+                conn, cleanup_string(args.primary_trait),
+                tuple(cleanup_string(args.control_traits).split(",")),
+                cleanup_string(args.method), args.criteria,
+                cleanup_string(args.target_database))
+        except Exception as exc:
+            print(traceback.format_exc(), file=sys.stderr)
+            return {
+                "status": "exception",
+                "message": traceback.format_exc()
+            }
 
 def enter():
     args = process_cli_arguments()