Skip to content

Bugfix: inconsistent streaming dataloader state (specific to StreamingDataset) #318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3658d64
chore: Add reset_state_dict method to StreamingDataset
bhimrazy Aug 9, 2024
8eb1c7f
chore: Update num_workers fallback value in StreamingDataset
bhimrazy Aug 9, 2024
10c10b3
fix: Reset dataset state after each epoch
bhimrazy Aug 9, 2024
391c68b
update
tchaton Aug 9, 2024
5d74ed8
Update src/litdata/streaming/dataset.py
tchaton Aug 9, 2024
7412064
feat: Add test for dataloader with loading states
bhimrazy Aug 9, 2024
0290a30
chore: Add test for dataloader with loading states with peristent wor…
bhimrazy Aug 9, 2024
00c2928
rm commment
bhimrazy Aug 9, 2024
25a87b7
🐛 fix: restore only if there are any remaining batches/samples to str…
bhimrazy Aug 11, 2024
678c3fc
added notes to checkout later
bhimrazy Aug 11, 2024
532dacd
Merge branch 'main' into bugfix/316-streaming-dataloader-state
bhimrazy Aug 11, 2024
9866992
add note
bhimrazy Aug 11, 2024
16bc40f
chore: Add test for dataloader resuming after completing last epoch
bhimrazy Aug 11, 2024
d3f9498
feat: Add test for resuming dataloader with new dataset
bhimrazy Aug 11, 2024
6769694
adds type ignore
bhimrazy Aug 11, 2024
81bc537
update timeout and num of samples
bhimrazy Aug 11, 2024
998fe5a
Add explicit test for resuming dataloader with new dataset
bhimrazy Aug 11, 2024
61120a4
chore: add validation for num_samples_yielded
bhimrazy Aug 11, 2024
faa0213
Merge branch 'main' into bugfix/316-streaming-dataloader-state
bhimrazy Aug 12, 2024
d98681c
removed unrequired test, as it was testing for wrong thing, when rese…
bhimrazy Aug 12, 2024
743f0dd
removed the unnecesssary todo
bhimrazy Aug 12, 2024
2db07e0
chore: Add restore flag to dataloader tests
bhimrazy Aug 12, 2024
fc3a960
chore: Add restore flag to dataloader for StreamingDataset
bhimrazy Aug 13, 2024
4a50cac
update
bhimrazy Aug 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ def set_drop_last(self, drop_last: bool) -> None:
for dataset in self._datasets:
dataset.set_drop_last(drop_last)

def reset_state_dict(self) -> None:
"""Reset the state of the dataset."""
for dataset in self._datasets:
dataset.reset_state_dict()

def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
if any(not isinstance(d, StreamingDataset) for d in datasets):
raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")
Expand Down
13 changes: 12 additions & 1 deletion src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ def __iter__(self) -> Any:
self.current_epoch += 1
self._num_samples_yielded_combined = {}
self._num_samples_yielded_streaming = 0
self.dataset.reset_state_dict()

self.dataset.set_epoch(self.current_epoch)

Expand Down Expand Up @@ -700,13 +701,23 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:

# Inform we are resuming and disable resetting the StreamingDataLoader state.
# This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes.
self.restore = True
# self.restore = True

if isinstance(self.dataset, CombinedStreamingDataset):
self.dataset._set_use_streaming_dataloader(True)
self.dataset.load_state_dict(obj)

# Inform that the dataloader is resuming.
# TODO: Check if the number of samples yielded is less than the length of the dataset.
# Also, len is not available for CombinedStreamingDataset incase of provided weights.
self.restore = True

elif isinstance(self.dataset, StreamingDataset):
self.dataset.load_state_dict(obj["dataset"])

# Inform that the dataloader is resuming.
if self._num_samples_yielded_streaming < len(self.dataset):
self.restore = True
else:
raise RuntimeError("The provided dataset should be a `StreamingDataset` or a `CombinedStreamingDataset`.")

Expand Down
14 changes: 12 additions & 2 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,14 @@ def __next__(self) -> Any:
# Prevent to create more batch on a given process
if self.global_index >= self.stop_length:
self.current_epoch += 1
self.reset_state_dict()
raise StopIteration

# Lazily re-populate the interval to reduce memory usage.
if len(self.current_indexes) == 0:
if self.chunk_index == self.num_chunks:
self.current_epoch += 1
self.reset_state_dict()
raise StopIteration

# reset index
Expand Down Expand Up @@ -392,7 +394,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int

