From 8e9abe8eccd1a95d34ab9a6bc7b92d1e660dcae7 Mon Sep 17 00:00:00 2001
From: Frederick Muriuki Muriithi
Date: Thu, 14 Dec 2023 19:21:14 +0300
Subject: Pass connection to `species_by_id` function.

To make `species_by_id` function reusable even outside of the
application context, pass in the database connection instead of
creating the connection inside the function.
---
 qc_app/dbinsert.py |  25 +++++------
 qc_app/samples.py  | 120 +++++++++++++++++++++++++++--------------------------
 2 files changed, 74 insertions(+), 71 deletions(-)

(limited to 'qc_app')

diff --git a/qc_app/dbinsert.py b/qc_app/dbinsert.py
index ab1c350..2282c8d 100644
--- a/qc_app/dbinsert.py
+++ b/qc_app/dbinsert.py
@@ -5,6 +5,7 @@ from typing import Union
 from functools import reduce
 from datetime import datetime
 
+import MySQLdb as mdb
 from redis import Redis
 from MySQLdb.cursors import DictCursor
 from flask import (
@@ -12,7 +13,7 @@ from flask import (
     current_app as app)
 
 from . import jobs
-from .db_utils import database_connection
+from .db_utils import with_db_connection, database_connection
 
 dbinsertbp = Blueprint("dbinsert", __name__)
 
@@ -41,17 +42,16 @@ def species() -> tuple:
 
     return tuple()
 
-def species_by_id(speciesid) -> Union[dict, None]:
+def species_by_id(conn: mdb.Connection, speciesid) -> Union[dict, None]:
     "Retrieve the species from the database by id."
-    with database_connection() as conn:
-        with conn.cursor(cursorclass=DictCursor) as cursor:
-            cursor.execute(
-                (
-                    "SELECT "
-                    "SpeciesId, SpeciesName, LOWER(Name) AS Name, MenuName "
-                    "FROM Species WHERE SpeciesId=%s"),
-                (speciesid,))
-            return cursor.fetchone()
+    with conn.cursor(cursorclass=DictCursor) as cursor:
+        cursor.execute(
+            (
+                "SELECT "
+                "SpeciesId, SpeciesName, LOWER(Name) AS Name, MenuName "
+                "FROM Species WHERE SpeciesId=%s"),
+            (speciesid,))
+        return cursor.fetchone()
 
 def genechips():
     "Retrieve the genechip information from the database"
@@ -362,7 +362,8 @@ def final_confirmation():
             filetype=form["filetype"], totallines=form["totallines"],
             species=speciesid, genechipid=genechipid, studyid=studyid,
             datasetid=datasetid, the_species=selected_keys(
-                species_by_id(speciesid), ("SpeciesName", "Name", "MenuName")),
+                with_db_connection(lambda conn: species_by_id(conn, speciesid)),
+                ("SpeciesName", "Name", "MenuName")),
             platform=selected_keys(
                 platform_by_id(genechipid),
                 ("GeneChipName", "Name", "GeoPlatform", "Title", "GO_tree_value")),
