about summary refs log tree commit diff
path: root/quality_control/parsing.py
diff options
context:
space:
mode:
Diffstat (limited to 'quality_control/parsing.py')
-rw-r--r--quality_control/parsing.py141
1 files changed, 32 insertions, 109 deletions
diff --git a/quality_control/parsing.py b/quality_control/parsing.py
index 70a85ed..655b98a 100644
--- a/quality_control/parsing.py
+++ b/quality_control/parsing.py
@@ -1,40 +1,22 @@
 """Module handling the high-level parsing of the files"""
 
-import csv
+import os
 import collections
 from enum import Enum
-from functools import reduce, partial
-from typing import Iterable, Generator
+from functools import partial
+from typing import Union, Iterable, Generator, Callable
 
 import quality_control.average as avg
 import quality_control.standard_error as se
-from quality_control.headers import valid_header
+from quality_control.errors import InvalidValue
 from quality_control.headers import (
     invalid_header, invalid_headings, duplicate_headings)
-from quality_control.errors import (
-    ParseError, DuplicateHeader, InvalidCellValue, InvalidHeaderValue)
 
 class FileType(Enum):
     """Enumerate the expected file types"""
     AVERAGE = 1
     STANDARD_ERROR = 2
 
-def __parse_header(line, strains):
-    return valid_header(
-        set(strains),
-        tuple(header.strip() for header in line.split("\t")))
-
-def __parse_average_line(line):
-    return (line[0],) + tuple(avg.valid_value(field) for field in line[1:])
-
-def __parse_standard_error_line(line):
-    return (line[0],) + tuple(se.valid_value(field) for field in line[1:])
-
-LINE_PARSERS = {
-    FileType.AVERAGE: __parse_average_line,
-    FileType.STANDARD_ERROR: __parse_standard_error_line
-}
-
 def strain_names(filepath):
     """Retrieve the strains names from given file"""
     strains = set()
@@ -51,90 +33,22 @@ def strain_names(filepath):
 
     return strains
 
-def parse_file(filepath: str, filetype: FileType, strains: list):
-    """Parse the given file"""
-    seek_pos = 0
-    try:
-        with open(filepath, encoding="utf-8") as input_file:
-            for line_number, line in enumerate(input_file):
-                if line_number == 0:
-                    yield __parse_header(line, strains), seek_pos + len(line)
-                    seek_pos = seek_pos + len(line)
-                    continue
-
-                yield (
-                    LINE_PARSERS[filetype](
-                        tuple(field.strip() for field in line.split("\t"))),
-                    seek_pos + len(line))
-                seek_pos = seek_pos + len(line)
-    except (DuplicateHeader, InvalidCellValue, InvalidHeaderValue) as err:
-        raise ParseError({
-            "filepath": filepath,
-            "filetype": filetype,
-            "position": seek_pos,
-            "line_number": line_number,
-            "error": err
-        }) from err
-
-def parse_errors(filepath: str, filetype: FileType, strains: list,
-                 seek_pos: int = 0) -> Generator:
-    """Retrieve ALL the parse errors"""
-    assert seek_pos >= 0, "The seek position must be at least zero (0)"
-
-    def __error_type(error):
-        """Return a nicer string representatiton for the error type."""
-        if isinstance(error, DuplicateHeader):
-            return "Duplicated Headers"
-        if isinstance(error, InvalidCellValue):
-            return "Invalid Value"
-        if isinstance(error, InvalidHeaderValue):
-            return "Invalid Strain"
-
-    def __errors(filepath, filetype, strains, seek_pos):
-        """Return only the errors as values"""
-        with open(filepath, encoding="utf-8") as input_file:
-            ## TODO: Seek the file to the given seek position
-            for line_number, line in enumerate(input_file):
-                if seek_pos > 0:
-                    input_file.seek(seek_pos, 0)
-                try:
-                    if seek_pos == 0 and line_number == 0:
-                        header = __parse_header(line, strains)
-                        yield None
-                        seek_pos = seek_pos + len(line)
-                        continue
-
-                    parsed_line = LINE_PARSERS[filetype](
-                        tuple(field.strip() for field in line.split("\t")))
-                    yield None
-                    seek_pos = seek_pos + len(line)
-                except (DuplicateHeader, InvalidCellValue, InvalidHeaderValue) as err:
-                    yield {
-                        "filepath": filepath,
-                        "filetype": filetype,
-                        "position": seek_pos,
-                        "line_number": line_number,
-                        "error": __error_type(err),
-                        "message": err.args
-                    }
-                    seek_pos = seek_pos + len(line)
-
-    return (
-        error for error in __errors(filepath, filetype, strains, seek_pos)
-        if error is not None)
-
 def header_errors(line_number, fields, strains):
+    """Gather all header row errors."""
     return (
         (invalid_header(line_number, fields),) +
         invalid_headings(line_number, strains, fields[1:]) +
         duplicate_headings(line_number, fields))
 
 def empty_value(line_number, column_number, value):
+    """Check for empty field values."""
     if value == "":
         return InvalidValue(
             line_number, column_number, value, "Empty value for column")
+    return None
 
 def average_errors(line_number, fields):
+    """Gather all errors for a line in a averages file."""
     return (
         (empty_value(line_number, 1, fields[0]),) +
         tuple(
@@ -142,6 +56,7 @@ def average_errors(line_number, fields):
             for field in enumerate(fields[1:], start=2)))
 
 def se_errors(line_number, fields):
+    """Gather all errors for a line in a standard-errors file."""
     return (
         (empty_value(line_number, 1, fields[0]),) +
         tuple(
@@ -149,7 +64,8 @@ def se_errors(line_number, fields):
             for field in enumerate(fields[1:], start=2)))
 
 def collect_errors(
-        filepath: str, filetype: FileType, strains: list, count: int = 10) -> Generator:
+        filepath: str, filetype: FileType, strains: list,
+        updater: Union[Callable, None] = None) -> Generator:
     """Run checks against file and collect all the errors"""
     errors = tuple()
     def __process_errors__(line_number, line, error_checker_fn, errors = tuple()):
@@ -162,30 +78,37 @@ def collect_errors(
             return errors + tuple(error for error in errs if error is not None)
         return errors + (errs,)
 
+    filesize = os.stat(filepath).st_size
+    processed_size = 0
     with open(filepath, encoding="utf-8") as input_file:
         for line_number, line in enumerate(input_file, start=1):
             if line_number == 1:
-                errors = __process_errors__(
-                    line_number, line, partial(header_errors, strains=strains),
-                    errors)
-            if line_number != 1:
-                errors = __process_errors__(
-                    line_number, line, (
-                        average_errors if filetype == FileType.AVERAGE
-                        else se_errors),
-                    errors)
+                for error in __process_errors__(
+                        line_number, line, partial(header_errors, strains=strains),
+                        errors):
+                    yield error
 
-            if count > 0 and len(errors) >= count:
-                break
-
-    return errors[0:count]
+            if line_number != 1:
+                for error in __process_errors__(
+                        line_number, line, (
+                            average_errors if filetype == FileType.AVERAGE
+                            else se_errors),
+                        errors):
+                    yield error
+
+            processed_size = processed_size + len(line)
+            if updater:
+                updater({
+                    "line_number": line_number,
+                    "percent": (processed_size/filesize) * 100
+                })
 
 def take(iterable: Iterable, num: int) -> list:
     """Take at most `num` items from `iterable`."""
     iterator = iter(iterable)
     items = []
     try:
-        for i in range(0, num):
+        for i in range(0, num): # pylint: disable=[unused-variable]
             items.append(next(iterator))
 
         return items