diff options
-rw-r--r-- | quality_control/parsing.py | 14 | ||||
-rw-r--r-- | tests/qc/test_error_collection.py | 16 |
2 files changed, 26 insertions, 4 deletions
diff --git a/quality_control/parsing.py b/quality_control/parsing.py index ac53642..e9bd5f7 100644 --- a/quality_control/parsing.py +++ b/quality_control/parsing.py @@ -3,7 +3,7 @@ import csv from enum import Enum from functools import reduce -from typing import Iterator, Generator +from typing import Iterable, Generator import quality_control.average as avg import quality_control.standard_error as se @@ -124,3 +124,15 @@ def parse_errors(filepath: str, filetype: FileType, strains: list, return ( error for error in __errors(filepath, filetype, strains, seek_pos) if error is not None) + +def take(iterable: Iterable, num: int) -> list: + """Take at most `num` items from `iterable`.""" + iterator = iter(iterable) + items = [] + try: + for i in range(0, num): + items.append(next(iterator)) + + return items + except StopIteration: + return items diff --git a/tests/qc/test_error_collection.py b/tests/qc/test_error_collection.py index c45803a..f1bd8b9 100644 --- a/tests/qc/test_error_collection.py +++ b/tests/qc/test_error_collection.py @@ -1,6 +1,6 @@ import pytest -from quality_control.parsing import FileType, parse_errors +from quality_control.parsing import take, FileType, parse_errors @pytest.mark.slow @pytest.mark.parametrize( @@ -14,8 +14,7 @@ from quality_control.parsing import FileType, parse_errors FileType.STANDARD_ERROR, 0), ("tests/test_data/standarderror.tsv", FileType.STANDARD_ERROR, 0), ("tests/test_data/duplicated_headers_no_data_errors.tsv", - FileType.AVERAGE), - )) + FileType.AVERAGE, 0))) def test_parse_errors(filepath, filetype, strains, seek_pos): """ Check that only errors are returned, and that certain properties hold for @@ -28,3 +27,14 @@ def test_parse_errors(filepath, filetype, strains, seek_pos): assert "position" in error assert "error" in error and isinstance(error["error"], str) assert "message" in error + + +@pytest.mark.parametrize( + "sample,num,expected", + ((range(0,25), 5, [0, 1, 2, 3, 4]), + ([0, 1, 2, 3], 200, [0, 1, 2, 3]), + (("he", "is", "a", "lovely", "boy"), 3, ["he", "is", "a"]))) +def test_take(sample, num, expected): + taken = take(sample, num) + assert len(taken) <= num + assert taken == expected |