about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--r_qtl/r_qtl2.py73
-rw-r--r--tests/r_qtl/test_r_qtl2_gmap.py4
2 files changed, 43 insertions, 34 deletions
diff --git a/r_qtl/r_qtl2.py b/r_qtl/r_qtl2.py
index 16bb652..79a0656 100644
--- a/r_qtl/r_qtl2.py
+++ b/r_qtl/r_qtl2.py
@@ -63,6 +63,33 @@ def __batch_of_n__(iterable: Iterable, num):
             break
         yield items
 
+def with_transposed(zfile: ZipFile,
+                    member_key: str,
+                    cdata: dict,
+                    merge_function: Callable[
+                        [str, tuple[str, ...], tuple[str, ...]],
+                        tuple[dict, ...]]) -> Iterator[dict]:
+    """Abstracts away common file-opening for transposed R/qtl2 files."""
+    with zfile.open(cdata[member_key]) as innerfile:
+        lines = (tuple(field.strip() for field in
+                       line.strip().split(cdata.get("sep", ",")))
+                 for line in
+                 filter(lambda line: not line.startswith("#"),
+                        io.TextIOWrapper(innerfile)))
+        try:
+            id_line = next(lines)
+            id_key, headers = id_line[0], id_line[1:]
+            for _key, row in reduce(# type: ignore[var-annotated]
+                    __make_organise_by_id__(id_key),
+                    (row
+                     for batch in __batch_of_n__(lines, 300)
+                     for line in batch
+                     for row in merge_function(id_key, headers, line)),
+                    {}).items():
+                yield row
+        except StopIteration:
+            pass
+
 def genotype_data(zfile: ZipFile, cdata: dict) -> Iterator[dict]:
     """Load the genotype file, making use of the control data."""
     def replace_genotype_codes(val):
@@ -84,6 +111,7 @@ def genotype_data(zfile: ZipFile, cdata: dict) -> Iterator[dict]:
                     for key,value in row.items()
                 }):
             yield line
+        return None
 
     def __merge__(key, samples, line):
         marker = line[0]
@@ -94,26 +122,10 @@ def genotype_data(zfile: ZipFile, cdata: dict) -> Iterator[dict]:
                  for item in items)))
             for items in zip(samples, line[1:]))
 
-    if cdata.get("geno_transposed", False):
-        with zfile.open(cdata["geno"]) as genofile:
-            lines = (line.strip().split(cdata.get("sep", ","))
-                     for line in filter(lambda line: not line.startswith("#"),
-                                         io.TextIOWrapper(genofile)))
-            try:
-                id_line = next(lines)
-                id_key, samples = id_line[0], id_line[1:]
-                for _key, row in reduce(# type: ignore[var-annotated]
-                        __make_organise_by_id__(id_key),
-                        (row
-                         for batch in __batch_of_n__(lines, 300)
-                         for line in batch
-                         for row in __merge__(id_key, samples, line)),
-                        {}).items():
-                    yield row
-            except StopIteration:
-                return None
-
-def map_data(zfile: ZipFile, map_type: str, cdata: dict) -> tuple[dict, ...]:
+    for row in with_transposed(zfile, "geno", cdata, __merge__):
+        yield row
+
+def map_data(zfile: ZipFile, map_type: str, cdata: dict) -> Iterator[dict]:
     """Read gmap files to get the genome mapping data"""
     assert map_type in ("genetic-map", "physical-map"), "Invalid map type"
     map_file_key = {
@@ -125,17 +137,14 @@ def map_data(zfile: ZipFile, map_type: str, cdata: dict) -> tuple[dict, ...]:
         "physical-map": "pmap_transposed"
     }
     if not cdata.get(transposed_dict[map_type], False):
-        return tuple(with_non_transposed(zfile, map_file_key, cdata))
+        for row in with_non_transposed(zfile, map_file_key, cdata):
+            yield row
+        return None
 
-    with zfile.open(cdata[map_file_key]) as gmapfile:
-        lines = [[field.strip() for field in
-                  line.strip().split(cdata.get("sep", ","))]
-                 for line in
-                 filter(lambda line: not line.startswith("#"),
-                        io.TextIOWrapper(gmapfile))]
+    def __merge__(key, samples, line):
+        marker = line[0]
+        return tuple(dict(zip([key, marker], items))
+                     for items in zip(samples, line[1:]))
 
-    headers = tuple(line[0] for line in lines)
-    return reduce(
-        lambda gmap, row: gmap + (dict(zip(headers, row)),),
-        zip(*(line[1:] for line in lines)),
-        tuple())
+    for row in with_transposed(zfile, map_file_key, cdata, __merge__):
+        yield row
diff --git a/tests/r_qtl/test_r_qtl2_gmap.py b/tests/r_qtl/test_r_qtl2_gmap.py
index 64774c2..ba46c42 100644
--- a/tests/r_qtl/test_r_qtl2_gmap.py
+++ b/tests/r_qtl/test_r_qtl2_gmap.py
@@ -44,5 +44,5 @@ def test_parse_map_files(relpath, mapfiletype, expected):
     THEN: ensure the parsed data is as expected.
     """
     with ZipFile(Path(relpath).absolute(), "r") as zfile:
-        assert rqtl2.map_data(
-            zfile, mapfiletype, rqtl2.control_data(zfile)) == expected
+        assert tuple(rqtl2.map_data(
+            zfile, mapfiletype, rqtl2.control_data(zfile))) == expected