forked from GuYuc/WS-DAN.PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_wsdan.py
378 lines (308 loc) · 14.1 KB
/
train_wsdan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
"""TRAINING
Created: May 04,2019 - Yuchong Gu
Revised: May 07,2019 - Yuchong Gu
"""
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
import time
import logging
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from optparse import OptionParser
from utils import accuracy
from models import *
from dataset import *
def main():
parser = OptionParser()
parser.add_option('-j', '--workers', dest='workers', default=16, type='int',
help='number of data loading workers (default: 16)')
parser.add_option('-e', '--epochs', dest='epochs', default=80, type='int',
help='number of epochs (default: 80)')
parser.add_option('-b', '--batch-size', dest='batch_size', default=16, type='int',
help='batch size (default: 16)')
parser.add_option('-c', '--ckpt', dest='ckpt', default=False,
help='load checkpoint model (default: False)')
parser.add_option('-v', '--verbose', dest='verbose', default=100, type='int',
help='show information for each <verbose> iterations (default: 100)')
parser.add_option('--lr', '--learning-rate', dest='lr', default=1e-3, type='float',
help='learning rate (default: 1e-3)')
parser.add_option('--sf', '--save-freq', dest='save_freq', default=1, type='int',
help='saving frequency of .ckpt models (default: 1)')
parser.add_option('--sd', '--save-dir', dest='save_dir', default='./models',
help='saving directory of .ckpt models (default: ./models)')
parser.add_option('--init', '--initial-training', dest='initial_training', default=1, type='int',
help='train from 1-beginning or 0-resume training (default: 1)')
(options, args) = parser.parse_args()
logging.basicConfig(format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO)
warnings.filterwarnings("ignore")
##################################
# Initialize model
##################################
image_size = (512, 512)
num_classes = 1000
num_attentions = 32
start_epoch = 0
feature_net = inception_v3(pretrained=True)
net = WSDAN(num_classes=num_classes, M=num_attentions, net=feature_net)
# feature_center: size of (#classes, #attention_maps, #channel_features)
feature_center = torch.zeros(num_classes, num_attentions, net.num_features * net.expansion).to(torch.device("cuda"))
if options.ckpt:
ckpt = options.ckpt
if options.initial_training == 0:
# Get Name (epoch)
epoch_name = (ckpt.split('/')[-1]).split('.')[0]
start_epoch = int(epoch_name)
# Load ckpt and get state_dict
checkpoint = torch.load(ckpt)
state_dict = checkpoint['state_dict']
# Load weights
net.load_state_dict(state_dict)
logging.info('Network loaded from {}'.format(options.ckpt))
# load feature center
if 'feature_center' in checkpoint:
feature_center = checkpoint['feature_center'].to(torch.device("cuda"))
logging.info('feature_center loaded from {}'.format(options.ckpt))
##################################
# Initialize saving directory
##################################
save_dir = options.save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
##################################
# Use cuda
##################################
cudnn.benchmark = True
net.to(torch.device("cuda"))
net = nn.DataParallel(net)
##################################
# Load dataset
##################################
train_dataset, validate_dataset = CustomDataset(phase='train', shape=image_size), \
CustomDataset(phase='val' , shape=image_size)
train_loader, validate_loader = DataLoader(train_dataset, batch_size=options.batch_size, shuffle=True,
num_workers=options.workers, pin_memory=True), \
DataLoader(validate_dataset, batch_size=options.batch_size * 4, shuffle=False,
num_workers=options.workers, pin_memory=True)
optimizer = torch.optim.SGD(net.parameters(), lr=options.lr, momentum=0.9, weight_decay=0.00001)
loss = nn.CrossEntropyLoss()
##################################
# Learning rate scheduling
##################################
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)
##################################
# TRAINING
##################################
logging.info('')
logging.info('Start training: Total epochs: {}, Batch size: {}, Training size: {}, Validation size: {}'.
format(options.epochs, options.batch_size, len(train_dataset), len(validate_dataset)))
for epoch in range(start_epoch, options.epochs):
train(epoch=epoch,
data_loader=train_loader,
net=net,
feature_center=feature_center,
loss=loss,
optimizer=optimizer,
save_freq=options.save_freq,
save_dir=options.save_dir,
verbose=options.verbose)
val_loss = validate(data_loader=validate_loader,
net=net,
loss=loss,
verbose=options.verbose)
scheduler.step()
def train(**kwargs):
# Retrieve training configuration
data_loader = kwargs['data_loader']
net = kwargs['net']
loss = kwargs['loss']
optimizer = kwargs['optimizer']
feature_center = kwargs['feature_center']
epoch = kwargs['epoch']
save_freq = kwargs['save_freq']
save_dir = kwargs['save_dir']
verbose = kwargs['verbose']
# Attention Regularization: LA Loss
l2_loss = nn.MSELoss()
# Default Parameters
beta = 1e-4
theta_c = 0.5
theta_d = 0.5
crop_size = (256, 256) # size of cropped images for 'See Better'
# metrics initialization
batches = 0
epoch_loss = np.array([0, 0, 0], dtype='float') # Loss on Raw/Crop/Drop Images
epoch_acc = np.array([[0, 0, 0],
[0, 0, 0],
[0, 0, 0]], dtype='float') # Top-1/3/5 Accuracy for Raw/Crop/Drop Images
# begin training
start_time = time.time()
logging.info('Epoch %03d, Learning Rate %g' % (epoch + 1, optimizer.param_groups[0]['lr']))
net.train()
for i, (X, y) in enumerate(data_loader):
batch_start = time.time()
# obtain data for training
X = X.to(torch.device("cuda"))
y = y.to(torch.device("cuda"))
##################################
# Raw Image
##################################
y_pred, feature_matrix, attention_map = net(X)
# loss
batch_loss = loss(y_pred, y) + l2_loss(feature_matrix, feature_center[y])
epoch_loss[0] += batch_loss.item()
# backward
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
# Update Feature Center
feature_center[y] += beta * (feature_matrix.detach() - feature_center[y])
# metrics: top-1, top-3, top-5 error
with torch.no_grad():
epoch_acc[0] += accuracy(y_pred, y, topk=(1, 3, 5))
##################################
# Attention Cropping
##################################
with torch.no_grad():
crop_mask = F.upsample_bilinear(attention_map, size=(X.size(2), X.size(3))) > theta_c
crop_images = []
for batch_index in range(crop_mask.size(0)):
nonzero_indices = torch.nonzero(crop_mask[batch_index, 0, ...])
height_min = nonzero_indices[:, 0].min()
height_max = nonzero_indices[:, 0].max()
width_min = nonzero_indices[:, 1].min()
width_max = nonzero_indices[:, 1].max()
crop_images.append(F.upsample_bilinear(X[batch_index:batch_index + 1, :, height_min:height_max, width_min:width_max], size=crop_size))
crop_images = torch.cat(crop_images, dim=0)
# crop images forward
y_pred, _, _ = net(crop_images)
# loss
batch_loss = loss(y_pred, y)
epoch_loss[1] += batch_loss.item()
# backward
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
# metrics: top-1, top-3, top-5 error
with torch.no_grad():
epoch_acc[1] += accuracy(y_pred, y, topk=(1, 3, 5))
##################################
# Attention Dropping
##################################
with torch.no_grad():
drop_mask = F.upsample_bilinear(attention_map, size=(X.size(2), X.size(3))) <= theta_d
drop_images = X * drop_mask.float()
# drop images forward
y_pred, _, _ = net(drop_images)
# loss
batch_loss = loss(y_pred, y)
epoch_loss[2] += batch_loss.item()
# backward
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
# metrics: top-1, top-3, top-5 error
with torch.no_grad():
epoch_acc[2] += accuracy(y_pred, y, topk=(1, 3, 5))
# end of this batch
batches += 1
batch_end = time.time()
if (i + 1) % verbose == 0:
logging.info('\tBatch %d: (Raw) Loss %.4f, Accuracy: (%.2f, %.2f, %.2f), (Crop) Loss %.4f, Accuracy: (%.2f, %.2f, %.2f), (Drop) Loss %.4f, Accuracy: (%.2f, %.2f, %.2f), Time %3.2f' %
(i + 1,
epoch_loss[0] / batches, epoch_acc[0, 0] / batches, epoch_acc[0, 1] / batches, epoch_acc[0, 2] / batches,
epoch_loss[1] / batches, epoch_acc[1, 0] / batches, epoch_acc[1, 1] / batches, epoch_acc[1, 2] / batches,
epoch_loss[2] / batches, epoch_acc[2, 0] / batches, epoch_acc[2, 1] / batches, epoch_acc[2, 2] / batches,
batch_end - batch_start))
# save checkpoint model
if epoch % save_freq == 0:
state_dict = net.module.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].cpu()
torch.save({
'epoch': epoch,
'save_dir': save_dir,
'state_dict': state_dict,
'feature_center': feature_center.cpu()},
os.path.join(save_dir, '%03d.ckpt' % (epoch + 1)))
# end of this epoch
end_time = time.time()
# metrics for average
epoch_loss /= batches
epoch_acc /= batches
# show information for this epoch
logging.info('Train: (Raw) Loss %.4f, Accuracy: (%.2f, %.2f, %.2f), (Crop) Loss %.4f, Accuracy: (%.2f, %.2f, %.2f), (Drop) Loss %.4f, Accuracy: (%.2f, %.2f, %.2f), Time %3.2f'%
(epoch_loss[0], epoch_acc[0, 0], epoch_acc[0, 1], epoch_acc[0, 2],
epoch_loss[1], epoch_acc[1, 0], epoch_acc[1, 1], epoch_acc[1, 2],
epoch_loss[2], epoch_acc[2, 0], epoch_acc[2, 1], epoch_acc[2, 2],
end_time - start_time))
def validate(**kwargs):
# Retrieve training configuration
data_loader = kwargs['data_loader']
net = kwargs['net']
loss = kwargs['loss']
verbose = kwargs['verbose']
# Default Parameters
theta_c = 0.5
crop_size = (256, 256) # size of cropped images for 'See Better'
# metrics initialization
batches = 0
epoch_loss = 0
epoch_acc = np.array([0, 0, 0], dtype='float') # top - 1, 3, 5
# begin validation
start_time = time.time()
net.eval()
with torch.no_grad():
for i, (X, y) in enumerate(data_loader):
batch_start = time.time()
# obtain data
X = X.to(torch.device("cuda"))
y = y.to(torch.device("cuda"))
##################################
# Raw Image
##################################
y_pred_raw, feature_matrix, attention_map = net(X)
##################################
# Object Localization and Refinement
##################################
crop_mask = F.upsample_bilinear(attention_map, size=(X.size(2), X.size(3))) > theta_c
crop_images = []
for batch_index in range(crop_mask.size(0)):
nonzero_indices = torch.nonzero(crop_mask[batch_index, 0, ...])
height_min = nonzero_indices[:, 0].min()
height_max = nonzero_indices[:, 0].max()
width_min = nonzero_indices[:, 1].min()
width_max = nonzero_indices[:, 1].max()
crop_images.append(F.upsample_bilinear(X[batch_index:batch_index + 1, :, height_min:height_max, width_min:width_max], size=crop_size))
crop_images = torch.cat(crop_images, dim=0)
y_pred_crop, _, _ = net(crop_images)
# final prediction
y_pred = (y_pred_raw + y_pred_crop) / 2
# loss
batch_loss = loss(y_pred, y)
epoch_loss += batch_loss.item()
# metrics: top-1, top-3, top-5 error
epoch_acc += accuracy(y_pred, y, topk=(1, 3, 5))
# end of this batch
batches += 1
batch_end = time.time()
if (i + 1) % verbose == 0:
logging.info('\tBatch %d: Loss %.5f, Accuracy: Top-1 %.2f, Top-3 %.2f, Top-5 %.2f, Time %3.2f' %
(i + 1, epoch_loss / batches, epoch_acc[0] / batches, epoch_acc[1] / batches, epoch_acc[2] / batches, batch_end - batch_start))
# end of validation
end_time = time.time()
# metrics for average
epoch_loss /= batches
epoch_acc /= batches
# show information for this epoch
logging.info('Valid: Loss %.5f, Accuracy: Top-1 %.2f, Top-3 %.2f, Top-5 %.2f, Time %3.2f'%
(epoch_loss, epoch_acc[0], epoch_acc[1], epoch_acc[2], end_time - start_time))
logging.info('')
return epoch_loss
if __name__ == '__main__':
main()