-
Notifications
You must be signed in to change notification settings - Fork 534
/
train_ssd.py
338 lines (297 loc) · 15.1 KB
/
train_ssd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import argparse
import os
import logging
import sys
import itertools
import torch
from torch.utils.data import DataLoader, ConcatDataset
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
from vision.utils.misc import str2bool, Timer, freeze_net_layers, store_labels
from vision.ssd.ssd import MatchPrior
from vision.ssd.vgg_ssd import create_vgg_ssd
from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd
from vision.ssd.mobilenetv1_ssd_lite import create_mobilenetv1_ssd_lite
from vision.ssd.mobilenet_v2_ssd_lite import create_mobilenetv2_ssd_lite
from vision.ssd.mobilenetv3_ssd_lite import create_mobilenetv3_large_ssd_lite, create_mobilenetv3_small_ssd_lite
from vision.ssd.squeezenet_ssd_lite import create_squeezenet_ssd_lite
from vision.datasets.voc_dataset import VOCDataset
from vision.datasets.open_images import OpenImagesDataset
from vision.nn.multibox_loss import MultiboxLoss
from vision.ssd.config import vgg_ssd_config
from vision.ssd.config import mobilenetv1_ssd_config
from vision.ssd.config import squeezenet_ssd_config
from vision.ssd.data_preprocessing import TrainAugmentation, TestTransform
parser = argparse.ArgumentParser(
description='Single Shot MultiBox Detector Training With Pytorch')
parser.add_argument("--dataset_type", default="voc", type=str,
help='Specify dataset type. Currently support voc and open_images.')
parser.add_argument('--datasets', nargs='+', help='Dataset directory path')
parser.add_argument('--validation_dataset', help='Dataset directory path')
parser.add_argument('--balance_data', action='store_true',
help="Balance training data by down-sampling more frequent labels.")
parser.add_argument('--net', default="vgg16-ssd",
help="The network architecture, it can be mb1-ssd, mb1-lite-ssd, mb2-ssd-lite, mb3-large-ssd-lite, mb3-small-ssd-lite or vgg16-ssd.")
parser.add_argument('--freeze_base_net', action='store_true',
help="Freeze base net layers.")
parser.add_argument('--freeze_net', action='store_true',
help="Freeze all the layers except the prediction head.")
parser.add_argument('--mb2_width_mult', default=1.0, type=float,
help='Width Multiplifier for MobilenetV2')
# Params for SGD
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,
help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float,
help='Momentum value for optim')
parser.add_argument('--weight_decay', default=5e-4, type=float,
help='Weight decay for SGD')
parser.add_argument('--gamma', default=0.1, type=float,
help='Gamma update for SGD')
parser.add_argument('--base_net_lr', default=None, type=float,
help='initial learning rate for base net.')
parser.add_argument('--extra_layers_lr', default=None, type=float,
help='initial learning rate for the layers not in base net and prediction heads.')
# Params for loading pretrained basenet or checkpoints.
parser.add_argument('--base_net',
help='Pretrained base model')
parser.add_argument('--pretrained_ssd', help='Pre-trained base model')
parser.add_argument('--resume', default=None, type=str,
help='Checkpoint state_dict file to resume training from')
# Scheduler
parser.add_argument('--scheduler', default="multi-step", type=str,
help="Scheduler for SGD. It can one of multi-step and cosine")
# Params for Multi-step Scheduler
parser.add_argument('--milestones', default="80,100", type=str,
help="milestones for MultiStepLR")
# Params for Cosine Annealing
parser.add_argument('--t_max', default=120, type=float,
help='T_max value for Cosine Annealing Scheduler.')
# Train params
parser.add_argument('--batch_size', default=32, type=int,
help='Batch size for training')
parser.add_argument('--num_epochs', default=120, type=int,
help='the number epochs')
parser.add_argument('--num_workers', default=4, type=int,
help='Number of workers used in dataloading')
parser.add_argument('--validation_epochs', default=5, type=int,
help='the number epochs')
parser.add_argument('--debug_steps', default=100, type=int,
help='Set the debug log output frequency.')
parser.add_argument('--use_cuda', default=True, type=str2bool,
help='Use CUDA to train model')
parser.add_argument('--checkpoint_folder', default='models/',
help='Directory for saving checkpoint models')
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
args = parser.parse_args()
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu")
if args.use_cuda and torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
logging.info("Use Cuda.")
def train(loader, net, criterion, optimizer, device, debug_steps=100, epoch=-1):
net.train(True)
running_loss = 0.0
running_regression_loss = 0.0
running_classification_loss = 0.0
for i, data in enumerate(loader):
images, boxes, labels = data
images = images.to(device)
boxes = boxes.to(device)
labels = labels.to(device)
optimizer.zero_grad()
confidence, locations = net(images)
regression_loss, classification_loss = criterion(confidence, locations, labels, boxes) # TODO CHANGE BOXES
loss = regression_loss + classification_loss
loss.backward()
optimizer.step()
running_loss += loss.item()
running_regression_loss += regression_loss.item()
running_classification_loss += classification_loss.item()
if i and i % debug_steps == 0:
avg_loss = running_loss / debug_steps
avg_reg_loss = running_regression_loss / debug_steps
avg_clf_loss = running_classification_loss / debug_steps
logging.info(
f"Epoch: {epoch}, Step: {i}, " +
f"Average Loss: {avg_loss:.4f}, " +
f"Average Regression Loss {avg_reg_loss:.4f}, " +
f"Average Classification Loss: {avg_clf_loss:.4f}"
)
running_loss = 0.0
running_regression_loss = 0.0
running_classification_loss = 0.0
def test(loader, net, criterion, device):
net.eval()
running_loss = 0.0
running_regression_loss = 0.0
running_classification_loss = 0.0
num = 0
for _, data in enumerate(loader):
images, boxes, labels = data
images = images.to(device)
boxes = boxes.to(device)
labels = labels.to(device)
num += 1
with torch.no_grad():
confidence, locations = net(images)
regression_loss, classification_loss = criterion(confidence, locations, labels, boxes)
loss = regression_loss + classification_loss
running_loss += loss.item()
running_regression_loss += regression_loss.item()
running_classification_loss += classification_loss.item()
return running_loss / num, running_regression_loss / num, running_classification_loss / num
if __name__ == '__main__':
timer = Timer()
logging.info(args)
if args.net == 'vgg16-ssd':
create_net = create_vgg_ssd
config = vgg_ssd_config
elif args.net == 'mb1-ssd':
create_net = create_mobilenetv1_ssd
config = mobilenetv1_ssd_config
elif args.net == 'mb1-ssd-lite':
create_net = create_mobilenetv1_ssd_lite
config = mobilenetv1_ssd_config
elif args.net == 'sq-ssd-lite':
create_net = create_squeezenet_ssd_lite
config = squeezenet_ssd_config
elif args.net == 'mb2-ssd-lite':
create_net = lambda num: create_mobilenetv2_ssd_lite(num, width_mult=args.mb2_width_mult)
config = mobilenetv1_ssd_config
elif args.net == 'mb3-large-ssd-lite':
create_net = lambda num: create_mobilenetv3_large_ssd_lite(num)
config = mobilenetv1_ssd_config
elif args.net == 'mb3-small-ssd-lite':
create_net = lambda num: create_mobilenetv3_small_ssd_lite(num)
config = mobilenetv1_ssd_config
else:
logging.fatal("The net type is wrong.")
parser.print_help(sys.stderr)
sys.exit(1)
train_transform = TrainAugmentation(config.image_size, config.image_mean, config.image_std)
target_transform = MatchPrior(config.priors, config.center_variance,
config.size_variance, 0.5)
test_transform = TestTransform(config.image_size, config.image_mean, config.image_std)
logging.info("Prepare training datasets.")
datasets = []
for dataset_path in args.datasets:
if args.dataset_type == 'voc':
dataset = VOCDataset(dataset_path, transform=train_transform,
target_transform=target_transform)
label_file = os.path.join(args.checkpoint_folder, "voc-model-labels.txt")
store_labels(label_file, dataset.class_names)
num_classes = len(dataset.class_names)
elif args.dataset_type == 'open_images':
dataset = OpenImagesDataset(dataset_path,
transform=train_transform, target_transform=target_transform,
dataset_type="train", balance_data=args.balance_data)
label_file = os.path.join(args.checkpoint_folder, "open-images-model-labels.txt")
store_labels(label_file, dataset.class_names)
logging.info(dataset)
num_classes = len(dataset.class_names)
else:
raise ValueError(f"Dataset type {args.dataset_type} is not supported.")
datasets.append(dataset)
logging.info(f"Stored labels into file {label_file}.")
train_dataset = ConcatDataset(datasets)
logging.info("Train dataset size: {}".format(len(train_dataset)))
train_loader = DataLoader(train_dataset, args.batch_size,
num_workers=args.num_workers,
shuffle=True)
logging.info("Prepare Validation datasets.")
if args.dataset_type == "voc":
val_dataset = VOCDataset(args.validation_dataset, transform=test_transform,
target_transform=target_transform, is_test=True)
elif args.dataset_type == 'open_images':
val_dataset = OpenImagesDataset(dataset_path,
transform=test_transform, target_transform=target_transform,
dataset_type="test")
logging.info(val_dataset)
logging.info("validation dataset size: {}".format(len(val_dataset)))
val_loader = DataLoader(val_dataset, args.batch_size,
num_workers=args.num_workers,
shuffle=False)
logging.info("Build network.")
net = create_net(num_classes)
min_loss = -10000.0
last_epoch = -1
base_net_lr = args.base_net_lr if args.base_net_lr is not None else args.lr
extra_layers_lr = args.extra_layers_lr if args.extra_layers_lr is not None else args.lr
if args.freeze_base_net:
logging.info("Freeze base net.")
freeze_net_layers(net.base_net)
params = itertools.chain(net.source_layer_add_ons.parameters(), net.extras.parameters(),
net.regression_headers.parameters(), net.classification_headers.parameters())
params = [
{'params': itertools.chain(
net.source_layer_add_ons.parameters(),
net.extras.parameters()
), 'lr': extra_layers_lr},
{'params': itertools.chain(
net.regression_headers.parameters(),
net.classification_headers.parameters()
)}
]
elif args.freeze_net:
freeze_net_layers(net.base_net)
freeze_net_layers(net.source_layer_add_ons)
freeze_net_layers(net.extras)
params = itertools.chain(net.regression_headers.parameters(), net.classification_headers.parameters())
logging.info("Freeze all the layers except prediction heads.")
else:
params = [
{'params': net.base_net.parameters(), 'lr': base_net_lr},
{'params': itertools.chain(
net.source_layer_add_ons.parameters(),
net.extras.parameters()
), 'lr': extra_layers_lr},
{'params': itertools.chain(
net.regression_headers.parameters(),
net.classification_headers.parameters()
)}
]
timer.start("Load Model")
if args.resume:
logging.info(f"Resume from the model {args.resume}")
net.load(args.resume)
elif args.base_net:
logging.info(f"Init from base net {args.base_net}")
net.init_from_base_net(args.base_net)
elif args.pretrained_ssd:
logging.info(f"Init from pretrained ssd {args.pretrained_ssd}")
net.init_from_pretrained_ssd(args.pretrained_ssd)
logging.info(f'Took {timer.end("Load Model"):.2f} seconds to load the model.')
net.to(DEVICE)
criterion = MultiboxLoss(config.priors, iou_threshold=0.5, neg_pos_ratio=3,
center_variance=0.1, size_variance=0.2, device=DEVICE)
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay)
logging.info(f"Learning rate: {args.lr}, Base net learning rate: {base_net_lr}, "
+ f"Extra Layers learning rate: {extra_layers_lr}.")
if args.scheduler == 'multi-step':
logging.info("Uses MultiStepLR scheduler.")
milestones = [int(v.strip()) for v in args.milestones.split(",")]
scheduler = MultiStepLR(optimizer, milestones=milestones,
gamma=0.1, last_epoch=last_epoch)
elif args.scheduler == 'cosine':
logging.info("Uses CosineAnnealingLR scheduler.")
scheduler = CosineAnnealingLR(optimizer, args.t_max, last_epoch=last_epoch)
else:
logging.fatal(f"Unsupported Scheduler: {args.scheduler}.")
parser.print_help(sys.stderr)
sys.exit(1)
logging.info(f"Start training from epoch {last_epoch + 1}.")
for epoch in range(last_epoch + 1, args.num_epochs):
scheduler.step()
train(train_loader, net, criterion, optimizer,
device=DEVICE, debug_steps=args.debug_steps, epoch=epoch)
if epoch % args.validation_epochs == 0 or epoch == args.num_epochs - 1:
val_loss, val_regression_loss, val_classification_loss = test(val_loader, net, criterion, DEVICE)
logging.info(
f"Epoch: {epoch}, " +
f"Validation Loss: {val_loss:.4f}, " +
f"Validation Regression Loss {val_regression_loss:.4f}, " +
f"Validation Classification Loss: {val_classification_loss:.4f}"
)
model_path = os.path.join(args.checkpoint_folder, f"{args.net}-Epoch-{epoch}-Loss-{val_loss}.pth")
net.save(model_path)
logging.info(f"Saved model {model_path}")