forked from apple/ml-cifar-10-faster
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fast_cifar_10_distributed.py
303 lines (252 loc) · 10.9 KB
/
fast_cifar_10_distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2019 Apple Inc. All Rights Reserved.
#
from utils import *
from functools import partial
from torch import distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import argparse
from torch.jit import script
from torch.nn import *
@script
def _mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x)))
@script
def _mish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
class MishJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return _mish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_variables[0]
return _mish_jit_bwd(x, grad_output)
def mish(x): return MishJitAutoFn.apply(x)
class MishJit(Module):
def forward(self, x): return MishJitAutoFn.apply(x)
# We use pytorch's distributed package with NCCL for inter-gpu communication
# Define the process group
dist.init_process_group(
backend='nccl',
init_method='env://'
)
# Set variables for the local worker to determine its rank and the world size
rank = dist.get_rank()
is_rank0 = rank == 0
world_size = dist.get_world_size()
# Assign a device for this worker.
device = torch.device(
"cuda:{}".format(rank) if torch.cuda.is_available() else "cpu"
)
torch.random.manual_seed(rank)
torch.backends.cudnn.benchmark = True
def run_benchmark(lr_scaler=1.0,
lr_end_fraction=0.1,
epochs=16,
batch_size=512,
ema_epochs=2,
n_runs=1,
warmup_fraction=5):
# Wait for GPUS to be initialized
torch.cuda.synchronize()
# Download the dataset
dataset = cifar10(root='./data/') # downloads dataset
# Start timing all processes together
dist.barrier()
timer = Timer(synch=torch.cuda.synchronize)
# Copy the dataset to the GPUs
dataset = map_nested(to(device), dataset)
dist.barrier()
timer()
data_transfer_time = timer.total_time
if rank == 0:
print(f"Uploaded data to GPUs {data_transfer_time:.3f}s")
# Select a shard of the training dataset for this worker, and select all of the validation dataset
selector = list(range(rank, len(dataset['train']['data']), world_size))
dataset = {'train': {'data': dataset['train']['data'][selector], 'targets':dataset['train']['targets'][selector]},
'valid': dataset['valid']}
# Upload the mean and standard deviations to the GPU
mean, std = [torch.tensor(x, device=device, dtype=torch.float16) for x in (CIFAR10_MEAN, CIFAR10_STD)]
train_set = preprocess(dataset['train'], [partial(pad, border=4), transpose,
partial(normalise, mean=mean, std=std), to(torch.float16)])
valid_set = preprocess(dataset['valid'], [transpose,
partial(normalise, mean=mean, std=std), to(torch.float16)])
train_batches = partial(
Batches,
dataset=train_set,
shuffle=True,
drop_last=True,
max_options=200,
device=device
)
valid_batches = partial(
Batches,
dataset=valid_set,
shuffle=False,
drop_last=False,
device=device
)
# Data pre-processing
dist.barrier()
timer()
eigen_values, eigen_vectors = compute_patch_whitening_statistics(train_set)
timer(update_total=False) # We do not count the data pre-processing time
# Run the training process n_runs times
logs = []
for run in range(n_runs):
# Network construction
# Architecture
channels = {'prep': 64, 'layer1': 128, 'layer2': 256, 'layer3': 512}
input_whitening_net = build_network(
channels=channels, extra_layers=(), res_layers=('layer1', 'layer3'),
conv_pool_block=conv_pool_block_pre, prep_block=partial(whitening_block,
eigen_values=eigen_values,
eigen_vectors=eigen_vectors),
scale=1 / 16,
types={
#nn.ReLU: partial(nn.CELU, 0.3),
nn.ReLU: MishJit,
BatchNorm: partial(GhostBatchNorm, num_splits=16, weight=False)
}
)
# Model to evaluate after the distributed model is trained
local_eval_model = Network(input_whitening_net, label_smoothing_loss(0.2)).half().to(device)
# Distributed model to train by all workers
distributed_model = Network(input_whitening_net, label_smoothing_loss(0.2)).half().to(device)
is_bias = group_by_key(('bias' in k, v) for k, v in trainable_params(distributed_model).items())
loss = distributed_model.loss
# Make sure all workers start timing here
dist.barrier()
timer = Timer(torch.cuda.synchronize)
# Wrap with distributed data parallel, this introduces hooks to execute all-reduce upon back propagation
distributed_model = DDP(distributed_model, device_ids=[rank])
if is_rank0:
# Save the model in rank 0 to initialize all the others
with open('initialized.model', 'wb') as f:
torch.save(distributed_model.state_dict(), f)
dist.barrier()
with open('initialized.model', 'rb') as f:
distributed_model.load_state_dict(torch.load(f))
# Data iterators
transforms = (Crop(32, 32), FlipLR())
tbatches = train_batches(batch_size, transforms)
train_batch_count = len(tbatches)
vbatches = valid_batches(batch_size)
# Construct the learning rate, weight decay and momentum schedules.
opt_params = {'lr': lr_schedule(
[0, epochs / warmup_fraction, epochs - ema_epochs],
[0.0, lr_scaler * 1.0, lr_scaler * lr_end_fraction],
batch_size, train_batch_count
),
'weight_decay': Const(5e-4 * lr_scaler * batch_size), 'momentum': Const(0.9)}
opt_params_bias = {'lr': lr_schedule(
[0, epochs / warmup_fraction, epochs - ema_epochs],
[0.0, lr_scaler * 1.0 * 64, lr_scaler * lr_end_fraction * 64],
batch_size, train_batch_count
),
'weight_decay': Const(5e-4 * lr_scaler * batch_size / 64), 'momentum': Const(0.9)}
opt = SGDOpt(
weight_param_schedule=opt_params,
bias_param_schedule=opt_params_bias,
weight_params=is_bias[False],
bias_params=is_bias[True]
)
# Train the network
distributed_model.train(True)
epochs_log = []
for epoch in range(epochs):
activations_log = []
for tb in tbatches:
# Forward step
out = loss(distributed_model(tb))
distributed_model.zero_grad()
out['loss'].sum().backward()
opt.step()
# Log activations
activations_log.append(('loss', out['loss'].detach()))
activations_log.append(('acc', out['acc'].detach()))
# Compute the average over the activation logs for the last epoch
res = map_values((lambda xs: to_numpy(torch.cat(xs)).astype(np.float)), group_by_key(activations_log))
train_summary = mean_members(res)
timer()
# Evaluate the model
# Copy the weights to the local model
model_dict = {k[7:]: v for k, v in distributed_model.state_dict().items()}
local_eval_model.load_state_dict(model_dict)
valid_summary = eval_on_batches(local_eval_model, loss, vbatches)
timer(update_total=False)
time_to_epoch_end = timer.total_time + data_transfer_time
epochs_log.append(
{
'valid': valid_summary,
'train': train_summary,
'time': time_to_epoch_end
}
)
# Wait until all models finished training
dist.barrier()
timer()
# Print output
if is_rank0:
print("Train acc {:.3f} loss {:.3f}, validation acc {:.3f} loss {:.3f} wall time {:3.3f}s".format(
train_summary['acc'], train_summary['loss'],
valid_summary['acc'], valid_summary['loss'],
timer.total_time + data_transfer_time
))
if run == 0:
save_log_to_tsv(epochs_log, path='timing_log.tsv')
# Save the model
torch.save(local_eval_model.state_dict(), 'replica_0_model')
logs.append(
{
'tain_acc': train_summary['acc'],
'tain_loss': train_summary['loss'],
'valid_acc': valid_summary['acc'],
'valid_loss': valid_summary['loss'],
'time': timer.total_time
}
)
dist.barrier()
# Compute the average accuracies and training times
times = [d['time'] for d in logs]
accuracies = [d['valid_acc'] for d in logs]
if is_rank0:
print("Maximum training time {} median {}".format(np.max(times), np.median(times)))
print("Lowest accuracy {} median {}".format(np.min(accuracies), np.median(accuracies)))
print("{} runs reached 0.94 out of {}".format(
np.count_nonzero(
np.array(accuracies) >= 0.94
),
n_runs
))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--lr_scaler', type=float, default=1.5,
help='Multiplicative scaling factor for the learning rate schedule')
parser.add_argument('--lr_scaler_end_fraction', type=float, default=0.1,
help='Fraction of the peak learning rate used for the final step')
parser.add_argument('--epochs', type=int, default=18,
help='Total number of training epochs')
parser.add_argument('--warmup_fraction', type=float, default=5,
help='Inverse of fraction of the epochs used to reach the peak learning rate')
parser.add_argument('--ema_ep', type=float, default=2,
help='Number of epochs (at the end of training) '
'where the learing rate is to be maintained constant')
parser.add_argument('--batch_size', type=int, default=256,
help='Per GPU batch size')
parser.add_argument('--runs', type=int, default=1,
help='Number of replicas')
args = parser.parse_args()
run_benchmark(
lr_scaler=args.lr_scaler,
lr_end_fraction=args.lr_scaler_end_fraction,
epochs=args.epochs,
ema_epochs=args.ema_ep,
n_runs=args.runs,
batch_size=args.batch_size,
warmup_fraction=args.warmup_fraction
)