-
Couldn't load subscription status.
- Fork 3k
Closed
Description
Describe the bug
I've noticed a strange behaviour with Iterable state_dict: the value of shard_example_idx is always equal to the amount of samples in a shard.
Steps to reproduce the bug
I am reusing the example from the doc
from datasets import Dataset
ds = Dataset.from_dict({"a": range(6)}).to_iterable_dataset(num_shards=1)
state_dict = None
# Iterate through the dataset and print examples
for idx, example in enumerate(ds):
print(example)
if idx == 2:
state_dict = ds.state_dict()
print("checkpoint")
break
print(state_dict)Returns:
{'a': 0}
{'a': 1}
checkpoint
{'examples_iterable': {'shard_idx': 0, 'shard_example_idx': 6, 'type': 'ArrowExamplesIterable'}, 'epoch': 0}
Expected behavior
shard_example_idx should be 2 instead of 6
If we run with num_shards=2, then shard_example_idx is 3 instead of 2 and so on.
Environment info
datasetsversion: 3.4.1- Platform: macOS-14.6.1-arm64-arm-64bit
- Python version: 3.12.9
huggingface_hubversion: 0.29.3- PyArrow version: 19.0.1
- Pandas version: 2.2.3
fsspecversion: 2024.12.0
Metadata
Metadata
Assignees
Labels
No labels