about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--qc_app/samples.py25
-rw-r--r--scripts/insert_samples.py147
2 files changed, 160 insertions, 12 deletions
diff --git a/qc_app/samples.py b/qc_app/samples.py
index dee08e5..88a0fde 100644
--- a/qc_app/samples.py
+++ b/qc_app/samples.py
@@ -159,6 +159,10 @@ def cross_reference_samples(conn: mdb.Connection,
                             strain_names: Iterator[str]):
     """Link samples to their population."""
     with conn.cursor(cursorclass=DictCursor) as cursor:
+        cursor.execute(
+            "SELECT MAX(OrderId) AS loid FROM StrainXRef WHERE InbredSetId=%s",
+            (population_id,))
+        last_order_id = cursor.fetchone()["loid"]
         while True:
             batch = take(strain_names, 5000)
             if len(batch) == 0:
@@ -171,10 +175,13 @@ def cross_reference_samples(conn: mdb.Connection,
                 f"({params_str}) AND sx.StrainId IS NULL",
                 (species_id,) + tuple(batch))
             strain_ids = (sid["Id"] for sid in cursor.fetchall())
-            cursor.execute(
-                "SELECT MAX(OrderId) AS loid FROM StrainXRef WHERE InbredSetId=%s",
-                (population_id,))
-            last_order_id = cursor.fetchone()["loid"]
+            params = tuple({
+                "pop_id": population_id,
+                "strain_id": strain_id,
+                "order_id": last_order_id + (order_id * 10),
+                "mapping": "N",
+                "pedigree": None
+            } for order_id, strain_id in enumerate(strain_ids, start=1))
             cursor.executemany(
                 "INSERT INTO StrainXRef( "
                 "  InbredSetId, StrainId, OrderId, Used_for_mapping, PedigreeStatus"
@@ -183,14 +190,8 @@ def cross_reference_samples(conn: mdb.Connection,
                 "  %(pop_id)s, %(strain_id)s, %(order_id)s, %(mapping)s, "
                 "  %(pedigree)s"
                 ")",
-                tuple({
-                    "pop_id": population_id,
-                    "strain_id": strain_id,
-                    "order_id": order_id,
-                    "mapping": "N",
-                    "pedigree": None
-                } for order_id, strain_id in
-                      enumerate(strain_ids, start=(last_order_id+10))))
+                params)
+            last_order_id += (len(params) * 10)
 
 @samples.route("/upload/samples", methods=["POST"])
 def upload_samples():
diff --git a/scripts/insert_samples.py b/scripts/insert_samples.py
new file mode 100644
index 0000000..43c6a38
--- /dev/null
+++ b/scripts/insert_samples.py
@@ -0,0 +1,147 @@
+"""Insert samples into the database."""
+import sys
+import logging
+import pathlib
+import argparse
+
+import MySQLdb as mdb
+from redis import Redis
+
+from qc_app.db_utils import database_connection
+from qc_app.check_connections import check_db, check_redis
+from qc_app.samples import (
+    species_by_id,
+    population_by_id,
+    save_samples_data,
+    read_samples_file,
+    cross_reference_samples)
+
+stderr_handler = logging.StreamHandler(stream=sys.stderr)
+root_logger = logging.getLogger()
+root_logger.addHandler(stderr_handler)
+root_logger.setLevel("INFO")
+
+class SeparatorAction(argparse.Action):
+    """Action to handle the separator values."""
+    def __init__(self, option_strings, dest, nargs=None, **kwargs):
+        """Init the action"""
+        if nargs is not None:
+            raise ValueError("nargs not allowed.")
+        super().__init__(option_strings, dest, nargs, **kwargs)
+
+    def __call__(self, parser, namespace, values, option_string=None):
+        """Process the value passed in."""
+        setattr(namespace, self.dest, (chr(9) if values == "\\t" else values))
+
+def insert_samples(conn: mdb.Connection,# pylint: disable=[too-many-arguments]
+                   rconn: Redis,# pylint: disable=[unused-argument]
+                   speciesid: int,
+                   populationid: int,
+                   samplesfile: pathlib.Path,
+                   separator: str,
+                   firstlineheading: bool,
+                   quotechar: str):
+    """Insert the samples into the database."""
+    species = species_by_id(conn, speciesid)
+    if not bool(species):
+        logging.error("Species with id '%s' does not exist.", str(speciesid))
+        return 1
+    population = population_by_id(conn, populationid)
+    if not bool(population):
+        logging.error("Population with id '%s' does not exist.",
+                      str(populationid))
+        return 1
+    logging.info("Inserting samples ...")
+    save_samples_data(
+        conn,
+        speciesid,
+        read_samples_file(samplesfile, separator, firstlineheading))
+    logging.info("Cross-referencing samples with their populations.")
+    cross_reference_samples(
+        conn,
+        speciesid,
+        populationid,
+        (row["Name"] for row in
+         read_samples_file(samplesfile,
+                           separator,
+                           firstlineheading,
+                           quotechar=quotechar)))
+
+    return 0
+
+if __name__ == "__main__":
+
+    def cli_args():
+        """Process the command-line arguments."""
+        #
+        parser = argparse.ArgumentParser(
+            prog="insert_samples",
+            description = (
+                "Script to parse and insert sample data from a file into the "
+                "database."))
+
+        # == Mandatory Arguments ==
+        parser.add_argument(
+            "databaseuri",
+            help="URL to be used to initialise the connection to the database")
+        parser.add_argument("speciesid",
+                            type=int,
+                            help="The species identifier in the database.")
+        parser.add_argument(
+            "populationid",
+            type=int,
+            help="The grouping/population identifier in the database.")
+        parser.add_argument(
+            "samplesfile",
+            type=pathlib.Path,
+            help="Path to the CSV file containing the samples data.")
+        parser.add_argument(
+            "separator",
+            action=SeparatorAction,
+            help="The 'character' in the CSV file that separates the fields.",
+            default=chr(9))
+
+        # == Optional Arguments ==
+        parser.add_argument(
+            "--firstlineheading",
+            action="store_true",
+            help=("If the first line of the file is a header row, invoke the "
+                  "program with this flag."))
+        parser.add_argument(
+            "--quotechar",
+            default='"',
+            help=("The character used to delimit (surround?) the value in "
+                  "each column."))
+
+        # == Script-specific extras ==
+        parser.add_argument("--redisuri",
+                            help="URL to initialise connection to redis",
+                            default="redis:///")
+
+        args = parser.parse_args()
+        return args
+
+    def main():
+        """Run script to insert samples into the database."""
+
+        args = cli_args()
+        check_db(args.databaseuri)
+        check_redis(args.redisuri)
+        if not args.samplesfile.exists():
+            logging.error("File not found: '%s'.", args.samplesfile)
+            return 2
+
+        with (Redis.from_url(args.redisuri, decode_responses=True) as rconn,
+              database_connection(args.databaseuri) as dbconn):
+            print("We got here...")
+            print(args)
+            return insert_samples(dbconn,
+                                  rconn,
+                                  args.speciesid,
+                                  args.populationid,
+                                  args.samplesfile,
+                                  args.separator,
+                                  args.firstlineheading,
+                                  args.quotechar)
+
+    sys.exit(main())