Skip to content

Commit 48bf488

Browse files
authored
Add support for grain.IterDataset in sampling (apple#980)
1 parent d4b563c commit 48bf488

File tree

2 files changed

+55
-30
lines changed

2 files changed

+55
-30
lines changed

axlearn/common/input_grain.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def sample_from_datasets(
131131
sources: Sequence[Dataset],
132132
weights: Sequence[float],
133133
) -> Dataset:
134-
"""Mixes one or more data sources.
134+
"""Mixes one or more repeated data sources.
135135
136136
Different from `input_tf_data.sample_from_datasets`, the mixing is deterministic:
137137
https://github.com/google/grain/blob/ddf825c68b6d2c811f9e599d7fb7ae7572affd8c/grain/_src/python/dataset/transformations/mix.py#L222
@@ -148,22 +148,24 @@ def sample_from_datasets(
148148
A Dataset for the mixed data source.
149149
"""
150150

151-
# Without repeat, mixing stops as soon as the first dataset is exhausted.
152-
def maybe_repeat(ds: Dataset):
153-
if not isinstance(ds, grain.MapDataset):
154-
raise ValueError(
155-
f"{sample_from_datasets.__name__} requires {grain.MapDataset.__name__}"
156-
)
157-
# Only repeat if not already infinite.
158-
if len(ds) != sys.maxsize:
159-
ds = ds.repeat()
160-
return ds
161-
162-
# TODO(markblee): Support mixing grain.IterDataset.
163-
return grain.MapDataset.mix(
164-
datasets=[maybe_repeat(source) for source in sources],
165-
weights=weights,
166-
)
151+
def _ensure_repeated(sources: Sequence[Dataset]):
152+
# There is no easy way to check if a grain.IterDataset is repeated.
153+
for source in sources:
154+
if isinstance(source, grain.MapDataset) and len(source) != sys.maxsize:
155+
raise ValueError(
156+
f"sample_from_datasets requires each dataset to be repeated, {source} is not."
157+
)
158+
if isinstance(source, grain.IterDataset):
159+
logging.info(
160+
"Sampling from grain.IterDataset, please make sure your dataset is repeated."
161+
)
162+
163+
_ensure_repeated(sources)
164+
# If any of the datasets are grain.IterDataset, we should use grain.IterDataset.mix().
165+
if any(isinstance(ds, grain.IterDataset) for ds in sources):
166+
return grain.IterDataset.mix(datasets=sources, weights=weights)
167+
168+
return grain.MapDataset.mix(datasets=sources, weights=weights)
167169

168170

169171
def default_pad_example_fn(example: utils.Nested[Any]) -> utils.Nested[Any]:

axlearn/common/input_grain_test.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,37 +85,65 @@ def check(starting_step, ds):
8585
dict(
8686
sources=[slice(0, 10, 2), slice(1, 5, 2)],
8787
weights=[1, 1],
88+
is_iter_dataset=[False, False],
8889
take=10,
8990
expected=[0, 1, 2, 3, 4, 1, 6, 3, 8, 1],
9091
),
9192
dict(
9293
sources=[slice(0, 10, 2), slice(1, 5, 2)],
9394
weights=[2, 1],
95+
is_iter_dataset=[False, False],
9496
take=10,
9597
expected=[0, 2, 1, 4, 6, 3, 8, 0, 2, 1],
9698
),
9799
dict(
98100
sources=[slice(0, 10, 2), slice(1, 5, 2)],
99101
weights=[1, 1e-9],
102+
is_iter_dataset=[False, False],
100103
take=10,
101104
expected=[0, 2, 4, 6, 8, 0, 2, 4, 6, 8],
102105
),
106+
# IterDataset
107+
dict(
108+
sources=[slice(0, 10, 2), slice(1, 5, 2)],
109+
weights=[1, 1],
110+
is_iter_dataset=[True, True],
111+
take=10,
112+
expected=[0, 1, 2, 3, 4, 1, 6, 3, 8, 1],
113+
),
114+
# Mixture of IterDataset and MapDataset.
115+
dict(
116+
sources=[slice(0, 10, 2), slice(1, 5, 2)],
117+
weights=[1, 1],
118+
is_iter_dataset=[True, False],
119+
take=10,
120+
expected=[0, 1, 2, 3, 4, 1, 6, 3, 8, 1],
121+
),
103122
)
104123
def test_sample_from_datasets(
105124
self,
106125
sources: list[slice],
107126
weights: list[int],
127+
is_iter_dataset: list[bool],
108128
take: Optional[int],
109129
expected: list[int],
110130
):
131+
sources = [
132+
range_dataset(start=src.start, stop=src.stop, step=src.step).repeat() for src in sources
133+
]
134+
sources = [
135+
source.to_iter_dataset() if should_convert else source
136+
for source, should_convert in zip(sources, is_iter_dataset)
137+
]
111138
ds = sample_from_datasets(
112-
sources=[
113-
range_dataset(start=src.start, stop=src.stop, step=src.step) for src in sources
114-
],
139+
sources=sources,
115140
weights=weights,
116141
)
117-
ds = ds.slice(slice(0, take))
118-
self.assertCountEqual(expected, list(ds))
142+
ds_iter = iter(ds)
143+
result = []
144+
for _ in range(take):
145+
result.append(next(ds_iter))
146+
self.assertCountEqual(expected, list(result))
119147

120148
def test_sample_from_datasets_errors(self):
121149
ds = range_dataset(start=0, stop=2)
@@ -124,17 +152,12 @@ def test_sample_from_datasets_errors(self):
124152
repeated_ds = sample_from_datasets(sources=[ds], weights=[1]).slice(slice(0, 4))
125153
self.assertEqual([0, 1, 0, 1], list(repeated_ds))
126154

127-
# Make sure that non-map dataset raises.
128-
with self.assertRaisesRegex(ValueError, "MapDataset"):
129-
ds = ds.to_iter_dataset()
130-
sample_from_datasets(sources=[ds], weights=[1])
131-
132155
def test_shuffle_dataset(self):
133156
# Test without repeat.
134157
ds = sample_from_datasets(
135158
sources=[
136-
range_dataset(start=0, stop=10, step=2),
137-
range_dataset(start=1, stop=5, step=2),
159+
range_dataset(start=0, stop=10, step=2).repeat(),
160+
range_dataset(start=1, stop=5, step=2).repeat(),
138161
],
139162
weights=[2, 1],
140163
)
@@ -174,7 +197,7 @@ def test_slice_dataset(self, s: slice, expected: list[int]):
174197

175198
def test_batch(self):
176199
# [0, 1, 2, 3, 4].
177-
ds = range_dataset(start=0, stop=5, seed=123)
200+
ds = range_dataset(start=0, stop=5, seed=123).repeat()
178201
# [1, 2, 3, 4, 5].
179202
other_ds = ds.map(_PlusOne())
180203
# [0, 1, 2, 1, 3, 4, 2, 5, 1, 3, ...].

0 commit comments

Comments
 (0)