-
Couldn't load subscription status.
- Fork 3k
Description
When resuming from a checkpoint, IterableDataset will drop samples if num_shards % world_size == 0 and the underlying example supports iter_arrow and needs to be formatted.
In that case, the FormattedExamplesIterable fetches a batch of samples from the child iterable's iter_arrow and yields them one by one (after formatting). However, the child increments the shard_example_idx counter (in its iter_arrow) before returning the batch for the whole batch size, which leads to a portion of samples being skipped if the iteration (of the parent iterable) is stopped mid-batch.
Perhaps one way to avoid this would be by signalling the child iterable which samples (within the chunk) are processed by the parent and which are not, so that it can adjust the shard_example_idx counter accordingly. This would also mean the chunk needs to be sliced when resuming, but this is straightforward to implement.
The following is a minimal reproducer of the bug:
from datasets import Dataset
from datasets.distributed import split_dataset_by_node
ds = Dataset.from_dict({"n": list(range(24))})
ds = ds.to_iterable_dataset(num_shards=4)
world_size = 4
rank = 0
ds_rank = split_dataset_by_node(ds, rank, world_size)
it = iter(ds_rank)
examples = []
for idx, example in enumerate(it):
examples.append(example)
if idx == 2:
state_dict = ds_rank.state_dict()
break
ds_rank.load_state_dict(state_dict)
it_resumed = iter(ds_rank)
examples_resumed = examples[:]
for example in it:
examples.append(example)
for example in it_resumed:
examples_resumed.append(example)
print("ORIGINAL ITER EXAMPLES:", examples)
print("RESUMED ITER EXAMPLES:", examples_resumed)