aboutsummaryrefslogtreecommitdiff
path: root/quality_control/parsing.py
blob: 28a311ec033691ff46368983cbe5974b4a1d5917 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Module handling the high-level parsing of the files"""
import collections
from enum import Enum
from functools import partial
from zipfile import ZipFile, is_zipfile
from typing import Tuple, Union, Iterable, Generator, Callable, Optional

import quality_control.average as avg
import quality_control.standard_error as se
from quality_control.errors import (
    InvalidValue, DuplicateHeading, InconsistentColumns)
from quality_control.headers import (
    invalid_header, invalid_headings, duplicate_headings)

class FileType(Enum):
    """Enumerate the expected file types"""
    AVERAGE = 1
    STANDARD_ERROR = 2

def strain_names(filepath):
    """Retrieve the strains names from given file"""
    strains = set()
    with open(filepath, encoding="utf8") as strains_file:
        for idx, line in enumerate(strains_file.readlines()):
            if idx > 0:
                parts = line.split()
                for name in (parts[1], parts[2]):
                    strains.add(name.strip())
                    if len(parts) >= 6:
                        alias = parts[5].strip()
                        if alias != "" and alias not in ("P", "\\N"):
                            strains.add(alias)

    return strains

def header_errors(line_number, fields, strains):
    """Gather all header row errors."""
    return (
        (invalid_header(line_number, fields),) +
        invalid_headings(line_number, strains, fields[1:]) +
        duplicate_headings(line_number, fields))

def empty_value(line_number, column_number, value):
    """Check for empty field values."""
    if value == "":
        return InvalidValue(
            line_number, column_number, value, "Empty value for column")
    return None

def average_errors(line_number, fields):
    """Gather all errors for a line in a averages file."""
    return (
        (empty_value(line_number, 1, fields[0]),) +
        tuple(
            avg.invalid_value(line_number, *field)
            for field in enumerate(fields[1:], start=2)))

def se_errors(line_number, fields):
    """Gather all errors for a line in a standard-errors file."""
    return (
        (empty_value(line_number, 1, fields[0]),) +
        tuple(
            se.invalid_value(line_number, *field)
            for field in enumerate(fields[1:], start=2)))

def make_column_consistency_checker(header_row):
    """Build function to check for column consistency"""
    headers = tuple(field.strip() for field in header_row.split("\t"))
    def __checker__(line_number, contents_row):
        contents = tuple(field.strip() for field in contents_row.split("\t"))
        if len(contents) != len(headers):
            return InconsistentColumns(
                line_number, len(headers), len(contents),
                (f"Header row has {len(headers)} columns while row "
                 f"{line_number} has {len(contents)} columns"))
        return None
    return __checker__

def collect_errors(
        filepath: str, filetype: FileType, strains: list,
        update_progress: Optional[Callable] = None,
        user_aborted: Callable = lambda: False) -> Generator:
    """Run checks against file and collect all the errors"""
    errors:Tuple[Union[InvalidValue, DuplicateHeading], ...] = tuple()
    def __process_errors__(line_number, line, error_checker_fn, errors = tuple()):
        errs = error_checker_fn(
            line_number,
            tuple(field.strip() for field in line.split("\t")))
        if errs is None:
            return errors
        if isinstance(errs, collections.abc.Sequence):
            return errors + tuple(error for error in errs if error is not None)
        return errors + (errs,)

    def __open_file__(filepath):
        if not is_zipfile(filepath):
            return open(filepath, encoding="utf-8")

        with ZipFile(filepath, "r") as zfile:
            return zfile.open(zfile.infolist()[0], "r")

    with __open_file__(filepath) as input_file:
        for line_number, line in enumerate(input_file, start=1):
            if user_aborted():
                break

            if isinstance(line, bytes):
                line = line.decode("utf-8")

            if line_number == 1:
                consistent_columns_checker = make_column_consistency_checker(line)
                for error in __process_errors__(
                        line_number, line, partial(header_errors, strains=strains),
                        errors):
                    yield error

            if line_number != 1:
                col_consistency_error = consistent_columns_checker(line_number, line)
                if col_consistency_error:
                    yield col_consistency_error

                for error in __process_errors__(
                        line_number, line, (
                            average_errors if filetype == FileType.AVERAGE
                            else se_errors),
                        errors):
                    yield error

            if update_progress:
                update_progress(line_number, line)

def take(iterable: Iterable, num: int) -> list:
    """Take at most `num` items from `iterable`."""
    iterator = iter(iterable)
    items = []
    try:
        for i in range(0, num): # pylint: disable=[unused-variable]
            items.append(next(iterator))

        return items
    except StopIteration:
        return items