about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--qc_app/files.py12
-rw-r--r--qc_app/samples.py93
2 files changed, 69 insertions, 36 deletions
diff --git a/qc_app/files.py b/qc_app/files.py
index 6485b27..0304296 100644
--- a/qc_app/files.py
+++ b/qc_app/files.py
@@ -3,18 +3,16 @@ from pathlib import Path
 from typing import Union
 
 from werkzeug.utils import secure_filename
-from flask import (
-    request,
-    current_app as app)
+from werkzeug.datastructures import FileStorage
 
-def save_file(key: str, upload_dir: Path) -> Union[Path, bool]:
+def save_file(fileobj: FileStorage, upload_dir: Path) -> Union[Path, bool]:
     """Save the uploaded file and return the path."""
-    if not bool(request.files.get(key)):
+    if not bool(fileobj):
         return False
-    filename = Path(secure_filename(request.files[key].filename))
+    filename = Path(secure_filename(fileobj.filename)) # type: ignore[arg-type]
     if not upload_dir.exists():
         upload_dir.mkdir()
 
     filepath = Path(upload_dir, filename)
-    request.files["samples_file"].save(filepath)
+    fileobj.save(filepath)
     return filepath
diff --git a/qc_app/samples.py b/qc_app/samples.py
index 27fdad3..1063fb8 100644
--- a/qc_app/samples.py
+++ b/qc_app/samples.py
@@ -120,52 +120,74 @@ def select_population():
                            species=species,
                            population=population)
 
-def read_samples_file(filepath, separator: str, **kwargs) -> Iterator[dict]:
+def read_samples_file(filepath, separator: str, firstlineheading: bool, **kwargs) -> Iterator[dict]:
     """Read the samples file."""
     with open(filepath, "r", encoding="utf-8") as inputfile:
         reader = csv.DictReader(
             inputfile,
-            fieldnames=("Name", "Name2", "Symbol", "Alias"),
+            fieldnames=(
+                None if firstlineheading
+                else ("Name", "Name2", "Symbol", "Alias")),
             delimiter=separator,
-            quotechar=kwargs.get("quotechar"))
+            quotechar=kwargs.get("quotechar", '"'))
         for row in reader:
             yield row
 
-def save_samples_data(conn: mdb.Connection, file_data: Iterator[dict]):
+def save_samples_data(conn: mdb.Connection,
+                      speciesid: int,
+                      file_data: Iterator[dict]):
     """Save the samples to DB."""
+    data = ({**row, "SpeciesId": speciesid} for row in file_data)
     with conn.cursor() as cursor:
         while True:
+            batch = take(data, 5000)
+            if len(batch) == 0:
+                break
             cursor.executemany(
                 "INSERT INTO Strain(Name, Name2, SpeciesId, Symbol, Alias) "
                 "VALUES("
                 "    %(Name)s, %(Name2)s, %(SpeciesId)s, %(Symbol)s, %(Alias)s"
-                ")",
-                tuple(take(file_data, 10000)))
+                ") ON DUPLICATE KEY UPDATE Name=Name",
+                batch)
 
 def cross_reference_samples(conn: mdb.Connection,
+                            species_id: int,
                             population_id: int,
-                            strain_names: tuple[str, ...]):
+                            strain_names: Iterator[str]):
     """Link samples to their population."""
     with conn.cursor(cursorclass=DictCursor) as cursor:
-        params_str = ", ".join(["%s"] * len(strain_names))
-        cursor.execute(
-            "SELECT Id FROM Strain WHERE (Name, SpeciesId) IN "
-            f"{params_str}",
-            tuple((name, species["SpeciesId"]) for name in strain_names))
-        strain_ids = (sid for sid in cursor.fetchall())
-        cursor.execute(
-            "SELECT MAX(OrderId) AS loid FROM StrainXRef WHERE InbredSetId=%s",
-            (population_id,))
-        last_order_id = cursor.fetchone()["loid"]
-        cursor.executemany(
-            "INSERT INTO StrainXRef(InbredSetId, StrainId, OrderId) "
-            "VALUES (%(pop_id)s, %(strain_id)s, %(order_id)s)",
-            tuple({
-                "pop_id": population_id,
-                "strain_id": strain_id,
-                "order_id": order_id
-            } for order_id, strain_id in
-                  enumerate(strain_ids, start=(last_order_id+10))))
+        while True:
+            batch = take(strain_names, 5000)
+            if len(batch) == 0:
+                break
+            params_str = ", ".join(["%s"] * len(batch))
+            ## This query is slow -- investigate.
+            cursor.execute(
+                "SELECT s.Id FROM Strain AS s LEFT JOIN StrainXRef AS sx "
+                "ON s.Id = sx.StrainId WHERE s.SpeciesId=%s AND s.Name IN "
+                f"({params_str}) AND sx.StrainId IS NULL",
+                (species_id,) + tuple(batch))
+            strain_ids = (sid["Id"] for sid in cursor.fetchall())
+            cursor.execute(
+                "SELECT MAX(OrderId) AS loid FROM StrainXRef WHERE InbredSetId=%s",
+                (population_id,))
+            last_order_id = cursor.fetchone()["loid"]
+            cursor.executemany(
+                "INSERT INTO StrainXRef( "
+                "  InbredSetId, StrainId, OrderId, Used_for_mapping, PedigreeStatus"
+                ")"
+                "VALUES ("
+                "  %(pop_id)s, %(strain_id)s, %(order_id)s, %(mapping)s, "
+                "  %(pedigree)s"
+                ")",
+                tuple({
+                    "pop_id": population_id,
+                    "strain_id": strain_id,
+                    "order_id": order_id,
+                    "mapping": "N",
+                    "pedigree": None
+                } for order_id, strain_id in
+                      enumerate(strain_ids, start=(last_order_id+10))))
 
 @samples.route("/upload/samples", methods=["POST"])
 def upload_samples():
@@ -185,17 +207,30 @@ def upload_samples():
         flash("Invalid grouping/population!", "alert-error")
         return samples_uploads_page
 
-    samples_file = save_file("samples_file", Path(app.config["UPLOAD_FOLDER"]))
+    samples_file = save_file(request.files["samples_file"], Path(app.config["UPLOAD_FOLDER"]))
     if not bool(samples_file):
         flash("You need to provide a file with the samples data.")
         return samples_uploads_page
 
+    firstlineheading = (request.form.get("first_line_heading") == "on")
+
+    separator = request.form.get("separator")
+    if separator == "other":
+        separator = request.form.get("other_separator")
+    if not bool(separator):
+        flash("You need to provide a separator character.", "alert-error")
+        return samples_uploads_page
+
     def __insert_samples__(conn: mdb.Connection):
-        save_samples_data(conn, read_samples_file(samples_file))
+        save_samples_data(
+            conn,
+            species["SpeciesId"],
+            read_samples_file(samples_file, separator, firstlineheading))
         cross_reference_samples(
             conn,
+            species["SpeciesId"],
             population["InbredSetId"],
-            tuple(row["Name"] for row in read_samples_file(samples_file)))
+            (row["Name"] for row in read_samples_file(samples_file, separator, firstlineheading)))
 
     with_db_connection(__insert_samples__)
     return "SUCCESS: Respond with a better UI than this."