Skip to content

Bug: Inconsistent Behavior with StreamingDataloader loading states (specific to CombinedStreamingDataset) #331

Closed
@bhimrazy

Description

@bhimrazy

🐛 Bug

Bug: Inconsistent Behavior in StreamingDataLoader After Loading States (Specific to CombinedStreamingDataset)

Description:
The StreamingDataLoader exhibits inconsistent behavior when handling loaded states across different scenarios. Specifically, issues arise when iterating over the dataloader after loading states with a complete or partial first epoch.

This bug is an extension of #316 for CombinedStreamingDataset.

To Reproduce

Create Optimized Dataset
from litdata import optimize


def random_data(index):
    return index

if __name__ == "__main__":
    datasets = ["dataset1", "dataset2"]
    for dataset in datasets:
        optimize(fn=random_data, inputs=list(range(50)), output_dir=dataset, num_workers=4, chunk_bytes="64MB")

Bugs

  1. IndexError raised when loading dataloader state without prior iteration

    from litdata import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset
    
    if __name__ == "__main__":
        dataset1 = StreamingDataset("dataset1")
        dataset2 = StreamingDataset("dataset2")
        datasets = [dataset1, dataset2]
        combined_dataset = CombinedStreamingDataset(datasets=datasets)
        dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=4)
    
        dataloader.load_state_dict(dataloader.state_dict())

    Output

    Traceback (most recent call last):
      File "/Users/bhimrajyadav/litdata/test_combined_dataset.py", line 10, in <module>
        dataloader.load_state_dict(dataloader.state_dict())
                                   ^^^^^^^^^^^^^^^^^^^^^^^
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/dataloader.py", line 668, in state_dict
        num_samples_yieled = [0 for _ in range(len(list(self._num_samples_yielded_combined.values())[0]))]
                                                   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
    IndexError: list index out of range                                        
  2. After loading the dataloader state following the completion of the first epoch, a ValueError is thrown (previously an IndexError, see clearer example in issue Failed to Resume Training w/ CombinedStreamingDataset #363).

    from litdata import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset
    
    if __name__ == "__main__":
        dataset1 = StreamingDataset("dataset1")
        dataset2 = StreamingDataset("dataset2")
        datasets = [dataset1, dataset2]
        combined_dataset = CombinedStreamingDataset(datasets=datasets)
        dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=4)
    
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx == 0:
                print("\nEpoch", dataloader.current_epoch)
            print(batch.numpy(), end=" ")
            
        dataloader.load_state_dict(dataloader.state_dict())
    
    
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx == 0:
                print("\nEpoch", dataloader.current_epoch)
            print(batch.numpy(), end=" ")

    Output

      File "/Users/bhimrajyadav/itdata/venv/lib/python3.12/site-packages/litdata/streaming/combined.py", line 160, in __iter__
        self._iterator = _CombinedDatasetIterator(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/combined.py", line 208, in __init__
        self._dataset_iters = [iter(dataset) for dataset in datasets]
                               ^^^^^^^^^^^^^
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/dataset.py", line 223, in __iter__
        self._validate_state_dict()
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/dataset.py", line 479, in _validate_state_dict
        raise ValueError(
    ValueError: The provided `num_samples_yielded` state is greater than the dataset length. Found `51` instead of `50`.                            
  3. After loading the dataloader state with a partially completed first epoch, the dataloader does not reset correctly upon completing the epoch.

    • Additional details will be added.

Environment

  • PyTorch Version (e.g., 1.0): 2.4.0
  • OS (e.g., Linux): Mac OS
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.12.4
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedExtra attention is needed

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions