Skip to content

Commit

Permalink
Introduce dropout scheduler (OpenNMT#1421)
Browse files Browse the repository at this point in the history
* add update_dropout methods approx. everywhere, dropout scheduler
* more meaningful log
* forgot some layers in audio_encoder
  • Loading branch information
francoishernandez authored and vince62s committed May 16, 2019
1 parent 674062f commit 607c091
Show file tree
Hide file tree
Showing 16 changed files with 104 additions and 15 deletions.
6 changes: 5 additions & 1 deletion onmt/decoders/cnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def from_opt(cls, opt, embeddings):
opt.global_attention,
opt.copy_attn,
opt.cnn_kernel_width,
opt.dropout,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
embeddings,
opt.copy_attn_type)

Expand Down Expand Up @@ -127,3 +127,7 @@ def forward(self, tgt, memory_bank, step=None, **kwargs):
self.state["previous_input"] = tgt
# TODO change the way attns is returned dict => list or tuple (onnx)
return dec_outs, attns

def update_dropout(self, dropout):
for layer in self.conv_layers:
layer.dropout.p = dropout
12 changes: 11 additions & 1 deletion onmt/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def from_opt(cls, opt, embeddings):
opt.coverage_attn,
opt.context_gate,
opt.copy_attn,
opt.dropout,
opt.dropout[0] if type(opt.dropout) is list
else opt.dropout,
embeddings,
opt.reuse_copy_attn,
opt.copy_attn_type)
Expand Down Expand Up @@ -233,6 +234,10 @@ def forward(self, tgt, memory_bank, memory_lengths=None, step=None):
attns[k] = torch.stack(attns[k])
return dec_outs, attns

def update_dropout(self, dropout):
self.dropout.p = dropout
self.embeddings.update_dropout(dropout)


class StdRNNDecoder(RNNDecoderBase):
"""Standard fully batched RNN decoder with attention.
Expand Down Expand Up @@ -427,3 +432,8 @@ def _build_rnn(self, rnn_type, input_size,
def _input_size(self):
"""Using input feed by concatenating input with attention vectors."""
return self.embeddings.embedding_size + self.hidden_size

def update_dropout(self, dropout):
self.dropout.p = dropout
self.rnn.dropout.p = dropout
self.embeddings.update_dropout(dropout)
13 changes: 12 additions & 1 deletion onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,

return output, attn

def update_dropout(self, dropout):
self.self_attn.update_dropout(dropout)
self.context_attn.update_dropout(dropout)
self.feed_forward.update_dropout(dropout)
self.drop.p = dropout


class TransformerDecoder(DecoderBase):
"""The Transformer decoder from "Attention is All You Need".
Expand Down Expand Up @@ -151,7 +157,7 @@ def from_opt(cls, opt, embeddings):
opt.transformer_ff,
opt.copy_attn,
opt.self_attn_type,
opt.dropout,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
embeddings,
opt.max_relative_positions)

Expand Down Expand Up @@ -232,3 +238,8 @@ def _init_cache(self, memory_bank):
layer_cache["self_keys"] = None
layer_cache["self_values"] = None
self.state["cache"]["layer_{}".format(i)] = layer_cache

