Skip to content

Commit

Permalink
advanced noam with decay and accum scheduler (OpenNMT#1367)
Browse files Browse the repository at this point in the history
* advanced noam with decay and accum scheduler
  • Loading branch information
vince62s authored Mar 28, 2019
1 parent 6bc8efe commit b7a8c21
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 19 deletions.
2 changes: 1 addition & 1 deletion onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,4 +648,4 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True):
device,
is_train,
repeat=not opt.single_pass,
num_batches_multiple=opt.accum_count * opt.world_size)
num_batches_multiple=max(opt.accum_count) * opt.world_size)
7 changes: 5 additions & 2 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,14 @@ def train_opts(parser):
group.add('--normalization', '-normalization', default='sents',
choices=["sents", "tokens"],
help='Normalization method of the gradient.')
group.add('--accum_count', '-accum_count', type=int, default=1,
group.add('--accum_count', '-accum_count', type=int, nargs='+',
default=[1],
help="Accumulate gradient this many times. "
"Approximately equivalent to updating "
"batch_size * accum_count batches at once. "
"Recommended for Transformer.")
group.add('--accum_steps', '-accum_steps', type=int, nargs='+',
default=[0], help="Steps at which accum_count values change")
group.add('--valid_steps', '-valid_steps', type=int, default=10000,
help='Perfom validation every X steps')
group.add('--valid_batch_size', '-valid_batch_size', type=int, default=32,
Expand Down Expand Up @@ -479,7 +482,7 @@ def train_opts(parser):
help="Decay every decay_steps")

group.add('--decay_method', '-decay_method', type=str, default="none",
choices=['noam', 'rsqrt', 'none'],
choices=['noam', 'noamwd', 'rsqrt', 'none'],
help="Use a custom decay rate.")
group.add('--warmup_steps', '-warmup_steps', type=int, default=4000,
help="Number of warmup steps for custom decay.")
Expand Down
2 changes: 2 additions & 0 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def main(opt, device_id):
# at this point.
configure_process(opt, device_id)
init_logger(opt.log_file)
assert len(opt.accum_count) == len(opt.accum_steps), \
'Number of accum_count values must match number of accum_steps'
# Load checkpoint if we resume from a previous training.
if opt.train_from:
logger.info('Loading checkpoint from %s' % opt.train_from)
Expand Down
46 changes: 31 additions & 15 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None):
trunc_size = opt.truncated_decoder # Badly named...
shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0
norm_method = opt.normalization
grad_accum_count = opt.accum_count
accum_count = opt.accum_count
accum_steps = opt.accum_steps
n_gpu = opt.world_size
average_decay = opt.average_decay
average_every = opt.average_every
Expand All @@ -54,7 +55,8 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None):
report_manager = onmt.utils.build_report_manager(opt)
trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size,
shard_size, norm_method,
grad_accum_count, n_gpu, gpu_rank,
accum_count, accum_steps,
n_gpu, gpu_rank,
gpu_verbose_level, report_manager,
model_saver=model_saver if gpu_rank == 0 else None,
average_decay=average_decay,
Expand All @@ -80,7 +82,8 @@ class Trainer(object):
shard_size(int): compute loss in shards of this size for efficiency
data_type(string): type of the source input: [text|img|audio]
norm_method(string): normalization methods: [sents|tokens]
grad_accum_count(int): accumulate gradients this many times.
accum_count(list): accumulate gradients this many times.
accum_steps(list): steps for accum gradients changes.
report_manager(:obj:`onmt.utils.ReportMgrBase`):
the object that creates reports, or None
model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is
Expand All @@ -90,7 +93,9 @@ class Trainer(object):

