aboutsummaryrefslogtreecommitdiff
path: root/quality_control/parsing.py
blob: 70a85ed4fc20c740744bfc4c26f94cc3e7f04429 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""Module handling the high-level parsing of the files"""

import csv
import collections
from enum import Enum
from functools import reduce, partial
from typing import Iterable, Generator

import quality_control.average as avg
import quality_control.standard_error as se
from quality_control.headers import valid_header
from quality_control.headers import (
    invalid_header, invalid_headings, duplicate_headings)
from quality_control.errors import (
    ParseError, DuplicateHeader, InvalidCellValue, InvalidHeaderValue)

class FileType(Enum):
    """Enumerate the expected file types"""
    AVERAGE = 1
    STANDARD_ERROR = 2

def __parse_header(line, strains):
    return valid_header(
        set(strains),
        tuple(header.strip() for header in line.split("\t")))

def __parse_average_line(line):
    return (line[0],) + tuple(avg.valid_value(field) for field in line[1:])

def __parse_standard_error_line(line):
    return (line[0],) + tuple(se.valid_value(field) for field in line[1:])

LINE_PARSERS = {
    FileType.AVERAGE: __parse_average_line,
    FileType.STANDARD_ERROR: __parse_standard_error_line
}

def strain_names(filepath):
    """Retrieve the strains names from given file"""
    strains = set()
    with open(filepath, encoding="utf8") as strains_file:
        for idx, line in enumerate(strains_file.readlines()):
            if idx > 0:
                parts = line.split()
                for name in (parts[1], parts[2]):
                    strains.add(name.strip())
                    if len(parts) >= 6:
                        alias = parts[5].strip()
                        if alias != "" and alias not in ("P", "\\N"):
                            strains.add(alias)

    return strains

def parse_file(filepath: str, filetype: FileType, strains: list):
    """Parse the given file"""
    seek_pos = 0
    try:
        with open(filepath, encoding="utf-8") as input_file:
            for line_number, line in enumerate(input_file):
                if line_number == 0:
                    yield __parse_header(line, strains), seek_pos + len(line)
                    seek_pos = seek_pos + len(line)
                    continue

                yield (
                    LINE_PARSERS[filetype](
                        tuple(field.strip() for field in line.split("\t"))),
                    seek_pos + len(line))
                seek_pos = seek_pos + len(line)
    except (DuplicateHeader, InvalidCellValue, InvalidHeaderValue) as err:
        raise ParseError({
            "filepath": filepath,
            "filetype": filetype,
            "position": seek_pos,
            "line_number": line_number,
            "error": err
        }) from err

def parse_errors(filepath: str, filetype: FileType, strains: list,
                 seek_pos: int = 0) -> Generator:
    """Retrieve ALL the parse errors"""
    assert seek_pos >= 0, "The seek position must be at least zero (0)"

    def __error_type(error):
        """Return a nicer string representatiton for the error type."""
        if isinstance(error, DuplicateHeader):
            return "Duplicated Headers"
        if isinstance(error, InvalidCellValue):
            return "Invalid Value"
        if isinstance(error, InvalidHeaderValue):
            return "Invalid Strain"

    def __errors(filepath, filetype, strains, seek_pos):
        """Return only the errors as values"""
        with open(filepath, encoding="utf-8") as input_file:
            ## TODO: Seek the file to the given seek position
            for line_number, line in enumerate(input_file):
                if seek_pos > 0:
                    input_file.seek(seek_pos, 0)
                try:
                    if seek_pos == 0 and line_number == 0:
                        header = __parse_header(line, strains)
                        yield None
                        seek_pos = seek_pos + len(line)
                        continue

                    parsed_line = LINE_PARSERS[filetype](
                        tuple(field.strip() for field in line.split("\t")))
                    yield None
                    seek_pos = seek_pos + len(line)
                except (DuplicateHeader, InvalidCellValue, InvalidHeaderValue) as err:
                    yield {
                        "filepath": filepath,
                        "filetype": filetype,
                        "position": seek_pos,
                        "line_number": line_number,
                        "error": __error_type(err),
                        "message": err.args
                    }
                    seek_pos = seek_pos + len(line)

    return (
        error for error in __errors(filepath, filetype, strains, seek_pos)
        if error is not None)

def header_errors(line_number, fields, strains):
    return (
        (invalid_header(line_number, fields),) +
        invalid_headings(line_number, strains, fields[1:]) +
        duplicate_headings(line_number, fields))

def empty_value(line_number, column_number, value):
    if value == "":
        return InvalidValue(
            line_number, column_number, value, "Empty value for column")

def average_errors(line_number, fields):
    return (
        (empty_value(line_number, 1, fields[0]),) +
        tuple(
            avg.invalid_value(line_number, *field)
            for field in enumerate(fields[1:], start=2)))

def se_errors(line_number, fields):
    return (
        (empty_value(line_number, 1, fields[0]),) +
        tuple(
            se.invalid_value(line_number, *field)
            for field in enumerate(fields[1:], start=2)))

def collect_errors(
        filepath: str, filetype: FileType, strains: list, count: int = 10) -> Generator:
    """Run checks against file and collect all the errors"""
    errors = tuple()
    def __process_errors__(line_number, line, error_checker_fn, errors = tuple()):
        errs = error_checker_fn(
            line_number,
            tuple(field.strip() for field in line.split("\t")))
        if errs is None:
            return errors
        if isinstance(errs, collections.Sequence):
            return errors + tuple(error for error in errs if error is not None)
        return errors + (errs,)

    with open(filepath, encoding="utf-8") as input_file:
        for line_number, line in enumerate(input_file, start=1):
            if line_number == 1:
                errors = __process_errors__(
                    line_number, line, partial(header_errors, strains=strains),
                    errors)
            if line_number != 1:
                errors = __process_errors__(
                    line_number, line, (
                        average_errors if filetype == FileType.AVERAGE
                        else se_errors),
                    errors)

            if count > 0 and len(errors) >= count:
                break

    return errors[0:count]

def take(iterable: Iterable, num: int) -> list:
    """Take at most `num` items from `iterable`."""
    iterator = iter(iterable)
    items = []
    try:
        for i in range(0, num):
            items.append(next(iterator))

        return items
    except StopIteration:
        return items