15
15
from .stateful import Stateful
16
16
17
17
18
- class StatefulRandomSamplerIterator (Iterator [list [ int ] ], Stateful ):
18
+ class StatefulRandomSamplerIterator (Iterator [int ], Stateful ):
19
19
_GENERATOR = "generator"
20
20
_YIELDED = "yielded"
21
21
@@ -25,7 +25,6 @@ def __init__(self, sampler):
25
25
self .yielded = 0
26
26
self .next_yielded = None
27
27
self .n = len (sampler .data_source )
28
- self .generator = sampler .generator
29
28
self .replacement = sampler .replacement
30
29
self .num_samples = sampler .num_samples
31
30
self .chunk_size = 32
@@ -46,7 +45,7 @@ def __next__(self):
46
45
high = self .n ,
47
46
size = (self .chunk_size ,),
48
47
dtype = torch .int64 ,
49
- generator = self .generator ,
48
+ generator = self .sampler . generator ,
50
49
).tolist ()
51
50
self .perm_index = 0
52
51
value = self .perm [self .perm_index ]
@@ -62,7 +61,7 @@ def __next__(self):
62
61
high = self .n ,
63
62
size = (remainder ,),
64
63
dtype = torch .int64 ,
65
- generator = self .generator ,
64
+ generator = self .sampler . generator ,
66
65
).tolist ()
67
66
self .perm_index = 0
68
67
value = self .perm [self .perm_index ]
@@ -78,7 +77,7 @@ def __next__(self):
78
77
remainder = self .num_samples % self .n
79
78
if self .chunk_index < num_full_perms :
80
79
if self .perm is None or not self .perm :
81
- self .perm = torch .randperm (self .n , generator = self .generator ).tolist ()
80
+ self .perm = torch .randperm (self .n , generator = self .sampler . generator ).tolist ()
82
81
self .perm_index = 0
83
82
value = self .perm [self .perm_index ]
84
83
self .perm_index += 1
@@ -89,7 +88,7 @@ def __next__(self):
89
88
return value
90
89
elif remainder > 0 :
91
90
if self .perm is None or not self .perm :
92
- self .perm = torch .randperm (self .n , generator = self .generator ).tolist ()[:remainder ]
91
+ self .perm = torch .randperm (self .n , generator = self .sampler . generator ).tolist ()[:remainder ]
93
92
self .perm_index = 0
94
93
value = self .perm [self .perm_index ]
95
94
self .perm_index += 1
@@ -109,7 +108,7 @@ def state_dict(self) -> dict:
109
108
def load_state_dict (self , state_dict : dict ) -> None :
110
109
self .next_yielded = state_dict [self ._YIELDED ]
111
110
self .generator_state = state_dict [self ._GENERATOR ]
112
- self .generator .set_state (self .generator_state )
111
+ self .sampler . generator .set_state (self .generator_state )
113
112
if self .next_yielded is not None :
114
113
for _ in range (self .next_yielded - self .yielded ):
115
114
next (self )
0 commit comments