diff options
-rw-r--r-- | r_qtl/r_qtl2.py | 26 | ||||
-rw-r--r-- | tests/r_qtl/test_r_qtl2_pheno.py | 11 |
2 files changed, 28 insertions, 9 deletions
diff --git a/r_qtl/r_qtl2.py b/r_qtl/r_qtl2.py index f03aff5..1e28bc0 100644 --- a/r_qtl/r_qtl2.py +++ b/r_qtl/r_qtl2.py @@ -368,19 +368,27 @@ def read_geno_file_data( replace_genotype_codes, genocodes=cdata.get("genotypes", {}))) -def load_geno_samples(zipfilepath: Union[str, Path]) -> tuple[str, ...]: - """Load the samples/cases/individuals from the 'geno' file(s).""" +def load_samples( + zipfilepath: Union[str, Path], filetype: str) -> tuple[str, ...]: + """Load the samples/cases/individuals from file(s) of type 'filetype'.""" cdata = read_control_file(zipfilepath) - samples = set() - for genofile in cdata.get("geno", []): - gdata = read_geno_file_data(zipfilepath, genofile) - if cdata.get("geno_transposed", False): - samples.update(next(gdata)[1:]) + samples: set[str] = set() + for afile in cdata.get(filetype, []): + filedata = read_geno_file_data(zipfilepath, afile) + if cdata.get(f"{filetype}_transposed", False): + samples.update( + item for item in next(filedata)[1:] if item is not None) else: try: - next(gdata)# Ignore first row. - samples.update(line[0] for line in gdata) + next(filedata)# Ignore first row. + samples.update( + line[0] for line in filedata if line[0] is not None) except StopIteration:# Empty file. pass return tuple(samples) + + +load_geno_samples = partial(load_samples, filetype="geno") +load_founder_geno_samples = partial(load_samples, filetype="founder_geno") +load_pheno_samples = partial(load_samples, filetype="pheno") diff --git a/tests/r_qtl/test_r_qtl2_pheno.py b/tests/r_qtl/test_r_qtl2_pheno.py index c7c0c86..d31ad54 100644 --- a/tests/r_qtl/test_r_qtl2_pheno.py +++ b/tests/r_qtl/test_r_qtl2_pheno.py @@ -57,3 +57,14 @@ def test_parse_phenocovar_files(filepath, expected): with ZipFile(Path(filepath).absolute(), "r") as zfile: cdata = rqtl2.control_data(zfile) assert tuple(rqtl2.file_data(zfile, "phenocovar", cdata)) == expected + + +@pytest.mark.unit_test +@pytest.mark.parametrize( + "filepath,expected", + (("tests/r_qtl/test_files/test_pheno.zip", + ("1", "2", "3", "4", "5")), + ("tests/r_qtl/test_files/test_pheno_transposed.zip", + ("1", "2", "3", "4", "5")))) +def test_load_geno_samples(filepath, expected): + assert sorted(rqtl2.load_pheno_samples(filepath)) == sorted(expected) |