Skip to content

Commit 6d49b4f

Browse files
committed
update _StatefulRandomSamplerIterator
update state dict if the iterator has finished add comment about why were updating state dict run precommit
1 parent d783247 commit 6d49b4f

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

torchdata/stateful_dataloader/sampler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def __next__(self) -> int:
3232

3333
self.yielded = self.next_yielded
3434
self.next_yielded = None
35-
3635
val = next(self.parent_iterator)
3736
self.yielded += 1
3837
return val
@@ -42,6 +41,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
4241
self.sampler.generator.set_state(state_dict[self._GENERATOR])
4342
self.next_yielded = state_dict[self._YIELDED]
4443

44+
def update_state_dict(self) -> None:
45+
self.generator_state = self.sampler.generator.get_state()
46+
4547
def state_dict(self) -> Dict[str, Any]:
4648
return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded}
4749

@@ -114,6 +116,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
114116
for _ in range(self.samples_yielded):
115117
next(self.sampler_iter)
116118

119+
def update_state_dict(self) -> None:
120+
if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"):
121+
self.sampler_iter.update_state_dict()
122+
117123

118124
class BatchSampler(torch.utils.data.sampler.BatchSampler):
119125
def __init__(self, sampler, batch_size, drop_last):

torchdata/stateful_dataloader/stateful_dataloader.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,11 @@ def __next__(self):
449449
try:
450450
return super().__next__()
451451
except StopIteration:
452+
# If we are at the end of the iteration, we want to update the state dict of _sampler_iter.
453+
# because in __iter__ after self._iterator is set using self._get_iterator() [which makes self.next_iter_state = None],
454+
# it is checked if self._iterator._finished is True, and if it is, self._iterator is reset with next_iter_state = None.
455+
if hasattr(self._sampler_iter, "update_state_dict"):
456+
self._sampler_iter.update_state_dict()
452457
self._finished = True
453458
raise
454459

@@ -531,11 +536,7 @@ def load_state_dict(self, state_dict):
531536
self._sampler_iter = iter(self._index_sampler)
532537
if state_dict[_SAMPLER_ITER_STATE] is not None:
533538
self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE])
534-
if state_dict[_ITERATOR_FINISHED]:
535-
try:
536-
next(self._sampler_iter)
537-
except StopIteration:
538-
pass
539+
539540
else:
540541
if not isinstance(
541542
self._index_sampler,

0 commit comments

Comments
 (0)