Skip to content
This repository has been archived by the owner on Jun 15, 2022. It is now read-only.

Commit

Permalink
Merge pull request #33 from DeNA/feature/use_scheduler
Browse files Browse the repository at this point in the history
use the official scheduler
  • Loading branch information
hirotomusiker authored Apr 18, 2019
2 parents 4c64fba + 388a33c commit a42ecd1
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def main():

print("successfully loaded config file: ", cfg)

lr = cfg['TRAIN']['LR']
momentum = cfg['TRAIN']['MOMENTUM']
decay = cfg['TRAIN']['DECAY']
burn_in = cfg['TRAIN']['BURN_IN']
Expand All @@ -67,12 +66,22 @@ def main():
subdivision = cfg['TRAIN']['SUBDIVISION']
ignore_thre = cfg['TRAIN']['IGNORETHRE']
random_resize = cfg['AUGMENTATION']['RANDRESIZE']
base_lr = cfg['TRAIN']['LR'] / batch_size / subdivision

print('effective_batch_size = batch_size * iter_size = %d * %d' %
(batch_size, subdivision))

# Learning rate setup
base_lr = lr
def burnin_schedule(i):
if i < burn_in:
factor = pow(i / burn_in, 4)
elif i < steps[0]:
factor = 1.0
elif i < steps[1]:
factor = 0.1
else:
factor = 0.01
return factor

# Initiate model
model = YOLOv3(cfg['MODEL'], ignore_thre=ignore_thre)
Expand Down Expand Up @@ -137,13 +146,7 @@ def main():
optimizer.load_state_dict(state['optimizer_state_dict'])
iter_state = state['iter'] + 1

# TODO: replace the following scheduler with the PyTorch's official one

tmp_lr = base_lr

def set_lr(tmp_lr):
for param_group in optimizer.param_groups:
param_group['lr'] = tmp_lr / batch_size / subdivision
scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)

# start training loop
for iter_i in range(iter_state, iter_size + 1):
Expand All @@ -156,17 +159,6 @@ def set_lr(tmp_lr):
tblogger.add_scalar('val/COCOAP50', ap50, iter_i)
tblogger.add_scalar('val/COCOAP50_95', ap50_95, iter_i)

# learning rate scheduling
if iter_i < burn_in:
tmp_lr = base_lr * pow(iter_i / burn_in, 4)
set_lr(tmp_lr)
elif iter_i == burn_in:
tmp_lr = base_lr
set_lr(tmp_lr)
elif iter_i in steps:
tmp_lr = tmp_lr * 0.1
set_lr(tmp_lr)

# subdivision loop
optimizer.zero_grad()
for inner_iter_i in range(subdivision):
Expand All @@ -181,12 +173,14 @@ def set_lr(tmp_lr):
loss.backward()

optimizer.step()
scheduler.step()

if iter_i % 10 == 0:
# logging
current_lr = scheduler.get_lr()[0] * batch_size * subdivision
print('[Iter %d/%d] [lr %f] '
'[Losses: xy %f, wh %f, conf %f, cls %f, total %f, imgsize %d]'
% (iter_i, iter_size, tmp_lr,
% (iter_i, iter_size, current_lr,
model.loss_dict['xy'], model.loss_dict['wh'],
model.loss_dict['conf'], model.loss_dict['cls'],
model.loss_dict['l2'], imgsize),
Expand Down

0 comments on commit a42ecd1

Please sign in to comment.