about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-01-15 17:49:14 +0300
committerFrederick Muriuki Muriithi2024-01-15 17:49:14 +0300
commit42b43f8d46fe0c25703de914a687127726ece35e (patch)
treef30490a689be720969a1a57cd3bd92e9abf68739
parentef6da7313f96390b9fecb126f9b7e9beb1afe034 (diff)
downloadgn-uploader-42b43f8d46fe0c25703de914a687127726ece35e.tar.gz
Extract common functional tools to separate package.
-rw-r--r--functional_tools/__init__.py33
-rw-r--r--quality_control/parsing.py14
-rw-r--r--r_qtl/r_qtl2.py18
3 files changed, 42 insertions, 23 deletions
diff --git a/functional_tools/__init__.py b/functional_tools/__init__.py
new file mode 100644
index 0000000..057bd9a
--- /dev/null
+++ b/functional_tools/__init__.py
@@ -0,0 +1,33 @@
+"""Tools to help with a more functional way of doing things."""
+from typing import Iterable
+from functools import reduce
+
+def take(iterable: Iterable, num: int) -> list:
+    """Take at most `num` items from `iterable`."""
+    iterator = iter(iterable)
+    items = []
+    try:
+        for i in range(0, num): # pylint: disable=[unused-variable]
+            items.append(next(iterator))
+
+        return items
+    except StopIteration:
+        return items
+
+def chain(value, *functions):
+    """
+    Flatten nested expressions
+
+    Inspired by, and approximates, Clojure's `->`.
+
+    Useful to rewrite nested expressions like func3(a, b, func2(c, func1(d e)))
+    into arguably flatter expressions like:
+    chain(
+      d,
+      partial(func1, e=val1),
+      partial(func2, c=val2),
+      partial(func3, a=val3, b=val3))
+
+    This can probably be improved.
+    """
+    return reduce(lambda result, func: func(result), functions, value)
diff --git a/quality_control/parsing.py b/quality_control/parsing.py
index 5fc5f62..5b21716 100644
--- a/quality_control/parsing.py
+++ b/quality_control/parsing.py
@@ -12,6 +12,8 @@ from quality_control.errors import (
 from quality_control.headers import (
     invalid_header, invalid_headings, duplicate_headings)
 
+from functional_tools import take
+
 class FileType(Enum):
     """Enumerate the expected file types"""
     AVERAGE = 1
@@ -121,15 +123,3 @@ def collect_errors(
 
             if update_progress:
                 update_progress(line_number, line)
-
-def take(iterable: Iterable, num: int) -> list:
-    """Take at most `num` items from `iterable`."""
-    iterator = iter(iterable)
-    items = []
-    try:
-        for i in range(0, num): # pylint: disable=[unused-variable]
-            items.append(next(iterator))
-
-        return items
-    except StopIteration:
-        return items
diff --git a/r_qtl/r_qtl2.py b/r_qtl/r_qtl2.py
index b2a7acf..b688404 100644
--- a/r_qtl/r_qtl2.py
+++ b/r_qtl/r_qtl2.py
@@ -8,14 +8,10 @@ from typing import Iterator, Iterable, Callable
 
 import yaml
 
-from quality_control.parsing import take
+from functional_tools import take, chain
 
 from r_qtl.errors import InvalidFormat
 
-def thread_op(value, *functions):
-    """Thread the `value` through the sequence of `functions`."""
-    return reduce(lambda result, func: func(result), functions, value)
-
 def control_data(zfile: ZipFile) -> dict:
     """Retrieve the control file from the zip file info."""
     files = tuple(filename
@@ -140,7 +136,7 @@ def make_process_data_geno(cdata) -> tuple[
         return val
     def __non_transposed__(row: dict) -> dict:
         return {
-            key: thread_op(value, replace_genotype_codes, replace_na_strings)
+            key: chain(value, replace_genotype_codes, replace_na_strings)
             for key,value in row.items()
         }
     def __transposed__(id_key: str,
@@ -149,7 +145,7 @@ def make_process_data_geno(cdata) -> tuple[
         return tuple(
             dict(zip(
                 [id_key, vals[0]],
-                (thread_op(item, replace_genotype_codes, replace_na_strings)
+                (chain(item, replace_genotype_codes, replace_na_strings)
                  for item in items)))
             for items in zip(ids, vals[1:]))
     return (__non_transposed__, __transposed__)
@@ -179,7 +175,7 @@ def make_process_data_covar(cdata) -> tuple[
     rep_cross_info = partial(replace_cross_info, cdata=cdata)
     def non_transposed(row: dict) -> dict:
         return {
-            key: thread_op(value, rep_sex_info, rep_cross_info)
+            key: chain(value, rep_sex_info, rep_cross_info)
             for key,value in row.items()
         }
     def transposed(id_key: str,
@@ -188,7 +184,7 @@ def make_process_data_covar(cdata) -> tuple[
         return tuple(
             dict(zip(
                 [id_key, vals[0]],
-                (thread_op(item, rep_sex_info, rep_cross_info)
+                (chain(item, rep_sex_info, rep_cross_info)
                  for item in items)))
             for items in zip(ids, vals[1:]))
     return (non_transposed, transposed)
@@ -245,7 +241,7 @@ def cross_information(zfile: ZipFile, cdata: dict) -> Iterator[dict]:
                          new_cdata,
                          *make_process_data_covar(cdata)):
         yield {
-            key: thread_op(value, partial(replace_cross_info, cdata=cdata))
+            key: chain(value, partial(replace_cross_info, cdata=cdata))
             for key, value in row.items() if key not in sex_fields}
 
 def sex_information(zfile: ZipFile, cdata: dict) -> Iterator[dict]:
@@ -263,7 +259,7 @@ def sex_information(zfile: ZipFile, cdata: dict) -> Iterator[dict]:
                          new_cdata,
                          *make_process_data_covar(cdata)):
         yield {
-            key: thread_op(value, partial(replace_sex_info, cdata=cdata))
+            key: chain(value, partial(replace_sex_info, cdata=cdata))
             for key, value in row.items() if key not in ci_fields}
 
 def validate_bundle(zfile: ZipFile):