diff options
Diffstat (limited to 'quality_control/headers.py')
-rw-r--r-- | quality_control/headers.py | 33 |
1 files changed, 8 insertions, 25 deletions
diff --git a/quality_control/headers.py b/quality_control/headers.py index 3b1e0e6..79d7e43 100644 --- a/quality_control/headers.py +++ b/quality_control/headers.py @@ -4,41 +4,22 @@ from functools import reduce from typing import Union, Tuple, Sequence from quality_control.errors import InvalidValue, DuplicateHeading -from quality_control.errors import DuplicateHeader, InvalidHeaderValue - -def valid_header(strains, headers): - "Return the valid headers with reference to strains or throw an error" - if not bool(headers[1:]): - raise InvalidHeaderValue( - "The header MUST contain at least 2 columns") - invalid_headers = tuple( - header for header in headers[1:] if header not in strains) - if invalid_headers: - raise InvalidHeaderValue( - *(f"'{header}' not a valid strain." for header in invalid_headers)) - - unique_headers = set(headers) - if len(unique_headers) != len(headers): - repeated = ( - (header, headers.count(header)) - for header in unique_headers if headers.count(header) > 1) - raise DuplicateHeader(*( - f"'{header}' is present in the header row {times} times." - for header, times in repeated)) - - return headers - def invalid_header( line_number: int, headers: Sequence[str]) -> Union[InvalidValue, None]: + """Return an `InvalidValue` object if the header row has less than 2 + items.""" if len(headers) < 2: return InvalidValue( line_number, 0, "<TAB>".join(headers), "The header MUST contain at least 2 columns") + return None def invalid_headings( line_number: int, strains: Sequence[str], headings: Sequence[str]) -> Union[Tuple[InvalidValue, ...], None]: + """Return tuple of `InvalidValue` objects for each error found for every + column heading.""" return tuple( InvalidValue( line_number, col, header, f"'{header}' not a valid strain.") @@ -47,13 +28,15 @@ def invalid_headings( def duplicate_headings( line_number: int, headers: Sequence[str]) -> Union[InvalidValue, None]: + """Return a tuple of `DuplicateHeading` objects for each column heading that + is a duplicate of another column heading.""" def __update_columns__(acc, item): if item[1] in acc.keys(): return {**acc, item[1]: acc[item[1]] + (item[0],)} return {**acc, item[1]: (item[0],)} repeated = { heading: columns for heading, columns in - reduce(__update_columns__, enumerate(headers, start=1), dict()).items() + reduce(__update_columns__, enumerate(headers, start=1), {}).items() if len(columns) > 1 } return tuple( |