@@ -24,7 +24,7 @@ def __getitem__(self,index):
24
24
assert index < len (self ),"index out of bounds in split_datset"
25
25
return self .parent_dataset [index + self .split_start ]
26
26
27
- def get_cifar10_dataloaders (args , validation_split = 0.0 ):
27
+ def get_cifar10_dataloaders (args , validation_split = 0.0 , max_threads = 10 ):
28
28
"""Creates augmented train, validation, and test data loaders."""
29
29
30
30
normalize = transforms .Normalize ((0.4914 , 0.4822 , 0.4465 ),
@@ -50,6 +50,15 @@ def get_cifar10_dataloaders(args, validation_split=0.0):
50
50
test_dataset = datasets .CIFAR10 ('_dataset' , False , test_transform , download = False )
51
51
52
52
53
+ # we need at least two threads
54
+ max_threads = 2 if max_threads < 2 else max_threads
55
+ if max_threads >= 6 :
56
+ val_threads = 2
57
+ train_threads = max_threads - val_threads
58
+ else :
59
+ val_threads = 1
60
+ train_threads = max_threads - 1
61
+
53
62
54
63
valid_loader = None
55
64
if validation_split > 0.0 :
@@ -59,12 +68,12 @@ def get_cifar10_dataloaders(args, validation_split=0.0):
59
68
train_loader = torch .utils .data .DataLoader (
60
69
train_dataset ,
61
70
args .batch_size ,
62
- num_workers = 8 ,
71
+ num_workers = train_threads ,
63
72
pin_memory = True , shuffle = True )
64
73
valid_loader = torch .utils .data .DataLoader (
65
74
val_dataset ,
66
75
args .test_batch_size ,
67
- num_workers = 2 ,
76
+ num_workers = val_threads ,
68
77
pin_memory = True )
69
78
else :
70
79
train_loader = torch .utils .data .DataLoader (
0 commit comments