Skip to content

Commit

Permalink
tgt/src ratio based beam stopping condition (OpenNMT#1344)
Browse files Browse the repository at this point in the history
* tgt/src ratio based beam stopping condition
  • Loading branch information
vince62s authored Mar 8, 2019
1 parent 5f08809 commit 9b3083f
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
2 changes: 2 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,8 @@ def translate_opts(parser):
group.add('--length_penalty', '-length_penalty', default='none',
choices=['none', 'wu', 'avg'],
help="Length Penalty to use.")
group.add('--ratio', '-ratio', type=float, default=-0.,
help="Ratio based beam stop condition")
group.add('--coverage_penalty', '-coverage_penalty', default='none',
choices=['none', 'wu', 'summary'],
help="Coverage Penalty to use.")
Expand Down
16 changes: 8 additions & 8 deletions onmt/tests/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_advance_with_all_repeats_gets_blocked(self):
beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(), 0, 30,
False, ngram_repeat, set(),
torch.randint(0, 30, (batch_sz,)), False)
torch.randint(0, 30, (batch_sz,)), False, 0.)
for i in range(ngram_repeat + 4):
# predict repeat_idx over and over again
word_probs = torch.full(
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_advance_with_some_repeats_gets_blocked(self):
beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(), 0, 30,
False, ngram_repeat, set(),
torch.randint(0, 30, (batch_sz,)), False)
torch.randint(0, 30, (batch_sz,)), False, 0.)
for i in range(ngram_repeat + 4):
# non-interesting beams are going to get dummy values
word_probs = torch.full(
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_repeating_excluded_index_does_not_die(self):
beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(), 0, 30,
False, ngram_repeat, {repeat_idx_ignored},
torch.randint(0, 30, (batch_sz,)), False)
torch.randint(0, 30, (batch_sz,)), False, 0.)
for i in range(ngram_repeat + 4):
# non-interesting beams are going to get dummy values
word_probs = torch.full(
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self):
beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(),
min_length, 30, False, 0, set(),
lengths, False)
lengths, False, 0.)
all_attns = []
for i in range(min_length + 4):
# non-interesting beams are going to get dummy values
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(),
min_length, 30, False, 0, set(),
torch.randint(0, 30, (batch_sz,)), False)
torch.randint(0, 30, (batch_sz,)), False, 0.)
for i in range(min_length + 4):
# non-interesting beams are going to get dummy values
word_probs = torch.full(
Expand Down Expand Up @@ -286,7 +286,7 @@ def test_beam_returns_attn_with_correct_length(self):
beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(),
min_length, 30, True, 0, set(),
inp_lens, False)
inp_lens, False, 0.)
for i in range(min_length + 2):
# non-interesting beams are going to get dummy values
word_probs = torch.full(
Expand Down Expand Up @@ -497,7 +497,7 @@ def test_beam_advance_against_known_reference(self):
self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
torch.device("cpu"), GlobalScorerStub(),
0, 30, False, 0, set(),
torch.randint(0, 30, (self.BATCH_SZ,)), False)
torch.randint(0, 30, (self.BATCH_SZ,)), False, 0.)

expected_beam_scores = self.init_step(beam, 1)
expected_beam_scores = self.first_step(beam, expected_beam_scores, 1)
Expand All @@ -515,7 +515,7 @@ def test_beam_advance_against_known_reference(self):
self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
torch.device("cpu"), scorer,
0, 30, False, 0, set(),
torch.randint(0, 30, (self.BATCH_SZ,)), False)
torch.randint(0, 30, (self.BATCH_SZ,)), False, 0.)
expected_beam_scores = self.init_step(beam, 1.)
expected_beam_scores = self.first_step(beam, expected_beam_scores, 3)
expected_beam_scores = self.second_step(beam, expected_beam_scores, 4)
Expand Down
20 changes: 17 additions & 3 deletions onmt/translate/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class BeamSearch(DecodeStrategy):
def __init__(self, beam_size, batch_size, pad, bos, eos, n_best, mb_device,
global_scorer, min_length, max_length, return_attention,
block_ngram_repeat, exclusion_tokens, memory_lengths,
stepwise_penalty):
stepwise_penalty, ratio):
super(BeamSearch, self).__init__(
pad, bos, eos, batch_size, mb_device, beam_size, min_length,
block_ngram_repeat, exclusion_tokens, return_attention,
Expand All @@ -66,12 +66,16 @@ def __init__(self, beam_size, batch_size, pad, bos, eos, n_best, mb_device,
self.beam_size = beam_size
self.n_best = n_best
self.batch_size = batch_size
self.ratio = ratio

# result caching
self.hypotheses = [[] for _ in range(batch_size)]

# beam state
self.top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8)
self.best_scores = torch.full([batch_size], -1e10, dtype=torch.float,
device=mb_device)

self._batch_offset = torch.arange(batch_size, dtype=torch.long)
self._beam_offset = torch.arange(
0, batch_size * beam_size, step=beam_size, dtype=torch.long,
Expand Down Expand Up @@ -213,15 +217,25 @@ def update_finished(self):
finished_hyp = self.is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
for j in finished_hyp:
if self.ratio > 0:
s = self.topk_scores[i, j] / (step + 1)
if self.best_scores[b] < s:
self.best_scores[b] = s
self.hypotheses[b].append((
self.topk_scores[i, j],
predictions[i, j, 1:], # Ignore start_token.
attention[:, i, j, :self._memory_lengths[i]]
if attention is not None else None))
# End condition is the top beam finished and we can return
# n_best hypotheses.
if self.top_beam_finished[i] and len(
self.hypotheses[b]) >= self.n_best:
if self.ratio > 0:
pred_len = self._memory_lengths[i] * self.ratio
finish_flag = ((self.topk_scores[i, 0] / pred_len)
<= self.best_scores[b]) or \
self.is_finished[i].all()
else:
finish_flag = self.top_beam_finished[i] != 0
if finish_flag and len(self.hypotheses[b]) >= self.n_best:
best_hyp = sorted(
self.hypotheses[b], key=lambda x: x[0], reverse=True)
for n, (score, pred, attn) in enumerate(best_hyp):
Expand Down
6 changes: 6 additions & 0 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
n_best=1,
min_length=0,
max_length=100,
ratio=0.,
beam_size=30,
random_sampling_topk=1,
random_sampling_temp=1,
Expand Down Expand Up @@ -135,6 +136,7 @@ def __init__(
self.sample_from_topk = random_sampling_topk

self.min_length = min_length
self.ratio = ratio
self.stepwise_penalty = stepwise_penalty
self.dump_beam = dump_beam
self.block_ngram_repeat = block_ngram_repeat
Expand Down Expand Up @@ -218,6 +220,7 @@ def from_opt(
n_best=opt.n_best,
min_length=opt.min_length,
max_length=opt.max_length,
ratio=opt.ratio,
beam_size=opt.beam_size,
random_sampling_topk=opt.random_sampling_topk,
random_sampling_temp=opt.random_sampling_temp,
Expand Down Expand Up @@ -507,6 +510,7 @@ def translate_batch(self, batch, src_vocabs, attn_debug):
src_vocabs,
self.max_length,
min_length=self.min_length,
ratio=self.ratio,
n_best=self.n_best,
return_attention=attn_debug or self.replace_unk)

Expand Down Expand Up @@ -588,6 +592,7 @@ def _translate_batch(
src_vocabs,
max_length,
min_length=0,
ratio=0.,
n_best=1,
return_attention=False):
# TODO: support these blacklisted features.
Expand Down Expand Up @@ -636,6 +641,7 @@ def _translate_batch(
eos=self._tgt_eos_idx,
bos=self._tgt_bos_idx,
min_length=min_length,
ratio=ratio,
max_length=max_length,
mb_device=mb_device,
return_attention=return_attention,
Expand Down

0 comments on commit 9b3083f

Please sign in to comment.