Skip to content

Commit 5b4a93e

Browse files
committed
Merge branch 'master' into fully_sparse_dynamic
2 parents 6b40284 + dc5677e commit 5b4a93e

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

mnist_cifar/main.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def main():
155155
parser.add_argument('--iters', type=int, default=1, help='How many times the model should be run after each other. Default=1')
156156
parser.add_argument('--save-features', action='store_true', help='Resumes a saved model and saves its feature data to disk for plotting.')
157157
parser.add_argument('--bench', action='store_true', help='Enables the benchmarking of layers and estimates sparse speedups')
158+
parser.add_argument('--max-threads', type=int, default=10, help='How many threads to use for data loading.')
158159
sparselearning.core.add_sparse_args(parser)
159160

160161
args = parser.parse_args()
@@ -180,7 +181,7 @@ def main():
180181
if args.data == 'mnist':
181182
train_loader, valid_loader, test_loader = get_mnist_dataloaders(args, validation_split=args.valid_split)
182183
else:
183-
train_loader, valid_loader, test_loader = get_cifar10_dataloaders(args, args.valid_split)
184+
train_loader, valid_loader, test_loader = get_cifar10_dataloaders(args, args.valid_split, max_threads=args.max_threads)
184185

185186
if args.model not in models:
186187
print('You need to select an existing model via the --model argument. Available models include: ')

sparselearning/utils.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __getitem__(self,index):
2424
assert index < len(self),"index out of bounds in split_datset"
2525
return self.parent_dataset[index + self.split_start]
2626

27-
def get_cifar10_dataloaders(args, validation_split=0.0):
27+
def get_cifar10_dataloaders(args, validation_split=0.0, max_threads=10):
2828
"""Creates augmented train, validation, and test data loaders."""
2929

3030
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
@@ -50,6 +50,15 @@ def get_cifar10_dataloaders(args, validation_split=0.0):
5050
test_dataset = datasets.CIFAR10('_dataset', False, test_transform, download=False)
5151

5252

53+
# we need at least two threads
54+
max_threads = 2 if max_threads < 2 else max_threads
55+
if max_threads >= 6:
56+
val_threads = 2
57+
train_threads = max_threads - val_threads
58+
else:
59+
val_threads = 1
60+
train_threads = max_threads - 1
61+
5362

5463
valid_loader = None
5564
if validation_split > 0.0:
@@ -59,12 +68,12 @@ def get_cifar10_dataloaders(args, validation_split=0.0):
5968
train_loader = torch.utils.data.DataLoader(
6069
train_dataset,
6170
args.batch_size,
62-
num_workers=8,
71+
num_workers=train_threads,
6372
pin_memory=True, shuffle=True)
6473
valid_loader = torch.utils.data.DataLoader(
6574
val_dataset,
6675
args.test_batch_size,
67-
num_workers=2,
76+
num_workers=val_threads,
6877
pin_memory=True)
6978
else:
7079
train_loader = torch.utils.data.DataLoader(

0 commit comments

Comments
 (0)