aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-01-15 17:49:14 +0300
committerFrederick Muriuki Muriithi2024-01-15 17:49:14 +0300
commit42b43f8d46fe0c25703de914a687127726ece35e (patch)
treef30490a689be720969a1a57cd3bd92e9abf68739
parentef6da7313f96390b9fecb126f9b7e9beb1afe034 (diff)
downloadgn-uploader-42b43f8d46fe0c25703de914a687127726ece35e.tar.gz
Extract common functional tools to separate package.
-rw-r--r--functional_tools/__init__.py33
-rw-r--r--quality_control/parsing.py14
-rw-r--r--r_qtl/r_qtl2.py18
3 files changed, 42 insertions, 23 deletions
diff --git a/functional_tools/__init__.py b/functional_tools/__init__.py
new file mode 100644
index 0000000..057bd9a
--- /dev/null
+++ b/functional_tools/__init__.py
@@ -0,0 +1,33 @@
+"""Tools to help with a more functional way of doing things."""
+from typing import Iterable
+from functools import reduce
+
+def take(iterable: Iterable, num: int) -> list:
+ """Take at most `num` items from `iterable`."""
+ iterator = iter(iterable)
+ items = []
+ try:
+ for i in range(0, num): # pylint: disable=[unused-variable]
+ items.append(next(iterator))
+
+ return items
+ except StopIteration:
+ return items
+
+def chain(value, *functions):
+ """
+ Flatten nested expressions
+
+ Inspired by, and approximates, Clojure's `->`.
+
+ Useful to rewrite nested expressions like func3(a, b, func2(c, func1(d e)))
+ into arguably flatter expressions like:
+ chain(
+ d,
+ partial(func1, e=val1),
+ partial(func2, c=val2),
+ partial(func3, a=val3, b=val3))
+
+ This can probably be improved.
+ """
+ return reduce(lambda result, func: func(result), functions, value)
diff --git a/quality_control/parsing.py b/quality_control/parsing.py
index 5fc5f62..5b21716 100644
--- a/quality_control/parsing.py
+++ b/quality_control/parsing.py
@@ -12,6 +12,8 @@ from quality_control.errors import (
from quality_control.headers import (
invalid_header, invalid_headings, duplicate_headings)
+from functional_tools import take
+
class FileType(Enum):
"""Enumerate the expected file types"""
AVERAGE = 1
@@ -121,15 +123,3 @@ def collect_errors(
if update_progress:
update_progress(line_number, line)
-
-def take(iterable: Iterable, num: int) -> list:
- """Take at most `num` items from `iterable`."""
- iterator = iter(iterable)
- items = []
- try:
- for i in range(0, num): # pylint: disable=[unused-variable]
- items.append(next(iterator))
-
- return items
- except StopIteration:
- return items
diff --git a/r_qtl/r_qtl2.py b/r_qtl/r_qtl2.py
index b2a7acf..b688404 100644
--- a/r_qtl/r_qtl2.py
+++ b/r_qtl/r_qtl2.py
@@ -8,14 +8,10 @@ from typing import Iterator, Iterable, Callable
import yaml
-from quality_control.parsing import take
+from functional_tools import take, chain
from r_qtl.errors import InvalidFormat
-def thread_op(value, *functions):
- """Thread the `value` through the sequence of `functions`."""
- return reduce(lambda result, func: func(result), functions, value)
-
def control_data(zfile: ZipFile) -> dict:
"""Retrieve the control file from the zip file info."""
files = tuple(filename
@@ -140,7 +136,7 @@ def make_process_data_geno(cdata) -> tuple[
return val
def __non_transposed__(row: dict) -> dict:
return {
- key: thread_op(value, replace_genotype_codes, replace_na_strings)
+ key: chain(value, replace_genotype_codes, replace_na_strings)
for key,value in row.items()
}
def __transposed__(id_key: str,
@@ -149,7 +145,7 @@ def make_process_data_geno(cdata) -> tuple[
return tuple(
dict(zip(
[id_key, vals[0]],
- (thread_op(item, replace_genotype_codes, replace_na_strings)
+ (chain(item, replace_genotype_codes, replace_na_strings)
for item in items)))
for items in zip(ids, vals[1:]))
return (__non_transposed__, __transposed__)
@@ -179,7 +175,7 @@ def make_process_data_covar(cdata) -> tuple[
rep_cross_info = partial(replace_cross_info, cdata=cdata)
def non_transposed(row: dict) -> dict:
return {
- key: thread_op(value, rep_sex_info, rep_cross_info)
+ key: chain(value, rep_sex_info, rep_cross_info)
for key,value in row.items()
}
def transposed(id_key: str,
@@ -188,7 +184,7 @@ def make_process_data_covar(cdata) -> tuple[
return tuple(
dict(zip(
[id_key, vals[0]],
- (thread_op(item, rep_sex_info, rep_cross_info)
+ (chain(item, rep_sex_info, rep_cross_info)
for item in items)))
for items in zip(ids, vals[1:]))
return (non_transposed, transposed)
@@ -245,7 +241,7 @@ def cross_information(zfile: ZipFile, cdata: dict) -> Iterator[dict]:
new_cdata,
*make_process_data_covar(cdata)):
yield {
- key: thread_op(value, partial(replace_cross_info, cdata=cdata))
+ key: chain(value, partial(replace_cross_info, cdata=cdata))
for key, value in row.items() if key not in sex_fields}
def sex_information(zfile: ZipFile, cdata: dict) -> Iterator[dict]:
@@ -263,7 +259,7 @@ def sex_information(zfile: ZipFile, cdata: dict) -> Iterator[dict]:
new_cdata,
*make_process_data_covar(cdata)):
yield {
- key: thread_op(value, partial(replace_sex_info, cdata=cdata))
+ key: chain(value, partial(replace_sex_info, cdata=cdata))
for key, value in row.items() if key not in ci_fields}
def validate_bundle(zfile: ZipFile):