def update_dropout(self, dropout):
self.embeddings.update_dropout(dropout)
for layer in self.transformer_layers:
layer.update_dropout(dropout)
15 changes: 11 additions & 4 deletions onmt/encoders/audio_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def __init__(self, rnn_type, enc_layers, dec_layers, brnn,
enc_pooling = [int(p) for p in enc_pooling]
self.enc_pooling = enc_pooling

if dropout > 0:
self.dropout = nn.Dropout(dropout)
if type(dropout) is not list:
dropout = [dropout]
if max(dropout) > 0:
self.dropout = nn.Dropout(dropout[0])
else:
self.dropout = None
self.W = nn.Linear(enc_rnn_size, dec_rnn_size, bias=False)
Expand All @@ -62,7 +64,7 @@ def __init__(self, rnn_type, enc_layers, dec_layers, brnn,
input_size=input_size,
hidden_size=enc_rnn_size_real,
num_layers=1,
dropout=dropout,
dropout=dropout[0],
bidirectional=brnn)
self.pool_0 = nn.MaxPool1d(enc_pooling[0])
for l in range(enc_layers - 1):
Expand All @@ -72,7 +74,7 @@ def __init__(self, rnn_type, enc_layers, dec_layers, brnn,
input_size=enc_rnn_size,
hidden_size=enc_rnn_size_real,
num_layers=1,
dropout=dropout,
dropout=dropout[0],
bidirectional=brnn)
setattr(self, 'rnn_%d' % (l + 1), rnn)
setattr(self, 'pool_%d' % (l + 1),
Expand Down Expand Up @@ -137,3 +139,8 @@ def forward(self, src, lengths=None):
else:
encoder_final = state
return encoder_final, memory_bank, orig_lengths.new_tensor(lengths)

def update_dropout(self, dropout):
self.dropout.p = dropout
for i in range(self.enc_layers - 1):
getattr(self, 'rnn_%d' % i).dropout = dropout
5 changes: 4 additions & 1 deletion onmt/encoders/cnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def from_opt(cls, opt, embeddings):
opt.enc_layers,
opt.enc_rnn_size,
opt.cnn_kernel_width,
opt.dropout,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
embeddings)

def forward(self, input, lengths=None, hidden=None):
Expand All @@ -50,3 +50,6 @@ def forward(self, input, lengths=None, hidden=None):

return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \
out.squeeze(3).transpose(0, 1).contiguous(), lengths

def update_dropout(self, dropout):
self.cnn.dropout.p = dropout
6 changes: 5 additions & 1 deletion onmt/encoders/image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, num_layers, bidirectional, rnn_size, dropout,
self.batch_norm3 = nn.BatchNorm2d(512)

src_size = 512
dropout = dropout[0] if type(dropout) is list else dropout
self.rnn = nn.LSTM(src_size, int(rnn_size / self.num_directions),
num_layers=num_layers,
dropout=dropout,
Expand All @@ -61,7 +62,7 @@ def from_opt(cls, opt, embeddings=None):
opt.enc_layers,
opt.brnn,
opt.enc_rnn_size,
opt.dropout,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
image_channel_size
)

Expand Down Expand Up @@ -125,3 +126,6 @@ def forward(self, src, lengths=None):
out = torch.cat(all_outputs, 0)

return hidden_t, out, lengths

def update_dropout(self, dropout):
self.rnn.dropout = dropout
5 changes: 4 additions & 1 deletion onmt/encoders/rnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def from_opt(cls, opt, embeddings):
opt.brnn,
opt.enc_layers,
opt.enc_rnn_size,
opt.dropout,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
embeddings,
opt.bridge)

Expand Down Expand Up @@ -113,3 +113,6 @@ def bottle_hidden(linear, states):
else:
outs = bottle_hidden(self.bridge[0], hidden)
return outs

def update_dropout(self, dropout):
self.rnn.dropout = dropout
12 changes: 11 additions & 1 deletion onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def forward(self, inputs, mask):
out = self.dropout(context) + inputs
return self.feed_forward(out)

def update_dropout(self, dropout):
self.self_attn.update_dropout(dropout)
self.feed_forward.update_dropout(dropout)
self.dropout.p = dropout


