Skip to content

Commit 1346e44

Browse files
committed
add comment about why were updating state dict
1 parent 8a83c5f commit 1346e44

File tree

2 files changed

+170
-58
lines changed

2 files changed

+170
-58
lines changed

torchdata/stateful_dataloader/sampler.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,18 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
109109
assert isinstance(self.sampler_iter, Stateful)
110110
self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE])
111111

112-
if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance(
113-
self.sampler, _InfiniteConstantSampler
114-
):
112+
if not (
113+
isinstance(self.sampler, Stateful)
114+
or isinstance(self.sampler_iter, Stateful)
115+
) and not isinstance(self.sampler, _InfiniteConstantSampler):
115116
# We skip x samples if underlying sampler is not stateful
116117
for _ in range(self.samples_yielded):
117118
next(self.sampler_iter)
118119

119-
def update_state_dict(self):
120-
if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"):
120+
def update_state_dict(self) -> None:
121+
if isinstance(self.sampler_iter, Stateful) and hasattr(
122+
self.sampler_iter, "update_state_dict"
123+
):
121124
self.sampler_iter.update_state_dict()
122125

123126

0 commit comments

Comments
 (0)