From a3eaeb686bb7ed8b1521243111dc2eb771889b2b Mon Sep 17 00:00:00 2001 From: flauted Date: Tue, 12 Feb 2019 02:53:27 -0500 Subject: [PATCH] Random sampling scores (#1285) * In random sampling, make scores be score of EOS. --- onmt/tests/test_random_sampling.py | 313 +++++++++++++++++++++++++++++ onmt/translate/decode_strategy.py | 4 +- onmt/translate/random_sampling.py | 37 +++- onmt/translate/translator.py | 24 ++- 4 files changed, 365 insertions(+), 13 deletions(-) create mode 100644 onmt/tests/test_random_sampling.py diff --git a/onmt/tests/test_random_sampling.py b/onmt/tests/test_random_sampling.py new file mode 100644 index 0000000000..da3a7116fb --- /dev/null +++ b/onmt/tests/test_random_sampling.py @@ -0,0 +1,313 @@ +import unittest +from onmt.translate.random_sampling import RandomSampling + +import torch + + +class TestRandomSampling(unittest.TestCase): + BATCH_SZ = 3 + INP_SEQ_LEN = 53 + DEAD_SCORE = -1e20 + + BLOCKED_SCORE = -10e20 + + def test_advance_with_repeats_gets_blocked(self): + n_words = 100 + repeat_idx = 47 + ngram_repeat = 3 + for batch_sz in [1, 3]: + samp = RandomSampling( + 0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, set(), + False, 30, 1., 5, torch.randint(0, 30, (batch_sz,))) + for i in range(ngram_repeat + 4): + # predict repeat_idx over and over again + word_probs = torch.full( + (batch_sz, n_words), -float('inf')) + word_probs[:, repeat_idx] = 0 + attns = torch.randn(1, batch_sz, 53) + samp.advance(word_probs, attns) + if i <= ngram_repeat: + expected_scores = torch.zeros((batch_sz, 1)) + self.assertTrue(samp.topk_scores.equal(expected_scores)) + else: + self.assertTrue( + samp.topk_scores.equal( + torch.tensor(self.BLOCKED_SCORE) + .repeat(batch_sz, 1))) + + def test_advance_with_some_repeats_gets_blocked(self): + # batch 0 and 7 will repeat, the rest will advance + n_words = 100 + repeat_idx = 47 + other_repeat_idx = 12 + ngram_repeat = 3 + for batch_sz in [1, 3, 13]: + samp = RandomSampling( + 0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, set(), + False, 30, 1., 5, torch.randint(0, 30, (batch_sz,))) + for i in range(ngram_repeat + 4): + word_probs = torch.full( + (batch_sz, n_words), -float('inf')) + # predict the same thing in batch 0 and 7 every i + word_probs[0, repeat_idx] = 0 + if batch_sz > 7: + word_probs[7, other_repeat_idx] = 0 + # push around what the other batches predict + word_probs[1:7, repeat_idx + i] = 0 + if batch_sz > 7: + word_probs[8:, repeat_idx + i] = 0 + attns = torch.randn(1, batch_sz, 53) + samp.advance(word_probs, attns) + if i <= ngram_repeat: + self.assertFalse( + samp.topk_scores.eq( + self.BLOCKED_SCORE).any()) + else: + # now batch 0 and 7 die + self.assertTrue(samp.topk_scores[0].eq(self.BLOCKED_SCORE)) + if batch_sz > 7: + self.assertTrue(samp.topk_scores[7].eq( + self.BLOCKED_SCORE)) + self.assertFalse( + samp.topk_scores[1:7].eq( + self.BLOCKED_SCORE).any()) + if batch_sz > 7: + self.assertFalse( + samp.topk_scores[8:].eq( + self.BLOCKED_SCORE).any()) + + def test_repeating_excluded_index_does_not_die(self): + # batch 0 will repeat excluded idx, batch 1 will repeat + n_words = 100 + repeat_idx = 47 # will be repeated and should be blocked + repeat_idx_ignored = 7 # will be repeated and should not be blocked + ngram_repeat = 3 + for batch_sz in [1, 3, 17]: + samp = RandomSampling( + 0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, + {repeat_idx_ignored}, False, 30, 1., 5, + torch.randint(0, 30, (batch_sz,))) + for i in range(ngram_repeat + 4): + word_probs = torch.full( + (batch_sz, n_words), -float('inf')) + word_probs[0, repeat_idx_ignored] = 0 + if batch_sz > 1: + word_probs[1, repeat_idx] = 0 + word_probs[2:, repeat_idx + i] = 0 + attns = torch.randn(1, batch_sz, 53) + samp.advance(word_probs, attns) + if i <= ngram_repeat: + self.assertFalse(samp.topk_scores.eq( + self.BLOCKED_SCORE).any()) + else: + # now batch 1 dies + self.assertFalse(samp.topk_scores[0].eq( + self.BLOCKED_SCORE).any()) + if batch_sz > 1: + self.assertTrue(samp.topk_scores[1].eq( + self.BLOCKED_SCORE).all()) + self.assertFalse(samp.topk_scores[2:].eq( + self.BLOCKED_SCORE).any()) + + def test_doesnt_predict_eos_if_shorter_than_min_len(self): + # batch 0 will always predict EOS. The other batches will predict + # non-eos scores. + for batch_sz in [1, 3]: + n_words = 100 + _non_eos_idxs = [47] + valid_score_dist = torch.log_softmax(torch.tensor( + [6., 5.]), dim=0) + min_length = 5 + eos_idx = 2 + lengths = torch.randint(0, 30, (batch_sz,)) + samp = RandomSampling( + 0, 1, 2, batch_sz, torch.device("cpu"), min_length, + False, set(), False, 30, 1., 1, lengths) + all_attns = [] + for i in range(min_length + 4): + word_probs = torch.full( + (batch_sz, n_words), -float('inf')) + # "best" prediction is eos - that should be blocked + word_probs[0, eos_idx] = valid_score_dist[0] + # include at least one prediction OTHER than EOS + # that is greater than -1e20 + word_probs[0, _non_eos_idxs[0]] = valid_score_dist[1] + word_probs[1:, _non_eos_idxs[0] + i] = 0 + + attns = torch.randn(1, batch_sz, 53) + all_attns.append(attns) + samp.advance(word_probs, attns) + if i < min_length: + self.assertTrue( + samp.topk_scores[0].allclose(valid_score_dist[1])) + self.assertTrue( + samp.topk_scores[1:].eq(0).all()) + elif i == min_length: + # now batch 0 has ended and no others have + self.assertTrue(samp.is_finished[0, :].eq(1).all()) + self.assertTrue(samp.is_finished[1:, 1:].eq(0).all()) + else: # i > min_length + break + + def test_returns_correct_scores_deterministic(self): + for batch_sz in [1, 13]: + for temp in [1., 3.]: + n_words = 100 + _non_eos_idxs = [47, 51, 13, 88, 99] + valid_score_dist_1 = torch.log_softmax(torch.tensor( + [6., 5., 4., 3., 2., 1.]), dim=0) + valid_score_dist_2 = torch.log_softmax(torch.tensor( + [6., 1.]), dim=0) + eos_idx = 2 + lengths = torch.randint(0, 30, (batch_sz,)) + samp = RandomSampling( + 0, 1, 2, batch_sz, torch.device("cpu"), 0, + False, set(), False, 30, temp, 1, lengths) + + # initial step + i = 0 + word_probs = torch.full( + (batch_sz, n_words), -float('inf')) + # batch 0 dies on step 0 + word_probs[0, eos_idx] = valid_score_dist_1[0] + # include at least one prediction OTHER than EOS + # that is greater than -1e20 + word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:] + word_probs[1:, _non_eos_idxs[0] + i] = 0 + + attns = torch.randn(1, batch_sz, 53) + samp.advance(word_probs, attns) + self.assertTrue(samp.is_finished[0].eq(1).all()) + samp.update_finished() + self.assertEqual( + samp.scores[0], [valid_score_dist_1[0] / temp]) + if batch_sz == 1: + self.assertTrue(samp.done) + continue + else: + self.assertFalse(samp.done) + + # step 2 + i = 1 + word_probs = torch.full( + (batch_sz - 1, n_words), -float('inf')) + # (old) batch 8 dies on step 1 + word_probs[7, eos_idx] = valid_score_dist_2[0] + word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2 + word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2 + + attns = torch.randn(1, batch_sz, 53) + samp.advance(word_probs, attns) + + self.assertTrue(samp.is_finished[7].eq(1).all()) + samp.update_finished() + self.assertEqual( + samp.scores[8], [valid_score_dist_2[0] / temp]) + + # step 3 + i = 2 + word_probs = torch.full( + (batch_sz - 2, n_words), -float('inf')) + # everything dies + word_probs[:, eos_idx] = 0 + + attns = torch.randn(1, batch_sz, 53) + samp.advance(word_probs, attns) + + self.assertTrue(samp.is_finished.eq(1).all()) + samp.update_finished() + for b in range(batch_sz): + if b != 0 and b != 8: + self.assertEqual(samp.scores[b], [0]) + self.assertTrue(samp.done) + + def test_returns_correct_scores_non_deterministic(self): + for batch_sz in [1, 13]: + for temp in [1., 3.]: + n_words = 100 + _non_eos_idxs = [47, 51, 13, 88, 99] + valid_score_dist_1 = torch.log_softmax(torch.tensor( + [6., 5., 4., 3., 2., 1.]), dim=0) + valid_score_dist_2 = torch.log_softmax(torch.tensor( + [6., 1.]), dim=0) + eos_idx = 2 + lengths = torch.randint(0, 30, (batch_sz,)) + samp = RandomSampling( + 0, 1, 2, batch_sz, torch.device("cpu"), 0, + False, set(), False, 30, temp, 2, lengths) + + # initial step + i = 0 + for _ in range(100): + word_probs = torch.full( + (batch_sz, n_words), -float('inf')) + # batch 0 dies on step 0 + word_probs[0, eos_idx] = valid_score_dist_1[0] + # include at least one prediction OTHER than EOS + # that is greater than -1e20 + word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:] + word_probs[1:, _non_eos_idxs[0] + i] = 0 + + attns = torch.randn(1, batch_sz, 53) + samp.advance(word_probs, attns) + if samp.is_finished[0].eq(1).all(): + break + else: + self.fail("Batch 0 never ended (very unlikely but maybe " + "due to stochasticisty. If so, please increase " + "the range of the for-loop.") + samp.update_finished() + self.assertEqual( + samp.scores[0], [valid_score_dist_1[0] / temp]) + if batch_sz == 1: + self.assertTrue(samp.done) + continue + else: + self.assertFalse(samp.done) + + # step 2 + i = 1 + for _ in range(100): + word_probs = torch.full( + (batch_sz - 1, n_words), -float('inf')) + # (old) batch 8 dies on step 1 + word_probs[7, eos_idx] = valid_score_dist_2[0] + word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2 + word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2 + + attns = torch.randn(1, batch_sz, 53) + samp.advance(word_probs, attns) + if samp.is_finished[7].eq(1).all(): + break + else: + self.fail("Batch 8 never ended (very unlikely but maybe " + "due to stochasticisty. If so, please increase " + "the range of the for-loop.") + + samp.update_finished() + self.assertEqual( + samp.scores[8], [valid_score_dist_2[0] / temp]) + + # step 3 + i = 2 + for _ in range(250): + word_probs = torch.full( + (samp.alive_seq.shape[0], n_words), -float('inf')) + # everything dies + word_probs[:, eos_idx] = 0 + + attns = torch.randn(1, batch_sz, 53) + samp.advance(word_probs, attns) + if samp.is_finished.any(): + samp.update_finished() + if samp.is_finished.eq(1).all(): + break + else: + self.fail("All batches never ended (very unlikely but " + "maybe due to stochasticisty. If so, please " + "increase the range of the for-loop.") + + for b in range(batch_sz): + if b != 0 and b != 8: + self.assertEqual(samp.scores[b], [0]) + self.assertTrue(samp.done) diff --git a/onmt/translate/decode_strategy.py b/onmt/translate/decode_strategy.py index c86a9b7e94..b7e9a3a45a 100644 --- a/onmt/translate/decode_strategy.py +++ b/onmt/translate/decode_strategy.py @@ -31,9 +31,9 @@ class DecodeStrategy(object): predictions (list[list[torch.LongTensor]]): For each batch, holds a list of beam prediction sequences. scores (list[list[torch.FloatTensor]]): For each batch, holds a - list of beam scores. + list of scores. attention (list[list[torch.FloatTensor or list[]]]): For each - batch, holds a list of beam attention sequence tensors + batch, holds a list of attention sequence tensors (or empty lists) having shape ``(step, inp_seq_len)`` where ``inp_seq_len`` is the length of the sample (not the max length of all inp seqs). diff --git a/onmt/translate/random_sampling.py b/onmt/translate/random_sampling.py index 27c155164c..a78c7c24f3 100644 --- a/onmt/translate/random_sampling.py +++ b/onmt/translate/random_sampling.py @@ -32,6 +32,8 @@ def sample_with_temperature(logits, sampling_temp, keep_topk): # For temp=0.0, take the argmax to avoid divide-by-zero errors. # keep_topk=1 is also equivalent to argmax. topk_scores, topk_ids = logits.topk(1, dim=-1) + if sampling_temp > 0: + topk_scores /= sampling_temp else: logits = torch.div(logits, sampling_temp) @@ -55,6 +57,10 @@ def sample_with_temperature(logits, sampling_temp, keep_topk): class RandomSampling(DecodeStrategy): """Select next tokens randomly from the top k possible next tokens. + The ``scores`` attribute's lists are the score, after applying temperature, + of the final prediction (either EOS or the final token in the event + that ``max_length`` is reached) + Args: pad (int): See base. bos (int): See base. @@ -73,11 +79,6 @@ class RandomSampling(DecodeStrategy): masking attention. """ - # NOTE: Currently this class doesn't return "final" scores or any form - # of Pr(EOS|pred). That is to say, the scores returned by RandomSampling - # # are just the scores of the last token (in a batched setting, that - # isn't even necessarily EOS since no early stopping is implemented). - def __init__(self, pad, bos, eos, batch_size, device, min_length, block_ngram_repeat, exclusion_tokens, return_attention, max_length, sampling_temp, keep_topk, @@ -93,6 +94,8 @@ def __init__(self, pad, bos, eos, batch_size, device, self.batch_size = batch_size self.select_indices = torch.arange(self.batch_size, dtype=torch.long, device=device) + self.original_batch_idx = torch.arange(self.batch_size, + dtype=torch.long, device=device) def advance(self, log_probs, attn): """Select next tokens randomly from the top k possible next tokens. @@ -108,9 +111,12 @@ def advance(self, log_probs, attn): """ self.ensure_min_length(log_probs) + self.block_ngram_repeats(log_probs) topk_ids, self.topk_scores = sample_with_temperature( log_probs, self.sampling_temp, self.keep_topk) + self.is_finished = topk_ids.eq(self.eos) + self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1) if self.return_attention: if self.alive_attn is None: @@ -121,10 +127,21 @@ def advance(self, log_probs, attn): def update_finished(self): """Finalize scores and predictions.""" - assert self.is_finished.all() - for b in range(self.batch_size): - self.scores[b].append(self.topk_scores[b, 0]) - self.predictions[b].append(self.alive_seq[b, 1:]) - self.attention[b].append( + # shape: (sum(~ self.is_finished), 1) + finished_batches = self.is_finished.view(-1).nonzero() + for b in finished_batches.view(-1): + b_orig = self.original_batch_idx[b] + self.scores[b_orig].append(self.topk_scores[b, 0]) + self.predictions[b_orig].append(self.alive_seq[b, 1:]) + self.attention[b_orig].append( self.alive_attn[:, b, :self.memory_length[b]] if self.alive_attn is not None else []) + self.done = self.is_finished.all() + if self.done: + return + is_alive = ~self.is_finished.view(-1) + self.alive_seq = self.alive_seq[is_alive] + if self.alive_attn is not None: + self.alive_attn = self.alive_attn[:, is_alive] + self.select_indices = is_alive.nonzero().view(-1) + self.original_batch_idx = self.original_batch_idx[is_alive] diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 9e0069c6a1..d18cdcfc58 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -361,8 +361,30 @@ def _translate_random_sampling( ) random_sampler.advance(log_probs, attn) + any_batch_is_finished = random_sampler.is_finished.any() + if any_batch_is_finished: + random_sampler.update_finished() + if random_sampler.done: + break + + if any_batch_is_finished: + select_indices = random_sampler.select_indices + + # Reorder states. + if isinstance(memory_bank, tuple): + memory_bank = tuple(x.index_select(1, select_indices) + for x in memory_bank) + else: + memory_bank = memory_bank.index_select(1, select_indices) + + memory_lengths = memory_lengths.index_select(0, select_indices) + + if src_map is not None: + src_map = src_map.index_select(1, select_indices) + + self.model.decoder.map_state( + lambda state, dim: state.index_select(dim, select_indices)) - random_sampler.update_finished() results["scores"] = random_sampler.scores results["predictions"] = random_sampler.predictions results["attention"] = random_sampler.attention