diff options
-rw-r--r-- | quality_control/parsing.py | 62 | ||||
-rw-r--r-- | tests/qc/test_error_collection.py | 19 |
2 files changed, 80 insertions, 1 deletions
diff --git a/quality_control/parsing.py b/quality_control/parsing.py index 436c90c..70a85ed 100644 --- a/quality_control/parsing.py +++ b/quality_control/parsing.py @@ -1,13 +1,16 @@ """Module handling the high-level parsing of the files""" import csv +import collections from enum import Enum -from functools import reduce +from functools import reduce, partial from typing import Iterable, Generator import quality_control.average as avg import quality_control.standard_error as se from quality_control.headers import valid_header +from quality_control.headers import ( + invalid_header, invalid_headings, duplicate_headings) from quality_control.errors import ( ParseError, DuplicateHeader, InvalidCellValue, InvalidHeaderValue) @@ -120,6 +123,63 @@ def parse_errors(filepath: str, filetype: FileType, strains: list, error for error in __errors(filepath, filetype, strains, seek_pos) if error is not None) +def header_errors(line_number, fields, strains): + return ( + (invalid_header(line_number, fields),) + + invalid_headings(line_number, strains, fields[1:]) + + duplicate_headings(line_number, fields)) + +def empty_value(line_number, column_number, value): + if value == "": + return InvalidValue( + line_number, column_number, value, "Empty value for column") + +def average_errors(line_number, fields): + return ( + (empty_value(line_number, 1, fields[0]),) + + tuple( + avg.invalid_value(line_number, *field) + for field in enumerate(fields[1:], start=2))) + +def se_errors(line_number, fields): + return ( + (empty_value(line_number, 1, fields[0]),) + + tuple( + se.invalid_value(line_number, *field) + for field in enumerate(fields[1:], start=2))) + +def collect_errors( + filepath: str, filetype: FileType, strains: list, count: int = 10) -> Generator: + """Run checks against file and collect all the errors""" + errors = tuple() + def __process_errors__(line_number, line, error_checker_fn, errors = tuple()): + errs = error_checker_fn( + line_number, + tuple(field.strip() for field in line.split("\t"))) + if errs is None: + return errors + if isinstance(errs, collections.Sequence): + return errors + tuple(error for error in errs if error is not None) + return errors + (errs,) + + with open(filepath, encoding="utf-8") as input_file: + for line_number, line in enumerate(input_file, start=1): + if line_number == 1: + errors = __process_errors__( + line_number, line, partial(header_errors, strains=strains), + errors) + if line_number != 1: + errors = __process_errors__( + line_number, line, ( + average_errors if filetype == FileType.AVERAGE + else se_errors), + errors) + + if count > 0 and len(errors) >= count: + break + + return errors[0:count] + def take(iterable: Iterable, num: int) -> list: """Take at most `num` items from `iterable`.""" iterator = iter(iterable) diff --git a/tests/qc/test_error_collection.py b/tests/qc/test_error_collection.py index 3a26d9c..466f455 100644 --- a/tests/qc/test_error_collection.py +++ b/tests/qc/test_error_collection.py @@ -1,6 +1,7 @@ import pytest from quality_control.parsing import take, FileType, parse_errors +from quality_control.parsing import collect_errors @pytest.mark.slow @pytest.mark.parametrize( @@ -37,3 +38,21 @@ def test_take(sample, num, expected): taken = take(sample, num) assert len(taken) <= num assert taken == expected + + +## ================================================== + +@pytest.mark.slow +@pytest.mark.parametrize( + "filepath,filetype,count", + (("tests/test_data/average_crlf.tsv", FileType.AVERAGE, 10), + ("tests/test_data/average_error_at_end_200MB.tsv", FileType.AVERAGE, + 20), + ("tests/test_data/average.tsv", FileType.AVERAGE, 5), + ("tests/test_data/standarderror_1_error_at_end.tsv", + FileType.STANDARD_ERROR, 13), + ("tests/test_data/standarderror.tsv", FileType.STANDARD_ERROR, 9), + ("tests/test_data/duplicated_headers_no_data_errors.tsv", + FileType.AVERAGE, 10))) +def test_collect_errors(filepath, filetype, strains, count): + assert len(collect_errors(filepath, filetype, strains, count)) <= count |