def __init__(self, model, train_loss, valid_loss, optim,
trunc_size=0, shard_size=32,
norm_method="sents", grad_accum_count=1, n_gpu=1, gpu_rank=1,
norm_method="sents", accum_count=[1],
accum_steps=[0],
n_gpu=1, gpu_rank=1,
gpu_verbose_level=0, report_manager=None, model_saver=None,
average_decay=0, average_every=1, model_dtype='fp32'):
# Basic attributes.
Expand All @@ -101,7 +106,9 @@ def __init__(self, model, train_loss, valid_loss, optim,
self.trunc_size = trunc_size
self.shard_size = shard_size
self.norm_method = norm_method
self.grad_accum_count = grad_accum_count
self.accum_count_l = accum_count
self.accum_count = accum_count[0]
self.accum_steps = accum_steps
self.n_gpu = n_gpu
self.gpu_rank = gpu_rank
self.gpu_verbose_level = gpu_verbose_level
Expand All @@ -112,18 +119,26 @@ def __init__(self, model, train_loss, valid_loss, optim,
self.average_every = average_every
self.model_dtype = model_dtype

assert grad_accum_count > 0
if grad_accum_count > 1:
assert self.trunc_size == 0, \
"""To enable accumulated gradients,
you must disable target sequence truncating."""
for i in range(len(self.accum_count_l)):
assert self.accum_count_l[i] > 0
if self.accum_count_l[i] > 1:
assert self.trunc_size == 0, \
"""To enable accumulated gradients,
you must disable target sequence truncating."""

# Set model in training mode.
self.model.train()

def _accum_count(self, step):
for i in range(len(self.accum_steps)):
if step > self.accum_steps[i]:
_accum = self.accum_count_l[i]
return _accum

def _accum_batches(self, iterator):
batches = []
normalization = 0
self.accum_count = self._accum_count(self.optim.training_step)
for batch in iterator:
batches.append(batch)
if self.norm_method == "tokens":
Expand All @@ -132,8 +147,9 @@ def _accum_batches(self, iterator):
normalization += num_tokens.item()
else:
normalization += batch.batch_size
if len(batches) == self.grad_accum_count:
if len(batches) == self.accum_count:
yield batches, normalization
self.accum_count = self._accum_count(self.optim.training_step)
batches = []
normalization = 0
if batches:
Expand Down Expand Up @@ -289,7 +305,7 @@ def validate(self, valid_iter, moving_average=None):

def _gradient_accumulation(self, true_batches, normalization, total_stats,
report_stats):
if self.grad_accum_count > 1:
if self.accum_count > 1:
self.optim.zero_grad()

for batch in true_batches:
Expand All @@ -313,7 +329,7 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
tgt = tgt_outer[j: j + trunc_size]

# 2. F-prop all but generator.
if self.grad_accum_count == 1:
if self.accum_count == 1:
self.optim.zero_grad()
outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt)
bptt = True
Expand All @@ -335,7 +351,7 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
report_stats.update(batch_stats)

# 4. Update the parameters and statistics.
if self.grad_accum_count == 1:
if self.accum_count == 1:
# Multi GPU gradient gather
if self.n_gpu > 1:
grads = [p.grad.data for p in self.model.parameters()
Expand All @@ -354,7 +370,7 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,

# in case of multi step gradient accumulation,
# update only after accum batches
if self.grad_accum_count > 1:
if self.accum_count > 1:
if self.n_gpu > 1:
grads = [p.grad.data for p in self.model.parameters()
if p.requires_grad
Expand Down
18 changes: 18 additions & 0 deletions onmt/utils/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ def make_learning_rate_decay_fn(opt):
noam_decay,
warmup_steps=opt.warmup_steps,
model_size=opt.rnn_size)
elif opt.decay_method == 'noamwd':
return functools.partial(
noamwd_decay,
warmup_steps=opt.warmup_steps,
model_size=opt.rnn_size,
rate=opt.learning_rate_decay,
decay_steps=opt.decay_steps,
start_step=opt.start_decay_steps)
elif opt.decay_method == 'rsqrt':
return functools.partial(
rsqrt_decay, warmup_steps=opt.warmup_steps)
Expand All @@ -128,6 +136,16 @@ def noam_decay(step, warmup_steps, model_size):
min(step ** (-0.5), step * warmup_steps**(-1.5)))


def noamwd_decay(step, warmup_steps,
model_size, rate, decay_steps, start_step=0):
"""Learning rate schedule optimized for huge batches
"""
return (
model_size ** (-0.5) *
min(step ** (-0.5), step * warmup_steps**(-1.5)) *
rate ** (max(step - start_step + decay_steps, 0) // decay_steps))


def exponential_decay(step, rate, decay_steps, start_step=0):
"""A standard exponential decay, scaling the learning rate by :obj:`rate`
every :obj:`decay_steps` steps.
Expand Down
2 changes: 1 addition & 1 deletion onmt/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def validate_train_opts(cls, opt):
if opt.epochs:
raise AssertionError(
"-epochs is deprecated please use -train_steps.")
if opt.truncated_decoder > 0 and opt.accum_count > 1:
if opt.truncated_decoder > 0 and max(opt.accum_count) > 1:
raise AssertionError("BPTT is not compatible with -accum > 1")
if opt.gpuid:
raise AssertionError("gpuid is deprecated \
Expand Down

0 comments on commit b7a8c21

Please sign in to comment.