From 42b43f8d46fe0c25703de914a687127726ece35e Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Mon, 15 Jan 2024 17:49:14 +0300 Subject: Extract common functional tools to separate package. --- functional_tools/__init__.py | 33 +++++++++++++++++++++++++++++++++ quality_control/parsing.py | 14 ++------------ r_qtl/r_qtl2.py | 18 +++++++----------- 3 files changed, 42 insertions(+), 23 deletions(-) create mode 100644 functional_tools/__init__.py diff --git a/functional_tools/__init__.py b/functional_tools/__init__.py new file mode 100644 index 0000000..057bd9a --- /dev/null +++ b/functional_tools/__init__.py @@ -0,0 +1,33 @@ +"""Tools to help with a more functional way of doing things.""" +from typing import Iterable +from functools import reduce + +def take(iterable: Iterable, num: int) -> list: + """Take at most `num` items from `iterable`.""" + iterator = iter(iterable) + items = [] + try: + for i in range(0, num): # pylint: disable=[unused-variable] + items.append(next(iterator)) + + return items + except StopIteration: + return items + +def chain(value, *functions): + """ + Flatten nested expressions + + Inspired by, and approximates, Clojure's `->`. + + Useful to rewrite nested expressions like func3(a, b, func2(c, func1(d e))) + into arguably flatter expressions like: + chain( + d, + partial(func1, e=val1), + partial(func2, c=val2), + partial(func3, a=val3, b=val3)) + + This can probably be improved. + """ + return reduce(lambda result, func: func(result), functions, value) diff --git a/quality_control/parsing.py b/quality_control/parsing.py index 5fc5f62..5b21716 100644 --- a/quality_control/parsing.py +++ b/quality_control/parsing.py @@ -12,6 +12,8 @@ from quality_control.errors import ( from quality_control.headers import ( invalid_header, invalid_headings, duplicate_headings) +from functional_tools import take + class FileType(Enum): """Enumerate the expected file types""" AVERAGE = 1 @@ -121,15 +123,3 @@ def collect_errors( if update_progress: update_progress(line_number, line) - -def take(iterable: Iterable, num: int) -> list: - """Take at most `num` items from `iterable`.""" - iterator = iter(iterable) - items = [] - try: - for i in range(0, num): # pylint: disable=[unused-variable] - items.append(next(iterator)) - - return items - except StopIteration: - return items diff --git a/r_qtl/r_qtl2.py b/r_qtl/r_qtl2.py index b2a7acf..b688404 100644 --- a/r_qtl/r_qtl2.py +++ b/r_qtl/r_qtl2.py @@ -8,14 +8,10 @@ from typing import Iterator, Iterable, Callable import yaml -from quality_control.parsing import take +from functional_tools import take, chain from r_qtl.errors import InvalidFormat -def thread_op(value, *functions): - """Thread the `value` through the sequence of `functions`.""" - return reduce(lambda result, func: func(result), functions, value) - def control_data(zfile: ZipFile) -> dict: """Retrieve the control file from the zip file info.""" files = tuple(filename @@ -140,7 +136,7 @@ def make_process_data_geno(cdata) -> tuple[ return val def __non_transposed__(row: dict) -> dict: return { - key: thread_op(value, replace_genotype_codes, replace_na_strings) + key: chain(value, replace_genotype_codes, replace_na_strings) for key,value in row.items() } def __transposed__(id_key: str, @@ -149,7 +145,7 @@ def make_process_data_geno(cdata) -> tuple[ return tuple( dict(zip( [id_key, vals[0]], - (thread_op(item, replace_genotype_codes, replace_na_strings) + (chain(item, replace_genotype_codes, replace_na_strings) for item in items))) for items in zip(ids, vals[1:])) return (__non_transposed__, __transposed__) @@ -179,7 +175,7 @@ def make_process_data_covar(cdata) -> tuple[ rep_cross_info = partial(replace_cross_info, cdata=cdata) def non_transposed(row: dict) -> dict: return { - key: thread_op(value, rep_sex_info, rep_cross_info) + key: chain(value, rep_sex_info, rep_cross_info) for key,value in row.items() } def transposed(id_key: str, @@ -188,7 +184,7 @@ def make_process_data_covar(cdata) -> tuple[ return tuple( dict(zip( [id_key, vals[0]], - (thread_op(item, rep_sex_info, rep_cross_info) + (chain(item, rep_sex_info, rep_cross_info) for item in items))) for items in zip(ids, vals[1:])) return (non_transposed, transposed) @@ -245,7 +241,7 @@ def cross_information(zfile: ZipFile, cdata: dict) -> Iterator[dict]: new_cdata, *make_process_data_covar(cdata)): yield { - key: thread_op(value, partial(replace_cross_info, cdata=cdata)) + key: chain(value, partial(replace_cross_info, cdata=cdata)) for key, value in row.items() if key not in sex_fields} def sex_information(zfile: ZipFile, cdata: dict) -> Iterator[dict]: @@ -263,7 +259,7 @@ def sex_information(zfile: ZipFile, cdata: dict) -> Iterator[dict]: new_cdata, *make_process_data_covar(cdata)): yield { - key: thread_op(value, partial(replace_sex_info, cdata=cdata)) + key: chain(value, partial(replace_sex_info, cdata=cdata)) for key, value in row.items() if key not in ci_fields} def validate_bundle(zfile: ZipFile): -- cgit v1.2.3