Skip to content

Commit 0a90c04

Browse files
committed
update generator usage
1 parent 10ba260 commit 0a90c04

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

torchdata/stateful_dataloader/sampler.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .stateful import Stateful
1616

1717

18-
class StatefulRandomSamplerIterator(Iterator[list[int]], Stateful):
18+
class StatefulRandomSamplerIterator(Iterator[int], Stateful):
1919
_GENERATOR = "generator"
2020
_YIELDED = "yielded"
2121

@@ -25,7 +25,6 @@ def __init__(self, sampler):
2525
self.yielded = 0
2626
self.next_yielded = None
2727
self.n = len(sampler.data_source)
28-
self.generator = sampler.generator
2928
self.replacement = sampler.replacement
3029
self.num_samples = sampler.num_samples
3130
self.chunk_size = 32
@@ -46,7 +45,7 @@ def __next__(self):
4645
high=self.n,
4746
size=(self.chunk_size,),
4847
dtype=torch.int64,
49-
generator=self.generator,
48+
generator=self.sampler.generator,
5049
).tolist()
5150
self.perm_index = 0
5251
value = self.perm[self.perm_index]
@@ -62,7 +61,7 @@ def __next__(self):
6261
high=self.n,
6362
size=(remainder,),
6463
dtype=torch.int64,
65-
generator=self.generator,
64+
generator=self.sampler.generator,
6665
).tolist()
6766
self.perm_index = 0
6867
value = self.perm[self.perm_index]
@@ -78,7 +77,7 @@ def __next__(self):
7877
remainder = self.num_samples % self.n
7978
if self.chunk_index < num_full_perms:
8079
if self.perm is None or not self.perm:
81-
self.perm = torch.randperm(self.n, generator=self.generator).tolist()
80+
self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()
8281
self.perm_index = 0
8382
value = self.perm[self.perm_index]
8483
self.perm_index += 1
@@ -89,7 +88,7 @@ def __next__(self):
8988
return value
9089
elif remainder > 0:
9190
if self.perm is None or not self.perm:
92-
self.perm = torch.randperm(self.n, generator=self.generator).tolist()[:remainder]
91+
self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()[:remainder]
9392
self.perm_index = 0
9493
value = self.perm[self.perm_index]
9594
self.perm_index += 1
@@ -109,7 +108,7 @@ def state_dict(self) -> dict:
109108
def load_state_dict(self, state_dict: dict) -> None:
110109
self.next_yielded = state_dict[self._YIELDED]
111110
self.generator_state = state_dict[self._GENERATOR]
112-
self.generator.set_state(self.generator_state)
111+
self.sampler.generator.set_state(self.generator_state)
113112
if self.next_yielded is not None:
114113
for _ in range(self.next_yielded - self.yielded):
115114
next(self)

0 commit comments

Comments
 (0)