about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-05-18 11:03:29 +0300
committerFrederick Muriuki Muriithi2022-05-18 11:03:29 +0300
commited7348ae2acefbb7806e26d5c13dfbd47ba1c9c0 (patch)
treebe294ee5dc56e361b5d584cf941de638091f315a
parent582686e030b660f218cb7091aaab3cafa103465d (diff)
downloadgn-uploader-ed7348ae2acefbb7806e26d5c13dfbd47ba1c9c0.tar.gz
Parse files with new non-exception functions
Parse the files with the new functions that return error objects
instead of raising exceptions
-rw-r--r--quality_control/parsing.py62
-rw-r--r--tests/qc/test_error_collection.py19
2 files changed, 80 insertions, 1 deletions
diff --git a/quality_control/parsing.py b/quality_control/parsing.py
index 436c90c..70a85ed 100644
--- a/quality_control/parsing.py
+++ b/quality_control/parsing.py
@@ -1,13 +1,16 @@
 """Module handling the high-level parsing of the files"""
 
 import csv
+import collections
 from enum import Enum
-from functools import reduce
+from functools import reduce, partial
 from typing import Iterable, Generator
 
 import quality_control.average as avg
 import quality_control.standard_error as se
 from quality_control.headers import valid_header
+from quality_control.headers import (
+    invalid_header, invalid_headings, duplicate_headings)
 from quality_control.errors import (
     ParseError, DuplicateHeader, InvalidCellValue, InvalidHeaderValue)
 
@@ -120,6 +123,63 @@ def parse_errors(filepath: str, filetype: FileType, strains: list,
         error for error in __errors(filepath, filetype, strains, seek_pos)
         if error is not None)
 
+def header_errors(line_number, fields, strains):
+    return (
+        (invalid_header(line_number, fields),) +
+        invalid_headings(line_number, strains, fields[1:]) +
+        duplicate_headings(line_number, fields))
+
+def empty_value(line_number, column_number, value):
+    if value == "":
+        return InvalidValue(
+            line_number, column_number, value, "Empty value for column")
+
+def average_errors(line_number, fields):
+    return (
+        (empty_value(line_number, 1, fields[0]),) +
+        tuple(
+            avg.invalid_value(line_number, *field)
+            for field in enumerate(fields[1:], start=2)))
+
+def se_errors(line_number, fields):
+    return (
+        (empty_value(line_number, 1, fields[0]),) +
+        tuple(
+            se.invalid_value(line_number, *field)
+            for field in enumerate(fields[1:], start=2)))
+
+def collect_errors(
+        filepath: str, filetype: FileType, strains: list, count: int = 10) -> Generator:
+    """Run checks against file and collect all the errors"""
+    errors = tuple()
+    def __process_errors__(line_number, line, error_checker_fn, errors = tuple()):
+        errs = error_checker_fn(
+            line_number,
+            tuple(field.strip() for field in line.split("\t")))
+        if errs is None:
+            return errors
+        if isinstance(errs, collections.Sequence):
+            return errors + tuple(error for error in errs if error is not None)
+        return errors + (errs,)
+
+    with open(filepath, encoding="utf-8") as input_file:
+        for line_number, line in enumerate(input_file, start=1):
+            if line_number == 1:
+                errors = __process_errors__(
+                    line_number, line, partial(header_errors, strains=strains),
+                    errors)
+            if line_number != 1:
+                errors = __process_errors__(
+                    line_number, line, (
+                        average_errors if filetype == FileType.AVERAGE
+                        else se_errors),
+                    errors)
+
+            if count > 0 and len(errors) >= count:
+                break
+
+    return errors[0:count]
+
 def take(iterable: Iterable, num: int) -> list:
     """Take at most `num` items from `iterable`."""
     iterator = iter(iterable)
diff --git a/tests/qc/test_error_collection.py b/tests/qc/test_error_collection.py
index 3a26d9c..466f455 100644
--- a/tests/qc/test_error_collection.py
+++ b/tests/qc/test_error_collection.py
@@ -1,6 +1,7 @@
 import pytest
 
 from quality_control.parsing import take, FileType, parse_errors
+from quality_control.parsing import collect_errors
 
 @pytest.mark.slow
 @pytest.mark.parametrize(
@@ -37,3 +38,21 @@ def test_take(sample, num, expected):
     taken = take(sample, num)
     assert len(taken) <= num
     assert taken == expected
+
+
+## ==================================================
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+    "filepath,filetype,count",
+    (("tests/test_data/average_crlf.tsv", FileType.AVERAGE, 10),
+     ("tests/test_data/average_error_at_end_200MB.tsv", FileType.AVERAGE,
+      20),
+     ("tests/test_data/average.tsv", FileType.AVERAGE, 5),
+     ("tests/test_data/standarderror_1_error_at_end.tsv",
+      FileType.STANDARD_ERROR, 13),
+     ("tests/test_data/standarderror.tsv", FileType.STANDARD_ERROR, 9),
+     ("tests/test_data/duplicated_headers_no_data_errors.tsv",
+      FileType.AVERAGE, 10)))
+def test_collect_errors(filepath, filetype, strains, count):
+    assert len(collect_errors(filepath, filetype, strains, count)) <= count