"""Module handling the high-level parsing of the files""" import collections from enum import Enum from pathlib import Path from functools import partial from typing import Tuple, Union, Generator, Callable, Optional import MySQLdb as mdb from MySQLdb.cursors import DictCursor import quality_control.average as avg from quality_control.file_utils import open_file import quality_control.standard_error as se from quality_control.errors import ( InvalidValue, DuplicateHeading, InconsistentColumns) from quality_control.headers import ( invalid_header, invalid_headings, duplicate_headings) class FileType(Enum): """Enumerate the expected file types""" AVERAGE = 1 STANDARD_ERROR = 2 def strain_names(dbconn: mdb.Connection, speciesid: int) -> tuple[str, ...]: """Retrieve samples/cases from database.""" with dbconn.cursor(cursorclass=DictCursor) as cursor: cursor.execute("SELECT * FROM Strain WHERE SpeciesId=%s", (speciesid,)) samplenames = ((row["Name"], row["Name2"]) for row in cursor.fetchall()) return tuple(set(filter( lambda item: bool(item.strip() if item is not None else item), (name for names in samplenames for name in names)))) def header_errors(filename, line_number, fields, strains): """Gather all header row errors.""" return ( (invalid_header(filename, line_number, fields),) + invalid_headings(filename, line_number, strains, fields[1:]) + duplicate_headings(filename, line_number, fields)) def empty_value(filename, line_number, column_number, value): """Check for empty field values.""" if value == "": return InvalidValue(filename, line_number, column_number, value, "Empty value for column") return None def average_errors(filename, line_number, fields): """Gather all errors for a line in a averages file.""" return ( (empty_value(filename, line_number, 1, fields[0]),) + tuple( avg.invalid_value(filename, line_number, *field) for field in enumerate(fields[1:], start=2))) def se_errors(filename, line_number, fields): """Gather all errors for a line in a standard-errors file.""" return ( (empty_value(filename, line_number, 1, fields[0]),) + tuple( se.invalid_value(filename, line_number, *field) for field in enumerate(fields[1:], start=2))) def make_column_consistency_checker(filename, header_row): """Build function to check for column consistency""" headers = tuple(field.strip() for field in header_row.split("\t")) def __checker__(line_number, contents_row): contents = tuple(field.strip() for field in contents_row.split("\t")) if len(contents) != len(headers): return InconsistentColumns( filename, line_number, len(headers), len(contents), (f"Header row has {len(headers)} columns while row " f"{line_number} has {len(contents)} columns")) return None return __checker__ def collect_errors( filepath: str, filetype: FileType, strains: list, update_progress: Optional[Callable] = None, user_aborted: Callable = lambda: False) -> Generator: """Run checks against file and collect all the errors""" errors:Tuple[Union[InvalidValue, DuplicateHeading], ...] = tuple() def __process_errors__( filename, line_number, line, error_checker_fn, errors = tuple()): errs = error_checker_fn( filename, line_number, tuple(field.strip() for field in line.split("\t"))) if errs is None: return errors if isinstance(errs, collections.abc.Sequence): return errors + tuple(error for error in errs if error is not None) return errors + (errs,) with open_file(filepath) as input_file: filename = Path(filepath).name for line_number, line in enumerate(input_file, start=1): if user_aborted(): break if isinstance(line, bytes): line = line.decode("utf-8") if line_number == 1: consistent_columns_checker = make_column_consistency_checker( filename, line) for error in __process_errors__( filename, line_number, line, partial(header_errors, strains=strains), errors): yield error if line_number != 1: col_consistency_error = consistent_columns_checker(line_number, line) if col_consistency_error: yield col_consistency_error for error in __process_errors__( filename, line_number, line, ( average_errors if filetype == FileType.AVERAGE else se_errors), errors): yield error if update_progress: update_progress(line_number, line)