Description
🐛 Bug
Bug: Inconsistent Behavior in StreamingDataLoader
After Loading States (Specific to CombinedStreamingDataset
)
Description:
The StreamingDataLoader
exhibits inconsistent behavior when handling loaded states across different scenarios. Specifically, issues arise when iterating over the dataloader after loading states with a complete or partial first epoch.
This bug is an extension of #316 for CombinedStreamingDataset
.
To Reproduce
Create Optimized Dataset
from litdata import optimize
def random_data(index):
return index
if __name__ == "__main__":
datasets = ["dataset1", "dataset2"]
for dataset in datasets:
optimize(fn=random_data, inputs=list(range(50)), output_dir=dataset, num_workers=4, chunk_bytes="64MB")
Bugs
-
IndexError raised when loading dataloader state without prior iteration
from litdata import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset if __name__ == "__main__": dataset1 = StreamingDataset("dataset1") dataset2 = StreamingDataset("dataset2") datasets = [dataset1, dataset2] combined_dataset = CombinedStreamingDataset(datasets=datasets) dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=4) dataloader.load_state_dict(dataloader.state_dict())
Output
Traceback (most recent call last): File "/Users/bhimrajyadav/litdata/test_combined_dataset.py", line 10, in <module> dataloader.load_state_dict(dataloader.state_dict()) ^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/dataloader.py", line 668, in state_dict num_samples_yieled = [0 for _ in range(len(list(self._num_samples_yielded_combined.values())[0]))] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^ IndexError: list index out of range
-
After loading the dataloader state following the completion of the first epoch, a
ValueError
is thrown (previously anIndexError
, see clearer example in issue Failed to Resume Training w/ CombinedStreamingDataset #363).from litdata import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset if __name__ == "__main__": dataset1 = StreamingDataset("dataset1") dataset2 = StreamingDataset("dataset2") datasets = [dataset1, dataset2] combined_dataset = CombinedStreamingDataset(datasets=datasets) dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=4) for batch_idx, batch in enumerate(dataloader): if batch_idx == 0: print("\nEpoch", dataloader.current_epoch) print(batch.numpy(), end=" ") dataloader.load_state_dict(dataloader.state_dict()) for batch_idx, batch in enumerate(dataloader): if batch_idx == 0: print("\nEpoch", dataloader.current_epoch) print(batch.numpy(), end=" ")
Output
File "/Users/bhimrajyadav/itdata/venv/lib/python3.12/site-packages/litdata/streaming/combined.py", line 160, in __iter__ self._iterator = _CombinedDatasetIterator( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/combined.py", line 208, in __init__ self._dataset_iters = [iter(dataset) for dataset in datasets] ^^^^^^^^^^^^^ File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/dataset.py", line 223, in __iter__ self._validate_state_dict() File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/dataset.py", line 479, in _validate_state_dict raise ValueError( ValueError: The provided `num_samples_yielded` state is greater than the dataset length. Found `51` instead of `50`.
-
After loading the dataloader state with a partially completed first epoch, the dataloader does not reset correctly upon completing the epoch.
- Additional details will be added.
Environment
- PyTorch Version (e.g., 1.0): 2.4.0
- OS (e.g., Linux): Mac OS
- How you installed PyTorch (
conda
,pip
, source): pip - Build command you used (if compiling from source):
- Python version: 3.12.4
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information: