Skip to content

Commit

Permalink
Disable attention (#1314)
Browse files Browse the repository at this point in the history
* Allow disabling attention.
* Test no/different copy attn, integrate copy into CNN dec.
* Update docs.
* Add back the 'empty' model test.
  • Loading branch information
flauted authored and vince62s committed Feb 19, 2019
1 parent 518a19d commit 9865443
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 28 deletions.
9 changes: 6 additions & 3 deletions onmt/decoders/cnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class CNNDecoder(DecoderBase):
"""

def __init__(self, num_layers, hidden_size, attn_type,
copy_attn, cnn_kernel_width, dropout, embeddings):
copy_attn, cnn_kernel_width, dropout, embeddings,
copy_attn_type):
super(CNNDecoder, self).__init__()

self.cnn_kernel_width = cnn_kernel_width
Expand All @@ -42,7 +43,8 @@ def __init__(self, num_layers, hidden_size, attn_type,
# Set up a separate copy attention layer if needed.
assert not copy_attn, "Copy mechanism not yet tested in conv2conv"
if copy_attn:
self.copy_attn = GlobalAttention(hidden_size, attn_type=attn_type)
self.copy_attn = GlobalAttention(
hidden_size, attn_type=copy_attn_type)
else:
self.copy_attn = None

Expand All @@ -56,7 +58,8 @@ def from_opt(cls, opt, embeddings):
opt.copy_attn,
opt.cnn_kernel_width,
opt.dropout,
embeddings)
embeddings,
opt.copy_attn_type)

def init_state(self, _, memory_bank, enc_hidden):
"""Init decoder state."""
Expand Down
73 changes: 52 additions & 21 deletions onmt/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@


class DecoderBase(nn.Module):
"""Abstract class for decoders."""
"""Abstract class for decoders.
Args:
attentional (bool): The decoder returns non-empty attention.
"""

def __init__(self, attentional=True):
super(DecoderBase, self).__init__()
self.attentional = attentional

@classmethod
def from_opt(cls, opt, embeddings):
"""Alternate constructor.
Expand Down Expand Up @@ -67,14 +76,17 @@ class RNNDecoderBase(DecoderBase):
dropout (float) : dropout value for :class:`torch.nn.Dropout`
embeddings (onmt.modules.Embeddings): embedding module to use
reuse_copy_attn (bool): reuse the attention for copying
copy_attn_type (str): The copy attention style. See
:class:`~onmt.modules.GlobalAttention`.
"""

def __init__(self, rnn_type, bidirectional_encoder, num_layers,
hidden_size, attn_type="general", attn_func="softmax",
coverage_attn=False, context_gate=None,
copy_attn=False, dropout=0.0, embeddings=None,
reuse_copy_attn=False):
super(RNNDecoderBase, self).__init__()
reuse_copy_attn=False, copy_attn_type="general"):
super(RNNDecoderBase, self).__init__(
attentional=attn_type != "none" and attn_type is not None)

self.bidirectional_encoder = bidirectional_encoder
self.num_layers = num_layers
Expand Down Expand Up @@ -102,19 +114,29 @@ def __init__(self, rnn_type, bidirectional_encoder, num_layers,

# Set up the standard attention.
self._coverage = coverage_attn
self.attn = GlobalAttention(
hidden_size, coverage=coverage_attn,
attn_type=attn_type, attn_func=attn_func
)
if not self.attentional:
if self._coverage:
raise ValueError("Cannot use coverage term with no attention.")
self.attn = None
else:
self.attn = GlobalAttention(
hidden_size, coverage=coverage_attn,
attn_type=attn_type, attn_func=attn_func
)

if copy_attn and not reuse_copy_attn:
if copy_attn_type == "none" or copy_attn_type is None:
raise ValueError(
"Cannot use copy_attn with copy_attn_type none")
self.copy_attn = GlobalAttention(
hidden_size, attn_type=attn_type, attn_func=attn_func
hidden_size, attn_type=copy_attn_type, attn_func=attn_func
)
else:
self.copy_attn = None

self._reuse_copy_attn = reuse_copy_attn and copy_attn
if self._reuse_copy_attn and not self.attentional:
raise ValueError("Cannot reuse copy attention with no attention.")

@classmethod
def from_opt(cls, opt, embeddings):
Expand All @@ -131,7 +153,8 @@ def from_opt(cls, opt, embeddings):
opt.copy_attn,
opt.dropout,
embeddings,
opt.reuse_copy_attn)
opt.reuse_copy_attn,
opt.copy_attn_type)

def init_state(self, src, memory_bank, encoder_final):
"""Initialize decoder state with last state of the encoder."""
Expand Down Expand Up @@ -266,12 +289,15 @@ def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
aeq(tgt_batch, output_batch)

# Calculate the attention.
dec_outs, p_attn = self.attn(
rnn_output.transpose(0, 1).contiguous(),
memory_bank.transpose(0, 1),
memory_lengths=memory_lengths
)
attns["std"] = p_attn
if not self.attentional:
dec_outs = rnn_output
else:
dec_outs, p_attn = self.attn(
rnn_output.transpose(0, 1).contiguous(),
memory_bank.transpose(0, 1),
memory_lengths=memory_lengths
)
attns["std"] = p_attn

# Calculate the context gate.
if self.context_gate is not None:
Expand Down Expand Up @@ -335,7 +361,9 @@ def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
# END Additional args check.

dec_outs = []
attns = {"std": []}
attns = {}
if self.attn is not None:
attns["std"] = []
if self.copy_attn is not None or self._reuse_copy_attn:
attns["copy"] = []
if self._coverage:
Expand All @@ -353,10 +381,14 @@ def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
for emb_t in emb.split(1):
decoder_input = torch.cat([emb_t.squeeze(0), input_feed], 1)
rnn_output, dec_state = self.rnn(decoder_input, dec_state)
decoder_output, p_attn = self.attn(
rnn_output,
memory_bank.transpose(0, 1),
memory_lengths=memory_lengths)
if self.attentional:
decoder_output, p_attn = self.attn(
rnn_output,
memory_bank.transpose(0, 1),
memory_lengths=memory_lengths)
attns["std"].append(p_attn)
else:
decoder_output = rnn_output
if self.context_gate is not None:
# TODO: context gate should be employed
# instead of second RNN transform.
Expand All @@ -367,7 +399,6 @@ def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
input_feed = decoder_output

dec_outs += [decoder_output]
attns["std"] += [p_attn]

# Update the coverage attention.
if self._coverage:
Expand Down
3 changes: 2 additions & 1 deletion onmt/modules/global_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def __init__(self, dim, coverage=False, attn_type="dot",

self.dim = dim
assert attn_type in ["dot", "general", "mlp"], (
"Please select a valid attention type.")
"Please select a valid attention type (got {:s}).".format(
attn_type))
self.attn_type = attn_type
assert attn_func in ["softmax", "sparsemax"], (
"Please select a valid attention function.")
Expand Down
8 changes: 7 additions & 1 deletion onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def model_opts(parser):
# Attention options
group = parser.add_argument_group('Model- Attention')
group.add('--global_attention', '-global_attention',
type=str, default='general', choices=['dot', 'general', 'mlp'],
type=str, default='general',
choices=['dot', 'general', 'mlp', 'none'],
help="The attention type to use: "
"dotprod or general (Luong) or MLP (Bahdanau)")
group.add('--global_attention_function', '-global_attention_function',
Expand All @@ -154,6 +155,11 @@ def model_opts(parser):
# Generator and loss options.
group.add('--copy_attn', '-copy_attn', action="store_true",
help='Train copy attention layer.')
group.add('--copy_attn_type', '-copy_attn_type',
type=str, default=None,
choices=['dot', 'general', 'mlp', 'none'],
help="The copy attention type to use. Leave as None to use "
"the same as -global_attention.")
group.add('--generator_function', '-generator_function', default="softmax",
choices=["softmax", "sparsemax"],
help="Which function to use for generating "
Expand Down
10 changes: 10 additions & 0 deletions onmt/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,16 @@ def test_method(self):
[('encoder_type', "brnn")],
[('decoder_type', 'cnn'),
('encoder_type', 'cnn')],
[('encoder_type', 'rnn'),
('global_attention', None)],
[('encoder_type', 'rnn'),
('global_attention', None),
('copy_attn', True),
('copy_attn_type', 'general')],
[('encoder_type', 'rnn'),
('global_attention', 'mlp'),
('copy_attn', True),
('copy_attn_type', 'general')],
[],
]

Expand Down
3 changes: 2 additions & 1 deletion onmt/translate/decode_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class DecodeStrategy(object):
``block_ngram_repeat``-grams repeat.
exclusion_tokens (set[int]): If a gram contains any of these
tokens, it may repeat.
return_attention (bool): Whether to work with attention too.
return_attention (bool): Whether to work with attention too. If this
is true, it is assumed that the decoder is attentional.
Attributes:
pad (int): See above.
Expand Down
12 changes: 11 additions & 1 deletion onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def __init__(
self.src_reader = src_reader
self.tgt_reader = tgt_reader
self.replace_unk = replace_unk
if self.replace_unk and not self.model.decoder.attentional:
raise ValueError(
"replace_unk requires an attentional decoder.")
self.data_type = data_type
self.verbose = verbose
self.report_bleu = report_bleu
Expand All @@ -153,6 +156,10 @@ def __init__(
self.copy_attn = copy_attn

self.global_scorer = global_scorer
if self.global_scorer.has_cov_pen and \
not self.model.decoder.attentional:
raise ValueError(
"Coverage penalty requires an attentional decoder.")
self.out_file = out_file
self.report_score = report_score
self.logger = logger
Expand Down Expand Up @@ -544,7 +551,10 @@ def _decode_and_generate(

# Generator forward.
if not self.copy_attn:
attn = dec_attn["std"]
if "std" in dec_attn:
attn = dec_attn["std"]
else:
attn = None
log_probs = self.model.generator(dec_out.squeeze(0))
# returns [(batch_size x beam_size) , vocab ] when 1 step
# or [ tgt_len, batch_size, vocab ] when full sentence
Expand Down
3 changes: 3 additions & 0 deletions onmt/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def update_model_opts(cls, model_opt):

model_opt.brnn = model_opt.encoder_type == "brnn"

if model_opt.copy_attn_type is None:
model_opt.copy_attn_type = model_opt.global_attention

@classmethod
def validate_model_opts(cls, model_opt):
assert model_opt.model_type in ["text", "img", "audio"], \
Expand Down

0 comments on commit 9865443

Please sign in to comment.