about summary refs log tree commit diff
path: root/scripts/run_qtlreaper.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/run_qtlreaper.py')
-rw-r--r--scripts/run_qtlreaper.py79
1 files changed, 53 insertions, 26 deletions
diff --git a/scripts/run_qtlreaper.py b/scripts/run_qtlreaper.py
index ab58203..2269ea6 100644
--- a/scripts/run_qtlreaper.py
+++ b/scripts/run_qtlreaper.py
@@ -1,14 +1,15 @@
 """Script to run rust-qtlreaper and update database with results."""
+import os
 import sys
 import csv
 import time
 import secrets
 import logging
-import traceback
 import subprocess
+import multiprocessing
 from pathlib import Path
-from typing import Union
 from functools import reduce
+from typing import Union, Iterator
 from argparse import Namespace, ArgumentParser
 
 from gn_libs import mysqldb
@@ -57,7 +58,7 @@ def reconcile_samples(
 def generate_qtlreaper_traits_file(
         outdir: Path,
         samples: tuple[str, ...],
-        traits_data: dict[str, Union[int, float]],
+        traits_data: tuple[dict[str, Union[int, float]], ...],
         filename_prefix: str = ""
 ) -> Path:
     """Generate a file for use with qtlreaper that contains the traits' data."""
@@ -66,7 +67,7 @@ def generate_qtlreaper_traits_file(
     _dialect.quoting=0
 
     _traitsfile = outdir.joinpath(
-        f"{filename_prefix}_{secrets.token_urlsafe(15)}.tsv")
+        f"{filename_prefix}_{secrets.token_urlsafe(15)}.tsv")#type: ignore[attr-defined]
     with _traitsfile.open(mode="w", encoding="utf-8") as outptr:
         writer = csv.DictWriter(
             outptr, fieldnames=("Trait",) + samples, dialect=_dialect)
@@ -80,14 +81,13 @@ def generate_qtlreaper_traits_file(
     return _traitsfile
 
 
-def parse_tsv_file(results_file: Path) -> list[dict]:
+def parse_tsv_file(results_file: Path) -> Iterator[dict]:
     """Parse the rust-qtlreaper output into usable python objects."""
     with results_file.open("r", encoding="utf-8") as readptr:
         _dialect = csv.unix_dialect()
         _dialect.delimiter = "\t"
         reader = csv.DictReader(readptr, dialect=_dialect)
-        for row in reader:
-            yield row
+        yield from reader
 
 
 def __qtls_by_trait__(qtls, current):
@@ -98,7 +98,8 @@ def __qtls_by_trait__(qtls, current):
     }
 
 
-def save_qtl_values_to_db(conn, qtls: dict):
+def save_qtl_values_to_db(conn, qtls: tuple[dict, ...]):
+    """Save computed QTLs to the database."""
     with conn.cursor() as cursor:
         cursor.executemany(
             "UPDATE PublishXRef SET "
@@ -132,11 +133,11 @@ def dispatch(args: Namespace) -> int:
                              ", ".join(_samples_not_in_genofile))
 
             # Fetch traits data: provided list, or all traits in db
-            _traitsdata = phenotypes_vector_data(
+            _traitsdata = tuple(phenotypes_vector_data(
                 conn,
                 args.species_id,
                 args.population_id,
-                xref_ids=tuple(args.xref_ids)).values()
+                xref_ids=tuple(args.xref_ids)).values())
             logger.debug("Successfully got traits data. Generating the QTLReaper's traits file…")
             _traitsfile = generate_qtlreaper_traits_file(
                 args.working_dir,
@@ -146,37 +147,63 @@ def dispatch(args: Namespace) -> int:
             logger.debug("QTLReaper's Traits file: %s", _traitsfile)
 
             _qtlreaper_main_output = args.working_dir.joinpath(
-                f"main-output-{secrets.token_urlsafe(15)}.tsv")
+                f"main-output-{secrets.token_urlsafe(15)}.tsv")#type: ignore[attr-defined]
+            _qtlreaper_permu_output = args.working_dir.joinpath(
+                f"permu-output-{secrets.token_urlsafe(15)}.tsv")
             logger.debug("Main output filename: %s", _qtlreaper_main_output)
             with subprocess.Popen(
                     ("qtlreaper",
                      "--n_permutations", "1000",
                      "--geno", _genofile,
                      "--traits", _traitsfile,
-                     "--main_output", _qtlreaper_main_output)) as _qtlreaper:
+                     "--main_output", _qtlreaper_main_output,
+                     "--permu_output", _qtlreaper_permu_output,
+                     "--threads", str(int(1+(multiprocessing.cpu_count()/2)))),
+                    env=({**os.environ, "RUST_BACKTRACE": "full"}
+                         if logger.getEffectiveLevel() == logging.DEBUG
+                         else dict(os.environ))) as _qtlreaper:
                 while _qtlreaper.poll() is None:
                     logger.debug("QTLReaper process running…")
                     time.sleep(1)
-                    results = tuple(max(qtls, key=lambda qtl: qtl["LRS"])
-                                    for qtls in
-                                    reduce(__qtls_by_trait__,
-                                           parse_tsv_file(_qtlreaper_main_output),
-                                           {}).values())
-            save_qtl_values_to_db(conn, results)
+                    results = (
+                        tuple(#type: ignore[var-annotated]
+                            max(qtls, key=lambda qtl: qtl["LRS"])
+                            for qtls in
+                            reduce(__qtls_by_trait__,
+                                   parse_tsv_file(_qtlreaper_main_output),
+                                   {}).values())
+                        if _qtlreaper_main_output.exists()
+                        else tuple())
             logger.debug("Cleaning up temporary files.")
-            _traitsfile.unlink()
-            _qtlreaper_main_output.unlink()
+
+            # short-circuits to delete file if exists
+            if _traitsfile.exists():
+                _traitsfile.unlink()
+                logger.info("Deleted generated traits' file for QTLReaper.")
+
+            if _qtlreaper_main_output.exists():
+                _qtlreaper_main_output.unlink()
+                logger.info("Deleted QTLReaper's main output file.")
+
+            if _qtlreaper_permu_output.exists():
+                _qtlreaper_permu_output.unlink()
+                logger.info("Deleted QTLReaper's permutations file.")
+
+            if _qtlreaper.returncode != 0:
+                return _qtlreaper.returncode
+
+            save_qtl_values_to_db(conn, results)
             logger.info("Successfully computed p values for %s traits.", len(_traitsdata))
-            exitcode = 0
+            return 0
         except FileNotFoundError as fnf:
-            logger.error(", ".join(fnf.args), exc_info=False)
+            logger.error(", ".join(str(arg) for arg in fnf.args), exc_info=False)
         except AssertionError as aserr:
             logger.error(", ".join(aserr.args), exc_info=False)
-        except Exception as _exc:
+        except Exception as _exc:# pylint: disable=[broad-exception-caught]
             logger.debug("Type of exception: %s", type(_exc))
             logger.error("General exception!", exc_info=True)
-        finally:
-            return exitcode
+
+        return exitcode
 
 
 if __name__ == "__main__":
@@ -205,7 +232,7 @@ if __name__ == "__main__":
                   "in the population."))
         args = parser.parse_args()
         setup_logging(logger, args.log_level)
-        
+
         return dispatch(args)
 
     sys.exit(main())