@@ -89,7 +89,9 @@ def __init__(self, *args, **kwargs):
89
89
90
90
def __iter__ (self ):
91
91
if self .generator is None :
92
- self .generator = torch .Generator (device = torch .get_default_device ())
92
+ self .generator = torch .Generator (
93
+ device = torch .get_default_device () if hasattr (torch , "get_default_device" ) else "cpu"
94
+ )
93
95
self .generator .manual_seed (self .initial_seed )
94
96
95
97
# Allow `self.epoch` to modify the seed of the generator
@@ -1156,13 +1158,19 @@ def prepare_data_loader(
1156
1158
data_source = sampler .data_source ,
1157
1159
replacement = sampler .replacement ,
1158
1160
num_samples = sampler ._num_samples ,
1159
- generator = getattr (sampler , "generator" , torch .Generator (device = torch .get_default_device ())),
1161
+ generator = getattr (
1162
+ sampler ,
1163
+ "generator" ,
1164
+ torch .Generator (device = torch .get_default_device () if hasattr (torch , "get_default_device" ) else "cpu" ),
1165
+ ),
1160
1166
data_seed = data_seed ,
1161
1167
)
1162
1168
1163
1169
if isinstance (dataloader .sampler , RandomSampler ) and state .distributed_type == DistributedType .XLA :
1164
1170
# isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
1165
- generator = torch .Generator (device = torch .get_default_device ()).manual_seed (42 )
1171
+ generator = torch .Generator (
1172
+ device = torch .get_default_device () if hasattr (torch , "get_default_device" ) else "cpu"
1173
+ ).manual_seed (42 )
1166
1174
dataloader .generator = generator
1167
1175
dataloader .sampler .generator = generator
1168
1176
# No change if no multiprocess
@@ -1181,7 +1189,9 @@ def prepare_data_loader(
1181
1189
else :
1182
1190
if not use_seedable_sampler and hasattr (sampler , "generator" ):
1183
1191
if sampler .generator is None :
1184
- sampler .generator = torch .Generator (device = torch .get_default_device ())
1192
+ sampler .generator = torch .Generator (
1193
+ device = torch .get_default_device () if hasattr (torch , "get_default_device" ) else "cpu"
1194
+ )
1185
1195
synchronized_generator = sampler .generator
1186
1196
batch_sampler = dataloader .sampler if sampler_is_batch_sampler else dataloader .batch_sampler
1187
1197
new_batch_sampler = BatchSamplerShard (
0 commit comments