Skip to content

Failed to Resume Training w/ CombinedStreamingDataset #363

Closed
@schopra8

Description

@schopra8

🐛 Bug

My training run crashed so I tried to resume it from the previous PyTorch Lightning checkpoint.

When I do so, I get the following error --

[rank5]: Original Traceback (most recent call last):
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank5]:     fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank5]:     return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank5]:     self.dataset_iter = iter(dataset)
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 160, in __iter__
[rank5]:     self._iterator = _CombinedDatasetIterator(
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 208, in __init__
[rank5]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 208, in <listcomp>
[rank5]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 223, in __iter__
[rank5]:     self._validate_state_dict()
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 479, in _validate_state_dict
[rank5]:     raise ValueError(
[rank5]: ValueError: The provided `num_samples_yielded` state is greater than the dataset length. Found `46320` instead of `44976`

To Reproduce

Unsure, what a minimal example of this bug is.

Code sample

My dataset is initialized as --
combined_dataset = CombinedStreamingDataset(datasets=train_datasets, iterate_over_all=True)

I do not weight the individual datasets.

Expected behavior

Able to resume training.

Environment

  • PyTorch Version (e.g., 1.0): 2.3.1
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):NA
  • Python version: 3.10.12
  • CUDA/cuDNN version: 12.1
  • GPU models and configuration: 2 nodes, 8 H-100s per node
  • Any other relevant information:

Additional context

The bug I showed above emerges when I run on the latest version of LitData. Previously, we were training on an older version of LitData and the following error was cropping up instead --

[rank15]: Traceback (most recent call last):
[rank15]:  File "/home/sahil/project/train.py", line 89, in <module>
[rank15]:   main(config)
[rank15]:  File "/home/sahil/project/train.py", line 69, in main
[rank15]:   trainer.fit(model, datamodule=my_data_module, ckpt_path=trainer_checkpoint_path)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
[rank15]:   call._call_and_handle_interrupt(
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank15]:   return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
[rank15]:   return function(*args, **kwargs)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
[rank15]:   self._run(model, ckpt_path=ckpt_path)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _run
[rank15]:   results = self._run_stage()
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1034, in _run_stage
[rank15]:   self.fit_loop.run()
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
[rank15]:   self.advance()
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
[rank15]:   self.epoch_loop.run(self._data_fetcher)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
[rank15]:   self.advance(data_fetcher)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 212, in advance
[rank15]:   batch, _, __ = next(data_fetcher)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 133, in __next__
[rank15]:   batch = super().__next__()
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 60, in __next__
[rank15]:   batch = next(self.iterator)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 341, in __next__
[rank15]:   out = next(self._iterator)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 78, in __next__
[rank15]:   out[i] = next(self.iterators[i])
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataloader.py", line 631, in __iter__
[rank15]:   for batch in super().__iter__():
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
[rank15]:   data = self._next_data()
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1346, in _next_data
[rank15]:   return self._process_data(data)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
[rank15]:   data.reraise()
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/_utils.py", line 705, in reraise
[rank15]:   raise exception
[rank15]: IndexError: Caught IndexError in DataLoader worker process 0.
[rank15]: Original Traceback (most recent call last):
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank15]:   fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank15]:   return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank15]:   self.dataset_iter = iter(dataset)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 155, in __iter__
[rank15]:   self._iterator = _CombinedDatasetIterator(
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 203, in __init__
[rank15]:   self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 203, in <listcomp>
[rank15]:   self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 236, in __iter__
[rank15]:   self._resume(workers_chunks, workers_intervals)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 308, in _resume
[rank15]:   interval = self.worker_intervals[self.chunk_index]
[rank15]: IndexError: list index out of range

Also note that this error is happening several epochs into training with data that is stored locally (not being streamed from an S3 blob store).

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingduplicateThis issue or pull request already existshelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions