about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--r_qtl/r_qtl2_qc.py64
1 files changed, 39 insertions, 25 deletions
diff --git a/r_qtl/r_qtl2_qc.py b/r_qtl/r_qtl2_qc.py
index 8d4fc19..4b3e184 100644
--- a/r_qtl/r_qtl2_qc.py
+++ b/r_qtl/r_qtl2_qc.py
@@ -2,7 +2,7 @@
 import re
 from zipfile import ZipFile
 from functools import reduce
-from typing import Union, Sequence, Iterator
+from typing import Union, Sequence, Iterator, Optional, Callable
 
 from r_qtl import errors as rqe
 from r_qtl import r_qtl2 as rqtl2
@@ -59,40 +59,54 @@ def validate_bundle(zfile: ZipFile):
                         "The following files do not exist in the bundle: " +
                         ", ".join(missing))
 
+def make_genocode_checker(genocode: dict) -> Callable[[int, str, str], Optional[InvalidValue]]:
+    """Make a checker from the genotypes in the control data"""
+    def __checker__(lineno: int, field: str, value: str) -> Optional[InvalidValue]:
+        genotypes = tuple(genocode.keys())
+        if value not in genotypes:
+            return InvalidValue(lineno, field, value, (
+                f"Invalid value '{value}'. Expected one of {genotypes}."))
+        return None
+    return __checker__
+
 def geno_errors(zfile: ZipFile) -> Iterator[Union[InvalidValue, MissingFile]]:
     """Check for and retrieve geno errors."""
     cdata = rqtl2.control_data(zfile)
-    genotypes = tuple(cdata.get("genotypes", {}).keys())
-    try:
-        for lineno, row in enumerate(
-                rqtl2.file_data(zfile, "geno", cdata), start=1):
-            for field, value in row.items():
-                if field == "id":
-                    continue
-                if value is not None and value not in genotypes:
-                    yield InvalidValue(lineno, field, value, (
-                        f"Invalid value '{value}'. Expected one of "
-                        f"{genotypes}."))
-    except rqe.MissingFileError:
-        fname = cdata.get("geno")
-        yield MissingFile("geno", fname, f"Missing 'geno' file '{fname}'.")
+    return (
+        error for error in retrieve_errors(
+            zfile, "geno", (make_genocode_checker(cdata.get("genotypes", {})),))
+        if error is not None)
 
 def pheno_errors(zfile: ZipFile) -> Iterator[Union[InvalidValue, MissingFile]]:
     """Check for and retrieve pheno errors."""
+    def __min_3_decimal_places__(
+            lineno: int, field: str, value: str) -> Optional[InvalidValue]:
+        if not (re.search(r"^([0-9]+\.[0-9]{3,}|[0-9]+\.?0*)$", value)
+                or re.search(r"^0\.0+$", value)
+                or re.search("^0+$", value)):
+            return InvalidValue(lineno, field, value, (
+                f"Invalid value '{value}'. Expected numerical value "
+                "with at least 3 decimal places."))
+        return None
+    return (
+        error for error in retrieve_errors(
+            zfile, "pheno", (__min_3_decimal_places__,))
+        if error is not None)
+
+def retrieve_errors(zfile: ZipFile, filetype: str, checkers: tuple[Callable]) -> Iterator[
+        Union[InvalidValue, MissingFile]]:
+    """Check for and retrieve errors from files of type `filetype`."""
+    assert filetype in __FILE_TYPES__, f"Invalid file type {filetype}."
     cdata = rqtl2.control_data(zfile)
     try:
         for lineno, row in enumerate(
-                rqtl2.file_data(zfile, "pheno", cdata), start=1):
+                rqtl2.file_data(zfile, filetype, cdata), start=1):
             for field, value in row.items():
                 if field == "id":
                     continue
-                if value is not None and not(
-                        re.search(r"^([0-9]+\.[0-9]{3,}|[0-9]+\.?0*)$", value)
-                        or re.search(r"^0\.0+$", value)
-                        or re.search("^0+$", value)):
-                    yield InvalidValue(lineno, field, value, (
-                        f"Invalid value '{value}'. Expected numerical value "
-                        "with at least 3 decimal places."))
+                if value is not None:
+                    for checker in checkers:
+                        yield checker(lineno, field, value)
     except rqe.MissingFileError:
-        fname = cdata.get("pheno")
-        yield MissingFile("pheno", fname, f"Missing 'pheno' file '{fname}'.")
+        fname = cdata.get(filetype)
+        yield MissingFile(filetype, fname, f"Missing '{filetype}' file '{fname}'.")