aboutsummaryrefslogtreecommitdiff
path: root/quality_control/parsing.py
blob: 9fe88f1ade6214a566612e4d241de34b85830693 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Module handling the high-level parsing of the files"""

import csv
from enum import Enum
from functools import reduce
from typing import Iterable, Generator

import quality_control.average as avg
import quality_control.standard_error as se
from quality_control.headers import valid_header
from quality_control.errors import (
    ParseError, DuplicateHeader, InvalidCellValue, InvalidHeaderValue)

class FileType(Enum):
    """Enumerate the expected file types"""
    AVERAGE = 1
    STANDARD_ERROR = 2

def parse_strains(filepath):
    """Parse the strains file"""
    with open(filepath, encoding="utf8") as strains_file:
        reader = csv.DictReader(
            strains_file,
            fieldnames=[
                header.strip() for header
                in strains_file.readline().split("\t")],
            delimiter="\t")
        for row in reader:
            yield {
                key: (value if value != "\\N" else None)
                for key, value in row.items()
            }

def __parse_header(line, strains):
    return valid_header(
        set(strains),
        tuple(header.strip() for header in line.split("\t")))

def __parse_average_line(line):
    return (line[0],) + tuple(avg.valid_value(field) for field in line[1:])

def __parse_standard_error_line(line):
    return (line[0],) + tuple(se.valid_value(field) for field in line[1:])

LINE_PARSERS = {
    FileType.AVERAGE: __parse_average_line,
    FileType.STANDARD_ERROR: __parse_standard_error_line
}

def strain_names(strains):
    """Retrieve a complete list of the names of the strains"""
    def __extract_strain_names(acc, strain):
        return acc + tuple(
            item for item in (strain["Name"], strain["Name2"])
            if (item is not None and item != ""))
    return reduce(__extract_strain_names, strains, tuple())

def parse_file(filepath: str, filetype: FileType, strains: list):
    """Parse the given file"""
    seek_pos = 0
    try:
        with open(filepath, encoding="utf-8") as input_file:
            for line_number, line in enumerate(input_file):
                if line_number == 0:
                    yield __parse_header(line, strains), seek_pos + len(line)
                    seek_pos = seek_pos + len(line)
                    continue

                yield (
                    LINE_PARSERS[filetype](
                        tuple(field.strip() for field in line.split("\t"))),
                    seek_pos + len(line))
                seek_pos = seek_pos + len(line)
    except (DuplicateHeader, InvalidCellValue, InvalidHeaderValue) as err:
        raise ParseError({
            "filepath": filepath,
            "filetype": filetype,
            "position": seek_pos,
            "line_number": line_number,
            "error": err
        }) from err

def parse_errors(filepath: str, filetype: FileType, strains: list,
                 seek_pos: int = 0) -> Generator:
    """Retrieve ALL the parse errors"""
    assert seek_pos >= 0, "The seek position must be at least zero (0)"

    def __error_type(error):
        """Return a nicer string representatiton for the error type."""
        if isinstance(error, DuplicateHeader):
            return "Duplicated Headers"
        if isinstance(error, InvalidCellValue):
            return "Invalid Value"
        if isinstance(error, InvalidHeaderValue):
            return "Invalid Strain"

    def __errors(filepath, filetype, strains, seek_pos):
        """Return only the errors as values"""
        with open(filepath, encoding="utf-8") as input_file:
            ## TODO: Seek the file to the given seek position
            for line_number, line in enumerate(input_file):
                if seek_pos > 0:
                    input_file.seek(seek_pos, 0)
                try:
                    if seek_pos == 0 and line_number == 0:
                        header = __parse_header(line, strains)
                        yield None
                        seek_pos = seek_pos + len(line)
                        continue

                    parsed_line = LINE_PARSERS[filetype](
                        tuple(field.strip() for field in line.split("\t")))
                    yield None
                    seek_pos = seek_pos + len(line)
                except (DuplicateHeader, InvalidCellValue, InvalidHeaderValue) as err:
                    yield {
                        "filepath": filepath,
                        "filetype": filetype,
                        "position": seek_pos,
                        "line_number": line_number,
                        "error": __error_type(err),
                        "message": err.args
                    }
                    seek_pos = seek_pos + len(line)

    return (
        error for error in __errors(filepath, filetype, strains, seek_pos)
        if error is not None)

def take(iterable: Iterable, num: int) -> list:
    """Take at most `num` items from `iterable`."""
    iterator = iter(iterable)
    items = []
    try:
        for i in range(0, num):
            items.append(next(iterator))

        return items
    except StopIteration:
        return items