Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

changed beam decoder stopping condition #965

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tensor2tensor/models/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,6 @@ def testBeamVsFast(self):
beam_res = beam_result.eval()
fast_res = fast_result.eval()

self.assertEqual(fast_res.shape,
(BATCH_SIZE, INPUT_LENGTH + decode_length))
self.assertAllClose(beam_res, fast_res)

def testTransformerWithoutProblem(self):
Expand Down
37 changes: 20 additions & 17 deletions tensor2tensor/utils/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores,
finished_flags, states)

def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
finished_scores, finished_in_finished, unused_states):
finished_scores, unused_finished_in_finished, unused_states):
"""Checking termination condition.

We terminate when we decoded up to decode_length or the lowest scoring item
Expand All @@ -472,30 +472,33 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
alive_log_probs: probabilities of the beams. [batch_size, beam_size]
finished_scores: scores for each of these sequences.
[batch_size, beam_size]
finished_in_finished: finished bools for each of these sequences.
[batch_size, beam_size]

Returns:
Bool.
"""
if not stop_early:
return tf.less(i, decode_length)
max_length_penalty = tf.pow(((5. + tf.to_float(decode_length)) / 6.), alpha)
# The best possible score of the most likely alive sequence.
lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty

# Now to compute the lowest score of a finished sequence in finished
# If the sequence isn't finished, we multiply it's score by 0. since
# scores are all -ve, taking the min will give us the score of the lowest
# finished item.
lowest_score_of_finished_in_finished = tf.reduce_min(
finished_scores * tf.to_float(finished_in_finished), axis=1)
# If none of the sequences have finished, then the min will be 0 and
# we have to replace it by -ve INF if it is. The score of any seq in alive
# will be much higher than -ve INF and the termination condition will not
# be met.
lowest_score_of_finished_in_finished += (
(1. - tf.to_float(tf.reduce_any(finished_in_finished, 1))) * -INF)
if not stop_early:
# by considering the min score (in the top N beams) we ensure that
# the decoder will keep decoding until there is at least one beam
# (in the top N) that can be improved (w.r.t. the alive beams).
# any unfinished beam will have score -INF - thus the min
# will always be -INF if there is at least one unfinished beam -
# which means the bound_is_met condition cannot be true in this case.
lowest_score_of_finished_in_finished = tf.reduce_min(finished_scores)
else:
# by taking the max score we only care about the the first beam;
# as soon as this first beam cannot be beaten from the alive beams
# the beam decoder can stop.
# similarly to the above, if the top beam is not completed, its
# finished_score is -INF, thus it will not activate the
# bound_is_met condition. (i.e., decoder will keep going on).
# note we need to find the max for every sequence eparately - so, we need
# to keep the batch dimension (see axis=1)
lowest_score_of_finished_in_finished = tf.reduce_max(finished_scores,
axis=1)

bound_is_met = tf.reduce_all(
tf.greater(lowest_score_of_finished_in_finished,
Expand Down
44 changes: 42 additions & 2 deletions tensor2tensor/utils/beam_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def symbols_to_logits(ids):
self.assertAllEqual([[[0, 0, 1]]], ids)
self.assertAllClose([[0.7 * 0.6]], np.exp(probs))

def testNotGreedyBeamTwo(self):
def testNotGreedyBeamTwoWithStopEarly(self):
batch_size = 1
beam_size = 2
vocab_size = 3
Expand All @@ -151,11 +151,51 @@ def symbols_to_logits(ids):
decode_length,
vocab_size,
0.0,
eos_id=1)
eos_id=1,
stop_early=True) # defaul value, but just to make this explicit

with self.test_session():
ids = final_ids.eval()
probs = final_probs.eval()
# given stop_early = True, the only 'assurance' is w.r.t. the first beam
# (i.e., other beams may not even be completed)
# so, we check only the first beam
first_beam = ids[:, 0]
first_probs = probs[:, 0]
self.assertAllEqual([[0, 2, 1]], first_beam)
self.assertAllClose([0.8 * 0.5], np.exp(first_probs))

def testNotGreedyBeamTwoWithoutStopEarly(self):
batch_size = 1
beam_size = 2
vocab_size = 3
decode_length = 3

initial_ids = tf.constant([0] * batch_size) # GO
probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
[[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
[[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

def symbols_to_logits(ids):
pos = tf.shape(ids)[1]
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits

final_ids, final_probs = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
decode_length,
vocab_size,
0.0,
eos_id=1,
stop_early=False)

with self.test_session():
ids = final_ids.eval()
probs = final_probs.eval()
# given stop_early = False, the algorithm will return all the beams
# so we can test all of them here
self.assertAllEqual([[[0, 2, 1, 0], [0, 2, 0, 1]]], ids)
self.assertAllClose([[0.8 * 0.5, 0.8 * 0.4 * 0.9]], np.exp(probs))

Expand Down