"""Module handling the high-level parsing of the files""" import collections from enum import Enum from functools import partial from zipfile import ZipFile, is_zipfile from typing import Tuple, Union, Iterable, Generator, Callable, Optional import quality_control.average as avg import quality_control.standard_error as se from quality_control.errors import ( InvalidValue, DuplicateHeading, InconsistentColumns) from quality_control.headers import ( invalid_header, invalid_headings, duplicate_headings) class FileType(Enum): """Enumerate the expected file types""" AVERAGE = 1 STANDARD_ERROR = 2 def strain_names(filepath): """Retrieve the strains names from given file""" strains = set() with open(filepath, encoding="utf8") as strains_file: for idx, line in enumerate(strains_file.readlines()): if idx > 0: parts = line.split() for name in (parts[1], parts[2]): strains.add(name.strip()) if len(parts) >= 6: alias = parts[5].strip() if alias != "" and alias not in ("P", "\\N"): strains.add(alias) return strains def header_errors(line_number, fields, strains): """Gather all header row errors.""" return ( (invalid_header(line_number, fields),) + invalid_headings(line_number, strains, fields[1:]) + duplicate_headings(line_number, fields)) def empty_value(line_number, column_number, value): """Check for empty field values.""" if value == "": return InvalidValue( line_number, column_number, value, "Empty value for column") return None def average_errors(line_number, fields): """Gather all errors for a line in a averages file.""" return ( (empty_value(line_number, 1, fields[0]),) + tuple( avg.invalid_value(line_number, *field) for field in enumerate(fields[1:], start=2))) def se_errors(line_number, fields): """Gather all errors for a line in a standard-errors file.""" return ( (empty_value(line_number, 1, fields[0]),) + tuple( se.invalid_value(line_number, *field) for field in enumerate(fields[1:], start=2))) def make_column_consistency_checker(header_row): """Build function to check for column consistency""" headers = tuple(field.strip() for field in header_row.split("\t")) def __checker__(line_number, contents_row): contents = tuple(field.strip() for field in contents_row.split("\t")) if len(contents) != len(headers): return InconsistentColumns( line_number, len(headers), len(contents), (f"Header row has {len(headers)} columns while row " f"{line_number} has {len(contents)} columns")) return None return __checker__ def collect_errors( filepath: str, filetype: FileType, strains: list, update_progress: Optional[Callable] = None, user_aborted: Callable = lambda: False) -> Generator: """Run checks against file and collect all the errors""" errors:Tuple[Union[InvalidValue, DuplicateHeading], ...] = tuple() def __process_errors__(line_number, line, error_checker_fn, errors = tuple()): errs = error_checker_fn( line_number, tuple(field.strip() for field in line.split("\t"))) if errs is None: return errors if isinstance(errs, collections.abc.Sequence): return errors + tuple(error for error in errs if error is not None) return errors + (errs,) def __open_file__(filepath): if not is_zipfile(filepath): return open(filepath, encoding="utf-8") with ZipFile(filepath, "r") as zfile: return zfile.open(zfile.infolist()[0], "r") with __open_file__(filepath) as input_file: for line_number, line in enumerate(input_file, start=1): if user_aborted(): break if isinstance(line, bytes): line = line.decode("utf-8") if line_number == 1: consistent_columns_checker = make_column_consistency_checker(line) for error in __process_errors__( line_number, line, partial(header_errors, strains=strains), errors): yield error if line_number != 1: col_consistency_error = consistent_columns_checker(line_number, line) if col_consistency_error: yield col_consistency_error for error in __process_errors__( line_number, line, ( average_errors if filetype == FileType.AVERAGE else se_errors), errors): yield error if update_progress: update_progress(line_number, line) def take(iterable: Iterable, num: int) -> list: """Take at most `num` items from `iterable`.""" iterator = iter(iterable) items = [] try: for i in range(0, num): # pylint: disable=[unused-variable] items.append(next(iterator)) return items except StopIteration: return items