about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--r_qtl/r_qtl2.py89
-rw-r--r--tests/r_qtl/test_files/test_pheno.zipbin485 -> 503 bytes
-rw-r--r--tests/r_qtl/test_files/test_pheno_transposed.zipbin536 -> 557 bytes
-rw-r--r--tests/r_qtl/test_r_qtl2_pheno.py8
4 files changed, 57 insertions, 40 deletions
diff --git a/r_qtl/r_qtl2.py b/r_qtl/r_qtl2.py
index b688404..13ac355 100644
--- a/r_qtl/r_qtl2.py
+++ b/r_qtl/r_qtl2.py
@@ -4,7 +4,7 @@ import csv
 import json
 from zipfile import ZipFile
 from functools import reduce, partial
-from typing import Iterator, Iterable, Callable
+from typing import Iterator, Iterable, Callable, Optional
 
 import yaml
 
@@ -28,6 +28,13 @@ def control_data(zfile: ZipFile) -> dict:
             if files[0].endswith(".json")
             else yaml.safe_load(zfile.read(files[0])))
 
+def replace_na_strings(cdata, val):
+    """Replace values indicated in `na.strings` with `None`."""
+    nastrings = cdata.get("na.strings")
+    if bool(nastrings):
+        return (None if val in nastrings else val)
+    return val
+
 def with_non_transposed(zfile: ZipFile,
                         member_key: str,
                         cdata: dict,
@@ -46,24 +53,27 @@ def with_non_transposed(zfile: ZipFile,
 
     sep = cdata.get("sep", ",")
     with zfile.open(cdata[member_key]) as innerfile:
-        wrapped_file = io.TextIOWrapper(innerfile)
-        firstrow = tuple(
-            field.strip() for field in
-            next(filter(not_comment_line, wrapped_file)).strip().split(sep))
-        id_key = firstrow[0]
-        wrapped_file.seek(0)
-        reader = csv.DictReader(filter(not_comment_line, wrapped_file),
-                                delimiter=sep)
-        for row in reader:
-            processed = process_value(row)
-            yield {
-                "id": processed[id_key],
-                **{
-                    key: value
-                    for key, value in processed.items()
-                    if key != id_key
+        try:
+            wrapped_file = io.TextIOWrapper(innerfile)
+            firstrow = tuple(
+                field.strip() for field in
+                next(filter(not_comment_line, wrapped_file)).strip().split(sep))
+            id_key = firstrow[0]
+            wrapped_file.seek(0)
+            reader = csv.DictReader(filter(not_comment_line, wrapped_file),
+                                    delimiter=sep)
+            for row in reader:
+                processed = process_value(row)
+                yield {
+                    "id": processed[id_key],
+                    **{
+                        key: value
+                        for key, value in processed.items()
+                        if key != id_key
+                    }
                 }
-            }
+        except StopIteration as exc:
+            raise InvalidFormat("The file has no rows!") from exc
 
 def __make_organise_by_id__(id_key):
     """Return a function to use with `reduce` to organise values by some
@@ -129,14 +139,10 @@ def make_process_data_geno(cdata) -> tuple[
     def replace_genotype_codes(val):
         return cdata["genotypes"].get(val, val)
 
-    def replace_na_strings(val):
-        nastrings = cdata.get("na.strings")
-        if bool(nastrings):
-            return (None if val in nastrings else val)
-        return val
     def __non_transposed__(row: dict) -> dict:
         return {
-            key: chain(value, replace_genotype_codes, replace_na_strings)
+            key: chain(value, replace_genotype_codes,
+                       partial(replace_na_strings, cdata))
             for key,value in row.items()
         }
     def __transposed__(id_key: str,
@@ -145,7 +151,7 @@ def make_process_data_geno(cdata) -> tuple[
         return tuple(
             dict(zip(
                 [id_key, vals[0]],
-                (chain(item, replace_genotype_codes, replace_na_strings)
+                (chain(item, replace_genotype_codes, partial(replace_na_strings, cdata))
                  for item in items)))
             for items in zip(ids, vals[1:]))
     return (__non_transposed__, __transposed__)
@@ -189,22 +195,33 @@ def make_process_data_covar(cdata) -> tuple[
             for items in zip(ids, vals[1:]))
     return (non_transposed, transposed)
 
-def __default_process_value_transposed__(
-        id_key: str,
-        ids: tuple[str, ...],
-        vals: tuple[str, ...]) -> tuple[dict, ...]:
-    """Default values processor for transposed files."""
-    return tuple(
-        dict(zip([id_key, vals[0]], items)) for items in zip(ids, vals[1:]))
-
 def file_data(zfile: ZipFile,
               member_key: str,
               cdata: dict,
-              process_value: Callable[[dict], dict] = lambda val: val,
-              process_transposed_value: Callable[
+              process_value: Optional[Callable[[dict], dict]] = None,
+              process_transposed_value: Optional[Callable[
                   [str, tuple[str, ...], tuple[str, ...]],
-                  tuple[dict, ...]] = __default_process_value_transposed__) -> Iterator[dict]:
+                  tuple[dict, ...]]] = None) -> Iterator[dict]:
     """Load data from files in R/qtl2 zip bundle."""
+    def __default_process_value_non_transposed__(val: dict) -> dict:
+        return {
+            key: replace_na_strings(cdata, value) for key,value in val.items()
+        }
+
+    def __default_process_value_transposed__(
+            id_key: str,
+            ids: tuple[str, ...],
+            vals: tuple[str, ...]) -> tuple[dict, ...]:
+        """Default values processor for transposed files."""
+        return tuple(
+            dict(zip([id_key, replace_na_strings(cdata, vals[0])], items))
+            for items in zip(
+                    ids, (replace_na_strings(cdata, val) for val in vals[1:])))
+
+    process_value = process_value or __default_process_value_non_transposed__
+    process_transposed_value = (
+        process_transposed_value or __default_process_value_transposed__)
+
     try:
         if isinstance(cdata[member_key], list):
             for row in (line for lines in
diff --git a/tests/r_qtl/test_files/test_pheno.zip b/tests/r_qtl/test_files/test_pheno.zip
index 5c709e7..ba9bbb0 100644
--- a/tests/r_qtl/test_files/test_pheno.zip
+++ b/tests/r_qtl/test_files/test_pheno.zip
Binary files differdiff --git a/tests/r_qtl/test_files/test_pheno_transposed.zip b/tests/r_qtl/test_files/test_pheno_transposed.zip
index 9bff030..e6a87aa 100644
--- a/tests/r_qtl/test_files/test_pheno_transposed.zip
+++ b/tests/r_qtl/test_files/test_pheno_transposed.zip
Binary files differdiff --git a/tests/r_qtl/test_r_qtl2_pheno.py b/tests/r_qtl/test_r_qtl2_pheno.py
index a7de675..c7c0c86 100644
--- a/tests/r_qtl/test_r_qtl2_pheno.py
+++ b/tests/r_qtl/test_r_qtl2_pheno.py
@@ -13,14 +13,14 @@ from r_qtl import r_qtl2 as rqtl2
       ({"id": "1", "liver": "61.92", "spleen": "153.16"},
        {"id": "2", "liver": "88.33", "spleen": "178.58"},
        {"id": "3", "liver": "58", "spleen": "131.91"},
-       {"id": "4", "liver": "78.06", "spleen": "126.13"},
-       {"id": "5", "liver": "65.31", "spleen": "181.05"})),
+       {"id": "4", "liver": "78.06", "spleen": None},
+       {"id": "5", "liver": None, "spleen": "181.05"})),
      ("tests/r_qtl/test_files/test_pheno_transposed.zip",
       ({"id": "1", "liver": "61.92", "spleen": "153.16"},
        {"id": "2", "liver": "88.33", "spleen": "178.58"},
        {"id": "3", "liver": "58", "spleen": "131.91"},
-       {"id": "4", "liver": "78.06", "spleen": "126.13"},
-       {"id": "5", "liver": "65.31", "spleen": "181.05"}))))
+       {"id": "4", "liver": "78.06", "spleen": None},
+       {"id": "5", "liver": None, "spleen": "181.05"}))))
 def test_parse_pheno_files(filepath, expected):
     """Test parsing of 'pheno' files from the R/qtl2 bundle.