diff --git a/qc_app/samples.py b/qc_app/samples.py
index 1063fb8..dee08e5 100644
--- a/qc_app/samples.py
+++ b/qc_app/samples.py
@@ -17,8 +17,8 @@ from flask import (
 from quality_control.parsing import take
 
 from .files import save_file
-from .db_utils import with_db_connection
 from .dbinsert import species_by_id, groups_by_species
+from .db_utils import with_db_connection, database_connection
 
 samples = Blueprint("samples", __name__)
 
@@ -29,7 +29,8 @@ def select_species():
     species_id = request.form.get("species_id")
     if bool(species_id):
         species_id = int(species_id)
-        species = species_by_id(species_id)
+        species = with_db_connection(
+            lambda conn: species_by_id(conn, species_id))
         if bool(species):
             return render_template(
                 "samples/select-population.html",
@@ -72,40 +73,42 @@ def population_by_id(conn: mdb.Connection, population_id: int) -> dict:
 def create_population():
     """Create new grouping/population."""
     species_page = redirect(url_for("samples.select_species"), code=307)
-    species = species_by_id(request.form.get("species_id"))
-    pop_name = request.form.get("inbredset_name").strip()
-    pop_fullname = request.form.get("inbredset_fullname").strip()
-
-    if not bool(species):
-        flash("Invalid species!", "alert-error error-create-population")
-        return species_page
-    if (not bool(pop_name)) or (not bool(pop_fullname)):
-        flash("You *MUST* provide a grouping/population name",
-              "alert-error error-create-population")
-        return species_page
-
-    pop_id = with_db_connection(lambda conn: save_population(conn, {
-        "SpeciesId": species["SpeciesId"],
-        "Name": pop_name,
-        "InbredSetName": pop_fullname,
-        "FullName": pop_fullname,
-        "Family": request.form.get("inbredset_family") or None,
-        "Description": request.form.get("description") or None
-    }))
-    flash("Grouping/Population created successfully.", "alert-success")
-    return render_template(
-        "samples/upload-samples.html",
-        species=species,
-        population=with_db_connection(
-            lambda conn: population_by_id(conn, pop_id)))
+    with database_connection(app.config["SQL_URI"]) as conn:
+        species = species_by_id(conn, request.form.get("species_id"))
+        pop_name = request.form.get("inbredset_name").strip()
+        pop_fullname = request.form.get("inbredset_fullname").strip()
+
+        if not bool(species):
+            flash("Invalid species!", "alert-error error-create-population")
+            return species_page
+        if (not bool(pop_name)) or (not bool(pop_fullname)):
+            flash("You *MUST* provide a grouping/population name",
+                  "alert-error error-create-population")
+            return species_page
+
+        pop_id = save_population(conn, {
+            "SpeciesId": species["SpeciesId"],
+            "Name": pop_name,
+            "InbredSetName": pop_fullname,
+            "FullName": pop_fullname,
+            "Family": request.form.get("inbredset_family") or None,
+            "Description": request.form.get("description") or None
+        })
+        flash("Grouping/Population created successfully.", "alert-success")
+        return render_template(
+            "samples/upload-samples.html",
+            species=species,
+            population=with_db_connection(
+                lambda conn: population_by_id(conn, pop_id)))
 
 @samples.route("/upload/select-population", methods=["POST"])
 def select_population():
     """Select from existing groupings/populations."""
     species_page = redirect(url_for("samples.select_species"), code=307)
-    species = species_by_id(request.form.get("species_id"))
-    pop_id = int(request.form.get("inbredset_id"))
-    population = with_db_connection(lambda conn: population_by_id(conn, pop_id))
+    with database_connection(app.config["SQL_URI"]) as conn:
+        species = species_by_id(conn, request.form.get("species_id"))
+        pop_id = int(request.form.get("inbredset_id"))
+        population = with_db_connection(lambda conn: population_by_id(conn, pop_id))
 
     if not bool(species):
         flash("Invalid species!", "alert-error error-select-population")
@@ -195,33 +198,33 @@ def upload_samples():
     samples_uploads_page = redirect(url_for("samples.select_population"),
                                     code=307)
 
-    species = species_by_id(request.form.get("species_id"))
-    if not bool(species):
-        flash("Invalid species!", "alert-error")
-        return samples_uploads_page
-
-    population = with_db_connection(
-        lambda conn: population_by_id(
-            conn, int(request.form.get("inbredset_id"))))
-    if not bool(population):
-        flash("Invalid grouping/population!", "alert-error")
-        return samples_uploads_page
-
-    samples_file = save_file(request.files["samples_file"], Path(app.config["UPLOAD_FOLDER"]))
-    if not bool(samples_file):
-        flash("You need to provide a file with the samples data.")
-        return samples_uploads_page
-
-    firstlineheading = (request.form.get("first_line_heading") == "on")
-
-    separator = request.form.get("separator")
-    if separator == "other":
-        separator = request.form.get("other_separator")
-    if not bool(separator):
-        flash("You need to provide a separator character.", "alert-error")
-        return samples_uploads_page
+    with database_connection(app.config["SQL_URI"]) as conn:
+        species = species_by_id(conn, request.form.get("species_id"))
+        if not bool(species):
+            flash("Invalid species!", "alert-error")
+            return samples_uploads_page
+
+        population = with_db_connection(
+            lambda conn: population_by_id(
+                conn, int(request.form.get("inbredset_id"))))
+        if not bool(population):
+            flash("Invalid grouping/population!", "alert-error")
+            return samples_uploads_page
+
+        samples_file = save_file(request.files["samples_file"], Path(app.config["UPLOAD_FOLDER"]))
+        if not bool(samples_file):
+            flash("You need to provide a file with the samples data.")
+            return samples_uploads_page
+
+        firstlineheading = (request.form.get("first_line_heading") == "on")
+
+        separator = request.form.get("separator")
+        if separator == "other":
+            separator = request.form.get("other_separator")
+        if not bool(separator):
+            flash("You need to provide a separator character.", "alert-error")
+            return samples_uploads_page
 
-    def __insert_samples__(conn: mdb.Connection):
         save_samples_data(
             conn,
             species["SpeciesId"],
@@ -232,5 +235,4 @@ def upload_samples():
             population["InbredSetId"],
             (row["Name"] for row in read_samples_file(samples_file, separator, firstlineheading)))
 
-    with_db_connection(__insert_samples__)
-    return "SUCCESS: Respond with a better UI than this."
+        return "SUCCESS: Respond with a better UI than this."
-- 
cgit v1.2.3