1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
|
"""Module handling the high-level parsing of the files"""
import os
import collections
from enum import Enum
from functools import partial
from typing import Union, Iterable, Generator, Callable
import quality_control.average as avg
import quality_control.standard_error as se
from quality_control.errors import InvalidValue
from quality_control.headers import (
invalid_header, invalid_headings, duplicate_headings)
class FileType(Enum):
"""Enumerate the expected file types"""
AVERAGE = 1
STANDARD_ERROR = 2
def strain_names(filepath):
"""Retrieve the strains names from given file"""
strains = set()
with open(filepath, encoding="utf8") as strains_file:
for idx, line in enumerate(strains_file.readlines()):
if idx > 0:
parts = line.split()
for name in (parts[1], parts[2]):
strains.add(name.strip())
if len(parts) >= 6:
alias = parts[5].strip()
if alias != "" and alias not in ("P", "\\N"):
strains.add(alias)
return strains
def header_errors(line_number, fields, strains):
"""Gather all header row errors."""
return (
(invalid_header(line_number, fields),) +
invalid_headings(line_number, strains, fields[1:]) +
duplicate_headings(line_number, fields))
def empty_value(line_number, column_number, value):
"""Check for empty field values."""
if value == "":
return InvalidValue(
line_number, column_number, value, "Empty value for column")
return None
def average_errors(line_number, fields):
"""Gather all errors for a line in a averages file."""
return (
(empty_value(line_number, 1, fields[0]),) +
tuple(
avg.invalid_value(line_number, *field)
for field in enumerate(fields[1:], start=2)))
def se_errors(line_number, fields):
"""Gather all errors for a line in a standard-errors file."""
return (
(empty_value(line_number, 1, fields[0]),) +
tuple(
se.invalid_value(line_number, *field)
for field in enumerate(fields[1:], start=2)))
def collect_errors(
filepath: str, filetype: FileType, strains: list,
update_progress: Union[Callable, None] = None) -> Generator:
"""Run checks against file and collect all the errors"""
errors = tuple()
def __process_errors__(line_number, line, error_checker_fn, errors = tuple()):
errs = error_checker_fn(
line_number,
tuple(field.strip() for field in line.split("\t")))
if errs is None:
return errors
if isinstance(errs, collections.Sequence):
return errors + tuple(error for error in errs if error is not None)
return errors + (errs,)
with open(filepath, encoding="utf-8") as input_file:
for line_number, line in enumerate(input_file, start=1):
if line_number == 1:
for error in __process_errors__(
line_number, line, partial(header_errors, strains=strains),
errors):
yield error
if line_number != 1:
for error in __process_errors__(
line_number, line, (
average_errors if filetype == FileType.AVERAGE
else se_errors),
errors):
yield error
if update_progress:
update_progress(line_number, line)
def take(iterable: Iterable, num: int) -> list:
"""Take at most `num` items from `iterable`."""
iterator = iter(iterable)
items = []
try:
for i in range(0, num): # pylint: disable=[unused-variable]
items.append(next(iterator))
return items
except StopIteration:
return items
|