aboutsummaryrefslogtreecommitdiff
path: root/quality_control/parsing.py
blob: c545937bb83b167f6288d6d8206a261fec7f5c91 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Module handling the high-level parsing of the files"""
import collections
from enum import Enum
from functools import partial
from typing import Tuple, Union, Generator, Callable, Optional

import quality_control.average as avg
from quality_control.file_utils import open_file
import quality_control.standard_error as se
from quality_control.errors import (
    InvalidValue, DuplicateHeading, InconsistentColumns)
from quality_control.headers import (
    invalid_header, invalid_headings, duplicate_headings)

class FileType(Enum):
    """Enumerate the expected file types"""
    AVERAGE = 1
    STANDARD_ERROR = 2

def strain_names(filepath):
    """Retrieve the strains names from given file"""
    strains = set()
    with open(filepath, encoding="utf8") as strains_file:
        for idx, line in enumerate(strains_file.readlines()):
            if idx > 0:
                parts = line.split()
                for name in (parts[1], parts[2]):
                    strains.add(name.strip())
                    if len(parts) >= 6:
                        alias = parts[5].strip()
                        if alias != "" and alias not in ("P", "\\N"):
                            strains.add(alias)

    return strains

def header_errors(line_number, fields, strains):
    """Gather all header row errors."""
    return (
        (invalid_header(line_number, fields),) +
        invalid_headings(line_number, strains, fields[1:]) +
        duplicate_headings(line_number, fields))

def empty_value(line_number, column_number, value):
    """Check for empty field values."""
    if value == "":
        return InvalidValue(
            line_number, column_number, value, "Empty value for column")
    return None

def average_errors(line_number, fields):
    """Gather all errors for a line in a averages file."""
    return (
        (empty_value(line_number, 1, fields[0]),) +
        tuple(
            avg.invalid_value(line_number, *field)
            for field in enumerate(fields[1:], start=2)))

def se_errors(line_number, fields):
    """Gather all errors for a line in a standard-errors file."""
    return (
        (empty_value(line_number, 1, fields[0]),) +
        tuple(
            se.invalid_value(line_number, *field)
            for field in enumerate(fields[1:], start=2)))

def make_column_consistency_checker(header_row):
    """Build function to check for column consistency"""
    headers = tuple(field.strip() for field in header_row.split("\t"))
    def __checker__(line_number, contents_row):
        contents = tuple(field.strip() for field in contents_row.split("\t"))
        if len(contents) != len(headers):
            return InconsistentColumns(
                line_number, len(headers), len(contents),
                (f"Header row has {len(headers)} columns while row "
                 f"{line_number} has {len(contents)} columns"))
        return None
    return __checker__

def collect_errors(
        filepath: str, filetype: FileType, strains: list,
        update_progress: Optional[Callable] = None,
        user_aborted: Callable = lambda: False) -> Generator:
    """Run checks against file and collect all the errors"""
    errors:Tuple[Union[InvalidValue, DuplicateHeading], ...] = tuple()
    def __process_errors__(line_number, line, error_checker_fn, errors = tuple()):
        errs = error_checker_fn(
            line_number,
            tuple(field.strip() for field in line.split("\t")))
        if errs is None:
            return errors
        if isinstance(errs, collections.abc.Sequence):
            return errors + tuple(error for error in errs if error is not None)
        return errors + (errs,)

    with open_file(filepath) as input_file:
        for line_number, line in enumerate(input_file, start=1):
            if user_aborted():
                break

            if isinstance(line, bytes):
                line = line.decode("utf-8")

            if line_number == 1:
                consistent_columns_checker = make_column_consistency_checker(line)
                for error in __process_errors__(
                        line_number, line, partial(header_errors, strains=strains),
                        errors):
                    yield error

            if line_number != 1:
                col_consistency_error = consistent_columns_checker(line_number, line)
                if col_consistency_error:
                    yield col_consistency_error

                for error in __process_errors__(
                        line_number, line, (
                            average_errors if filetype == FileType.AVERAGE
                            else se_errors),
                        errors):
                    yield error

            if update_progress:
                update_progress(line_number, line)