class TransformerEncoder(EncoderBase):
"""The Transformer encoder from "Attention is All You Need"
Expand Down Expand Up @@ -102,7 +107,7 @@ def from_opt(cls, opt, embeddings):
opt.enc_rnn_size,
opt.heads,
opt.transformer_ff,
opt.dropout,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
embeddings,
opt.max_relative_positions)

Expand All @@ -123,3 +128,8 @@ def forward(self, src, lengths=None):
out = self.layer_norm(out)

return emb, out.transpose(0, 1).contiguous(), lengths

def update_dropout(self, dropout):
self.embeddings.update_dropout(dropout)
for layer in self.transformer:
layer.update_dropout(dropout)
2 changes: 1 addition & 1 deletion onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def build_embeddings(opt, text_field, for_encoder=True):
feat_merge=opt.feat_merge,
feat_vec_exponent=opt.feat_vec_exponent,
feat_vec_size=opt.feat_vec_size,
dropout=opt.dropout,
dropout=opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
word_padding_idx=word_padding_idx,
feat_padding_idx=feat_pad_indices,
word_vocab_size=num_word_embeddings,
Expand Down
4 changes: 4 additions & 0 deletions onmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ def forward(self, src, tgt, lengths, bptt=False):
dec_out, attns = self.decoder(tgt, memory_bank,
memory_lengths=lengths)
return dec_out, attns

def update_dropout(self, dropout):
self.encoder.update_dropout(dropout)
self.decoder.update_dropout(dropout)
4 changes: 4 additions & 0 deletions onmt/modules/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,7 @@ def forward(self, source, step=None):
source = self.make_embedding(source)

return source

def update_dropout(self, dropout):
if self.position_encoding:
self._modules['make_embedding'][1].dropout.p = dropout
3 changes: 3 additions & 0 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,6 @@ def unshape(x):
.contiguous()

return output, top_attn

def update_dropout(self, dropout):
self.dropout.p = dropout
4 changes: 4 additions & 0 deletions onmt/modules/position_ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@ def forward(self, x):
inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x))))
output = self.dropout_2(self.w_2(inter))
return output + x

def update_dropout(self, dropout):
self.dropout_1.p = dropout
self.dropout_2.p = dropout
4 changes: 3 additions & 1 deletion onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,10 @@ def train_opts(parser):
help="If the norm of the gradient vector exceeds this, "
"renormalize it to have the norm equal to "
"max_grad_norm")
group.add('--dropout', '-dropout', type=float, default=0.3,
group.add('--dropout', '-dropout', type=float, default=[0.3], nargs='+',
help="Dropout probability; applied in LSTM stacks.")
group.add('--dropout_steps', '-dropout_steps', type=int, nargs='+',
default=[0], help="Steps at which dropout changes.")
group.add('--truncated_decoder', '-truncated_decoder', type=int, default=0,
help="""Truncated bptt.""")
group.add('--adam_beta1', '-adam_beta1', type=float, default=0.9,
Expand Down
20 changes: 18 additions & 2 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None):
n_gpu = opt.world_size
average_decay = opt.average_decay
average_every = opt.average_every
dropout = opt.dropout
dropout_steps = opt.dropout_steps
if device_id >= 0:
gpu_rank = opt.gpu_ranks[device_id]
else:
Expand All @@ -67,7 +69,9 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None):
average_decay=average_decay,
average_every=average_every,
model_dtype=opt.model_dtype,
earlystopper=earlystopper)
earlystopper=earlystopper,
dropout=dropout,
dropout_steps=dropout_steps)
return trainer


Expand Down Expand Up @@ -104,7 +108,7 @@ def __init__(self, model, train_loss, valid_loss, optim,
n_gpu=1, gpu_rank=1,
gpu_verbose_level=0, report_manager=None, model_saver=None,
average_decay=0, average_every=1, model_dtype='fp32',
earlystopper=None):
earlystopper=None, dropout=[0.3], dropout_steps=[0]):
# Basic attributes.
self.model = model
self.train_loss = train_loss
Expand All @@ -126,6 +130,8 @@ def __init__(self, model, train_loss, valid_loss, optim,
self.average_every = average_every
self.model_dtype = model_dtype
self.earlystopper = earlystopper
self.dropout = dropout
self.dropout_steps = dropout_steps

for i in range(len(self.accum_count_l)):
assert self.accum_count_l[i] > 0
Expand All @@ -143,6 +149,13 @@ def _accum_count(self, step):
_accum = self.accum_count_l[i]
return _accum

def _maybe_update_dropout(self, step):
for i in range(len(self.dropout_steps)):
if step > 1 and step == self.dropout_steps[i] + 1:
self.model.update_dropout(self.dropout[i])
logger.info("Updated dropout to %f from step %d"
% (self.dropout[i], step))

def _accum_batches(self, iterator):
batches = []
normalization = 0
Expand Down Expand Up @@ -216,6 +229,9 @@ def train(self,
self._accum_batches(train_iter)):
step = self.optim.training_step

# UPDATE DROPOUT
self._maybe_update_dropout(step)

if self.gpu_verbose_level > 1:
logger.info("GpuRank %d: index: %d", self.gpu_rank, i)
if self.gpu_verbose_level > 0:
Expand Down
4 changes: 4 additions & 0 deletions onmt/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def validate_train_opts(cls, opt):
"-epochs is deprecated please use -train_steps.")
if opt.truncated_decoder > 0 and max(opt.accum_count) > 1:
raise AssertionError("BPTT is not compatible with -accum > 1")

if opt.gpuid:
raise AssertionError(
"gpuid is deprecated see world_size and gpu_ranks")
Expand All @@ -100,6 +101,9 @@ def validate_train_opts(cls, opt):
"-gpu_ranks should have master(=0) rank "
"unless -world_size is greater than len(gpu_ranks).")

assert len(opt.dropout) == len(opt.dropout_steps), \
"Number of dropout values must match number of accum_steps"

@classmethod
def validate_translate_opts(cls, opt):
if opt.beam_size != 1 and opt.random_sampling_topk != 1:
Expand Down

0 comments on commit 607c091

Please sign in to comment.