about summary refs log tree commit diff
path: root/scripts/rqtl2/install_genotypes.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/rqtl2/install_genotypes.py')
-rw-r--r--scripts/rqtl2/install_genotypes.py80
1 files changed, 18 insertions, 62 deletions
diff --git a/scripts/rqtl2/install_genotypes.py b/scripts/rqtl2/install_genotypes.py
index a1609a0..354bff0 100644
--- a/scripts/rqtl2/install_genotypes.py
+++ b/scripts/rqtl2/install_genotypes.py
@@ -1,26 +1,22 @@
 """Load genotypes from R/qtl2 bundle into the database."""
 import sys
-import uuid
 import logging
 import traceback
 from pathlib import Path
 from zipfile import ZipFile
 from functools import reduce
-from typing import Union, Iterator
-from argparse import ArgumentParser
+from typing import Iterator, Optional
 
 import MySQLdb as mdb
-from redis import Redis
 from MySQLdb.cursors import DictCursor
 
 from r_qtl import r_qtl2 as rqtl2
 
 from quality_control.parsing import take
 
-from qc_app.db_utils import database_connection
-from qc_app.check_connections import check_db, check_redis
-
-from scripts.redis_logger import RedisLogger
+from scripts.rqtl2.entry import build_main
+from scripts.cli_parser import init_cli_parser
+from scripts.rqtl2.cli_parser import add_common_arguments
 
 stderr_handler = logging.StreamHandler(stream=sys.stderr)
 logger = logging.getLogger("install_genotypes")
@@ -29,7 +25,7 @@ logger.addHandler(stderr_handler)
 def insert_markers(dbconn: mdb.Connection,
                    speciesid: int,
                    markers: tuple[str, ...],
-                   pmapdata: Union[Iterator[dict], None]) -> int:
+                   pmapdata: Optional[Iterator[dict]]) -> int:
     """Insert genotype and genotype values into the database."""
     mdata = reduce(#type: ignore[var-annotated]
         lambda acc, row: ({#type: ignore[arg-type, return-value]
@@ -129,7 +125,7 @@ def cross_reference_genotypes(dbconn: mdb.Connection,
                               speciesid: int,
                               datasetid: int,
                               dataids: tuple[dict, ...],
-                              gmapdata: Union[Iterator[dict], None]) -> int:
+                              gmapdata: Optional[Iterator[dict]]) -> int:
     """Cross-reference the data to the relevant dataset."""
     _rows, markers, mdata = reduce(#type: ignore[var-annotated]
         lambda acc, row: (#type: ignore[return-value,arg-type]
@@ -221,59 +217,19 @@ if __name__ == "__main__":
 
     def cli_args():
         """Process command-line arguments for install_genotypes"""
-        parser = ArgumentParser(
-            prog="install_genotypes",
-            description="Parse genotypes from R/qtl2 bundle into the database.")
-
-        parser.add_argument("databaseuri", help="URL to MariaDB")
-        parser.add_argument("redisuri", help="URL to Redis")
-        parser.add_argument("jobid",
-                            help="Job ID that this belongs to.",
-                            type=uuid.UUID)
-
-        parser.add_argument("speciesid",
-                            help="Species to which bundle relates.")
-        parser.add_argument("populationid",
-                            help="Population to group data under")
-        parser.add_argument("datasetid",
-                            help="The dataset to which the data belongs.")
-        parser.add_argument("rqtl2bundle",
-                            help="Path to R/qtl2 bundle zip file.",
-                            type=Path)
-
-        parser.add_argument("--redisexpiry",
-                            help="How long to keep any redis keys around.",
-                            type=int,
-                            default=86400)
+        parser = add_common_arguments(init_cli_parser(
+            "install_genotypes",
+            "Parse genotypes from R/qtl2 bundle into the database."))
 
         return parser.parse_args()
 
-    def main():
-        """Run `install_genotypes` scripts."""
-        args = cli_args()
-        check_db(args.databaseuri)
-        check_redis(args.redisuri)
-        if not args.rqtl2bundle.exists():
-            logging.error("File not found: '%s'.", args.rqtl2bundle)
-            return 2
-
-        with (Redis.from_url(args.redisuri, decode_responses=True) as rconn,
-              database_connection(args.databaseuri) as dbconn):
-            formatter = logging.Formatter(
-                "%(asctime)s - %(name)s - %(levelname)s: %(message)s")
-            job_messagelist = f"{str(args.jobid)}:log-messages"
-            rconn.hset(name=str(args.jobid),
-                       key="log-messagelist",
-                       value=job_messagelist)
-            redislogger = RedisLogger(
-                rconn, args.jobid, expiry=args.redisexpiry)
-            redislogger.setFormatter(formatter)
-            logger.addHandler(redislogger)
-            logger.setLevel("INFO")
-            return install_genotypes(dbconn,
-                                     args.speciesid,
-                                     args.populationid,
-                                     args.datasetid,
-                                     args.rqtl2bundle)
-
+    main = build_main(
+        cli_args,
+        lambda dbconn, args: install_genotypes(dbconn,
+                                               args.speciesid,
+                                               args.populationid,
+                                               args.datasetid,
+                                               args.rqtl2bundle),
+        logger,
+        "INFO")
     sys.exit(main())