about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-04-25 08:08:08 +0300
committerFrederick Muriuki Muriithi2022-04-25 08:08:08 +0300
commit20c4028f567f0d4b5df1b80f1b6bea1bc99887b4 (patch)
tree49a2eaf24d6fb700de69f0cd6f264cacad62cd92
parent043d83d1cadd5b5abfa418ac02773c7a979c611d (diff)
downloadgn-uploader-20c4028f567f0d4b5df1b80f1b6bea1bc99887b4.tar.gz
`take`: function to select a few items from an iterable
To avoid processing all the items in an iterable, the `take` function
is added in this commit. It realised a limited number (specified at
call time) of items from the iterable given.
-rw-r--r--quality_control/parsing.py14
-rw-r--r--tests/qc/test_error_collection.py16
2 files changed, 26 insertions, 4 deletions
diff --git a/quality_control/parsing.py b/quality_control/parsing.py
index ac53642..e9bd5f7 100644
--- a/quality_control/parsing.py
+++ b/quality_control/parsing.py
@@ -3,7 +3,7 @@
 import csv
 from enum import Enum
 from functools import reduce
-from typing import Iterator, Generator
+from typing import Iterable, Generator
 
 import quality_control.average as avg
 import quality_control.standard_error as se
@@ -124,3 +124,15 @@ def parse_errors(filepath: str, filetype: FileType, strains: list,
     return (
         error for error in __errors(filepath, filetype, strains, seek_pos)
         if error is not None)
+
+def take(iterable: Iterable, num: int) -> list:
+    """Take at most `num` items from `iterable`."""
+    iterator = iter(iterable)
+    items = []
+    try:
+        for i in range(0, num):
+            items.append(next(iterator))
+
+        return items
+    except StopIteration:
+        return items
diff --git a/tests/qc/test_error_collection.py b/tests/qc/test_error_collection.py
index c45803a..f1bd8b9 100644
--- a/tests/qc/test_error_collection.py
+++ b/tests/qc/test_error_collection.py
@@ -1,6 +1,6 @@
 import pytest
 
-from quality_control.parsing import FileType, parse_errors
+from quality_control.parsing import take, FileType, parse_errors
 
 @pytest.mark.slow
 @pytest.mark.parametrize(
@@ -14,8 +14,7 @@ from quality_control.parsing import FileType, parse_errors
       FileType.STANDARD_ERROR, 0),
      ("tests/test_data/standarderror.tsv", FileType.STANDARD_ERROR, 0),
      ("tests/test_data/duplicated_headers_no_data_errors.tsv",
-      FileType.AVERAGE),
-     ))
+      FileType.AVERAGE, 0)))
 def test_parse_errors(filepath, filetype, strains, seek_pos):
     """
     Check that only errors are returned, and that certain properties hold for
@@ -28,3 +27,14 @@ def test_parse_errors(filepath, filetype, strains, seek_pos):
         assert "position" in error
         assert "error" in error and isinstance(error["error"], str)
         assert "message" in error
+
+
+@pytest.mark.parametrize(
+    "sample,num,expected",
+    ((range(0,25), 5, [0, 1, 2, 3, 4]),
+     ([0, 1, 2, 3], 200, [0, 1, 2, 3]),
+     (("he", "is", "a", "lovely", "boy"), 3, ["he", "is", "a"])))
+def test_take(sample, num, expected):
+    taken = take(sample, num)
+    assert len(taken) <= num
+    assert taken == expected