return {
"num_samples_yielded": num_samples_yielded,
"num_workers": num_workers,
"num_workers": num_workers or 1,
"batch_size": batch_size,
"current_epoch": self.current_epoch,
"input_dir_path": self.input_dir.path,
Expand All @@ -411,13 +413,15 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# the state is restored within the workers
self._state_dict = state_dict

def reset_state_dict(self) -> None:
self._state_dict = None

def _validate_state_dict(self) -> None:
assert self._state_dict
assert self.worker_env
assert self.cache

state: Dict[str, Any] = self._state_dict

if state["shuffle"] != self.shuffle:
raise ValueError(
"The provided `shuffle` state doesn't match the current one. "
Expand Down Expand Up @@ -471,6 +475,12 @@ def _validate_state_dict(self) -> None:
f"Found `{self.drop_last}` instead of `{state['drop_last']}`."
)

if state["num_samples_yielded"] > len(self):
raise ValueError(
"The provided `num_samples_yielded` state is greater than the dataset length. "
f"Found `{state['num_samples_yielded']}` instead of `{len(self)}`."
)

def reset(self) -> None:
# undo all the properties associated with original dataset
default_properties: Dict[str, Any] = {
Expand Down
3 changes: 3 additions & 0 deletions tests/streaming/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class TestCombinedStreamingDataset(CombinedStreamingDataset):
def _check_datasets(self, datasets) -> None:
pass

def reset_state_dict(self):
pass


def test_combined_dataset_num_samples_yield():
dataset = TestCombinedStreamingDataset(
Expand Down
128 changes: 128 additions & 0 deletions tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class TestCombinedStreamingDataset(CombinedStreamingDataset):
def _check_datasets(self, datasets) -> None:
pass

def reset_state_dict(self):
pass


def test_streaming_dataloader():
dataset = TestCombinedStreamingDataset(
Expand Down Expand Up @@ -202,3 +205,128 @@ def test_dataloader_no_workers(tmpdir):
assert len(dataset) == 1000
assert len(dataloader) == 1000
assert len(dataset) == 1000


@pytest.mark.timeout(120)
def test_dataloader_with_loading_states(tmpdir):
cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB")
for i in range(100):
cache[i] = i
cache.done()
cache.merge()

dataset = StreamingDataset(str(tmpdir), shuffle=True)

# Test dataloader without explicit num workers
dataloader = StreamingDataLoader(dataset, batch_size=4)
dataloader.load_state_dict(dataloader.state_dict())
batch = next(iter(dataloader))
assert len(batch) == 4, "Batch size should be 4"
assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)"

# Test dataloader with num workers
dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2)
assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)"

# Verify dataloader state after partial iteration
for batch_idx, batch in enumerate(dataloader):
assert dataloader.current_epoch == 1, "Current epoch should be 1"
if batch_idx == 10:
break
dataloader.load_state_dict(dataloader.state_dict())
assert dataloader.restore
# Verify remaining batches in the first epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 1, "Current epoch should be 1"
count += 1
assert count == 15, "There should be atleast 15 batches remaining in the first epoch"
assert not dataloader.restore

# Verify batches in the second epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2"
count += 1
assert count >= 25, "There should be at least 25 batches in the second epoch"

# Verify that the datalaoder can resume after complete last epoch
dataloader.load_state_dict(dataloader.state_dict())
assert not dataloader.restore
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 3, "Current epoch should be 3"
count += 1
assert count >= 25, "There should be at least 25 batches in the third epoch"


@pytest.mark.timeout(120)
def test_dataloader_states_with_persistent_workers(tmpdir):
cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB")
for i in range(100):
cache[i] = i
cache.done()
cache.merge()

dataset = StreamingDataset(str(tmpdir), shuffle=True)

dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2)
assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)"

# Verify dataloader state after partial iteration
for batch_idx, batch in enumerate(dataloader):
assert dataloader.current_epoch == 1, "Current epoch should be 1"
if batch_idx == 10:
break

prev_dataloader_state = dataloader.state_dict()
dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2, persistent_workers=True)
dataloader.load_state_dict(prev_dataloader_state)
assert dataloader.restore

# Verify remaining batches in the first epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 1, "Current epoch should be 1"
count += 1
assert count == 15, "There should be atleast 15 batches remaining in the first epoch"
assert not dataloader.restore

# Verify batches in the second epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2"
count += 1
assert count >= 25, "There should be at least 25 batches in the second epoch"

# Verify that the datalaoder can resume after complete last epoch
dataloader.load_state_dict(dataloader.state_dict())
assert not dataloader.restore
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 3, "Current epoch should be 3"
count += 1
assert count >= 25, "There should be at least 25 batches in the third epoch"


@pytest.mark.timeout(60)
def test_resume_dataloader_with_new_dataset(tmpdir):
dataset_1_path = tmpdir.join("dataset_1")
dataset_2_path = tmpdir.join("dataset_2")
for dataset in [dataset_1_path, dataset_2_path]:
cache = Cache(input_dir=str(dataset), chunk_bytes="64MB")
for i in range(50):
cache[i] = i
cache.done()
cache.merge()
dataset = StreamingDataset(str(dataset_1_path), shuffle=True)
dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2)
for _ in dataloader:
assert dataloader.current_epoch == 1, "Current epoch should be 1"

dataloader_state = dataloader.state_dict()
dataset = StreamingDataset(str(dataset_2_path), shuffle=True)
dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2)
dataloader.load_state_dict(dataloader_state)
for _ in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2"
Loading