"""Module handling the high-level parsing of the files""" import csv import collections from enum import Enum from functools import reduce, partial from typing import Iterable, Generator import quality_control.average as avg import quality_control.standard_error as se from quality_control.headers import valid_header from quality_control.headers import ( invalid_header, invalid_headings, duplicate_headings) from quality_control.errors import ( ParseError, DuplicateHeader, InvalidCellValue, InvalidHeaderValue) class FileType(Enum): """Enumerate the expected file types""" AVERAGE = 1 STANDARD_ERROR = 2 def __parse_header(line, strains): return valid_header( set(strains), tuple(header.strip() for header in line.split("\t"))) def __parse_average_line(line): return (line[0],) + tuple(avg.valid_value(field) for field in line[1:]) def __parse_standard_error_line(line): return (line[0],) + tuple(se.valid_value(field) for field in line[1:]) LINE_PARSERS = { FileType.AVERAGE: __parse_average_line, FileType.STANDARD_ERROR: __parse_standard_error_line } def strain_names(filepath): """Retrieve the strains names from given file""" strains = set() with open(filepath, encoding="utf8") as strains_file: for idx, line in enumerate(strains_file.readlines()): if idx > 0: parts = line.split() for name in (parts[1], parts[2]): strains.add(name.strip()) if len(parts) >= 6: alias = parts[5].strip() if alias != "" and alias not in ("P", "\\N"): strains.add(alias) return strains def parse_file(filepath: str, filetype: FileType, strains: list): """Parse the given file""" seek_pos = 0 try: with open(filepath, encoding="utf-8") as input_file: for line_number, line in enumerate(input_file): if line_number == 0: yield __parse_header(line, strains), seek_pos + len(line) seek_pos = seek_pos + len(line) continue yield ( LINE_PARSERS[filetype]( tuple(field.strip() for field in line.split("\t"))), seek_pos + len(line)) seek_pos = seek_pos + len(line) except (DuplicateHeader, InvalidCellValue, InvalidHeaderValue) as err: raise ParseError({ "filepath": filepath, "filetype": filetype, "position": seek_pos, "line_number": line_number, "error": err }) from err def parse_errors(filepath: str, filetype: FileType, strains: list, seek_pos: int = 0) -> Generator: """Retrieve ALL the parse errors""" assert seek_pos >= 0, "The seek position must be at least zero (0)" def __error_type(error): """Return a nicer string representatiton for the error type.""" if isinstance(error, DuplicateHeader): return "Duplicated Headers" if isinstance(error, InvalidCellValue): return "Invalid Value" if isinstance(error, InvalidHeaderValue): return "Invalid Strain" def __errors(filepath, filetype, strains, seek_pos): """Return only the errors as values""" with open(filepath, encoding="utf-8") as input_file: ## TODO: Seek the file to the given seek position for line_number, line in enumerate(input_file): if seek_pos > 0: input_file.seek(seek_pos, 0) try: if seek_pos == 0 and line_number == 0: header = __parse_header(line, strains) yield None seek_pos = seek_pos + len(line) continue parsed_line = LINE_PARSERS[filetype]( tuple(field.strip() for field in line.split("\t"))) yield None seek_pos = seek_pos + len(line) except (DuplicateHeader, InvalidCellValue, InvalidHeaderValue) as err: yield { "filepath": filepath, "filetype": filetype, "position": seek_pos, "line_number": line_number, "error": __error_type(err), "message": err.args } seek_pos = seek_pos + len(line) return ( error for error in __errors(filepath, filetype, strains, seek_pos) if error is not None) def header_errors(line_number, fields, strains): return ( (invalid_header(line_number, fields),) + invalid_headings(line_number, strains, fields[1:]) + duplicate_headings(line_number, fields)) def empty_value(line_number, column_number, value): if value == "": return InvalidValue( line_number, column_number, value, "Empty value for column") def average_errors(line_number, fields): return ( (empty_value(line_number, 1, fields[0]),) + tuple( avg.invalid_value(line_number, *field) for field in enumerate(fields[1:], start=2))) def se_errors(line_number, fields): return ( (empty_value(line_number, 1, fields[0]),) + tuple( se.invalid_value(line_number, *field) for field in enumerate(fields[1:], start=2))) def collect_errors( filepath: str, filetype: FileType, strains: list, count: int = 10) -> Generator: """Run checks against file and collect all the errors""" errors = tuple() def __process_errors__(line_number, line, error_checker_fn, errors = tuple()): errs = error_checker_fn( line_number, tuple(field.strip() for field in line.split("\t"))) if errs is None: return errors if isinstance(errs, collections.Sequence): return errors + tuple(error for error in errs if error is not None) return errors + (errs,) with open(filepath, encoding="utf-8") as input_file: for line_number, line in enumerate(input_file, start=1): if line_number == 1: errors = __process_errors__( line_number, line, partial(header_errors, strains=strains), errors) if line_number != 1: errors = __process_errors__( line_number, line, ( average_errors if filetype == FileType.AVERAGE else se_errors), errors) if count > 0 and len(errors) >= count: break return errors[0:count] def take(iterable: Iterable, num: int) -> list: """Take at most `num` items from `iterable`.""" iterator = iter(iterable) items = [] try: for i in range(0, num): items.append(next(iterator)) return items except StopIteration: return items