about summary refs log tree commit diff
path: root/quality_control
diff options
context:
space:
mode:
Diffstat (limited to 'quality_control')
-rw-r--r--quality_control/average.py15
-rw-r--r--quality_control/headers.py5
-rw-r--r--quality_control/parsing.py12
-rw-r--r--quality_control/standard_error.py15
-rw-r--r--quality_control/utils.py12
5 files changed, 31 insertions, 28 deletions
diff --git a/quality_control/average.py b/quality_control/average.py
index 06b0a47..2b098db 100644
--- a/quality_control/average.py
+++ b/quality_control/average.py
@@ -1,19 +1,16 @@
 """Contain logic for checking average files"""
-import re
 from typing import Union
 
+from .utils import cell_error
 from .errors import InvalidValue
 
 def invalid_value(line_number: int, column_number: int, val: str) -> Union[
         InvalidValue, None]:
     """Return an `InvalidValue` object if `val` is not a valid "averages"
     value."""
-    if re.search(r"^[0-9]+\.[0-9]{3}$", val):
-        return None
-    if re.search(r"^0\.0+$", val) or re.search("^0+$", val):
-        return None
-    return InvalidValue(
-        line_number, column_number, val, (
+    return cell_error(
+        r"^[0-9]+\.[0-9]{3}$", val, line=line_number, column=column_number,
+        value=val, message=(
             f"Invalid value '{val}'. "
-            "Expected string representing a number with exactly three decimal "
-            "places."))
+            "Expected string representing a number with exactly three "
+            "decimal places."))
diff --git a/quality_control/headers.py b/quality_control/headers.py
index 79d7e43..f4f4dad 100644
--- a/quality_control/headers.py
+++ b/quality_control/headers.py
@@ -27,14 +27,15 @@ def invalid_headings(
         enumerate(headings, start=2) if header not in strains)
 
 def duplicate_headings(
-        line_number: int, headers: Sequence[str]) -> Union[InvalidValue, None]:
+        line_number: int,
+        headers: Sequence[str]) -> Tuple[DuplicateHeading, ...]:
     """Return a tuple of `DuplicateHeading` objects for each column heading that
     is a duplicate of another column heading."""
     def __update_columns__(acc, item):
         if item[1] in acc.keys():
             return {**acc, item[1]: acc[item[1]] + (item[0],)}
         return {**acc, item[1]: (item[0],)}
-    repeated = {
+    repeated = {# type: ignore[var-annotated]
         heading: columns for heading, columns in
         reduce(__update_columns__, enumerate(headers, start=1), {}).items()
         if len(columns) > 1
diff --git a/quality_control/parsing.py b/quality_control/parsing.py
index f1f4f79..ba22e0c 100644
--- a/quality_control/parsing.py
+++ b/quality_control/parsing.py
@@ -1,15 +1,13 @@
 """Module handling the high-level parsing of the files"""
-
-import os
 import collections
 from enum import Enum
 from functools import partial
 from zipfile import ZipFile, is_zipfile
-from typing import Iterable, Generator, Callable, Optional
+from typing import Tuple, Union, Iterable, Generator, Callable, Optional
 
 import quality_control.average as avg
 import quality_control.standard_error as se
-from quality_control.errors import InvalidValue
+from quality_control.errors import InvalidValue, DuplicateHeading
 from quality_control.headers import (
     invalid_header, invalid_headings, duplicate_headings)
 
@@ -67,16 +65,16 @@ def se_errors(line_number, fields):
 def collect_errors(
         filepath: str, filetype: FileType, strains: list,
         update_progress: Optional[Callable] = None,
-        user_aborted: Optional[Callable] = lambda: False) -> Generator:
+        user_aborted: Callable = lambda: False) -> Generator:
     """Run checks against file and collect all the errors"""
-    errors = tuple()
+    errors:Tuple[Union[InvalidValue, DuplicateHeading], ...] = tuple()
     def __process_errors__(line_number, line, error_checker_fn, errors = tuple()):
         errs = error_checker_fn(
             line_number,
             tuple(field.strip() for field in line.split("\t")))
         if errs is None:
             return errors
-        if isinstance(errs, collections.Sequence):
+        if isinstance(errs, collections.abc.Sequence):
             return errors + tuple(error for error in errs if error is not None)
         return errors + (errs,)
 
diff --git a/quality_control/standard_error.py b/quality_control/standard_error.py
index aa7df3c..7e059ad 100644
--- a/quality_control/standard_error.py
+++ b/quality_control/standard_error.py
@@ -1,7 +1,7 @@
 """Contain logic for checking standard error files"""
-import re
 from typing import Union
 
+from .utils import cell_error
 from .errors import InvalidValue
 
 def invalid_value(
@@ -12,12 +12,9 @@ def invalid_value(
     `val` is not a valid input for standard error files, otherwise, it returns
     `None`.
     """
-    if re.search(r"^[0-9]+\.[0-9]{6,}$", val):
-        return None
-    if re.search(r"^0\.0+$", val) or re.search("^0+$", val):
-        return None
-    return InvalidValue(
-        line_number, column_number, val, (
+    return cell_error(
+        r"^[0-9]+\.[0-9]{6,}$", val, line=line_number, column=column_number,
+        value=val, message=(
             f"Invalid value '{val}'. "
-            "Expected string representing a number with at least six decimal "
-            "places."))
+            "Expected string representing a number with at least six "
+            "decimal places."))
diff --git a/quality_control/utils.py b/quality_control/utils.py
index 0072608..67301d6 100644
--- a/quality_control/utils.py
+++ b/quality_control/utils.py
@@ -1,7 +1,9 @@
 """Utilities that might be useful elsewhere."""
-
+import re
 from collections import namedtuple
 
+from .errors import InvalidValue
+
 ProgressIndicator = namedtuple(
     "ProgressIndicator", ("filesize", "processedsize", "currentline", "percent"))
 
@@ -19,3 +21,11 @@ def make_progress_calculator(filesize: int):
             ((processedsize/filesize) * 100))
 
     return __calculator__
+
+def cell_error(pattern, val, **error_kwargs):
+    "Return the error in the cell"
+    if re.search(pattern, val):
+        return None
+    if re.search(r"^0\.0+$", val) or re.search("^0+$", val):
+        return None
+    return InvalidValue(**error_kwargs)