Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

transformer_fast_decode, beam search: take an optional cache and return it #999

Merged
merged 1 commit into from
Jan 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensor2tensor/layers/latent_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def symbols_to_logits_fn(ids):

initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
length = tf.shape(latents_dense_in)[1]
ids, _ = beam_search.beam_search(
ids, _, _ = beam_search.beam_search(
symbols_to_logits_fn,
initial_ids,
1,
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/models/research/transformer_nat.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def symbols_to_logits_fn(ids):

initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
length = tf.shape(latents_dense_in)[1]
ids, _ = beam_search.beam_search(
ids, _, _ = beam_search.beam_search(
symbols_to_logits_fn,
initial_ids,
beam_size=1,
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/models/research/transformer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def symbols_to_logits_fn(ids):

initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
length = tf.shape(latents_dense_in)[1]
ids, _ = beam_search.beam_search(
ids, _, _ = beam_search.beam_search(
symbols_to_logits_fn, initial_ids, beam_size, length,
vocab_size, alpha=0.0, eos_id=-1, stop_early=False)

Expand Down
15 changes: 9 additions & 6 deletions tensor2tensor/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,8 @@ def fast_decode(encoder_output,
eos_id=beam_search.EOS_ID,
batch_size=None,
force_decode_length=False,
scope_prefix="body/"):
scope_prefix="body/",
cache=None):
"""Given encoder output and a symbols to logits function, does fast decoding.

Implements both greedy and beam search decoding, uses beam search iff
Expand Down Expand Up @@ -859,7 +860,9 @@ def fast_decode(encoder_output,
vars_3d_num_heads = (
hparams.num_heads if hparams.get("attention_variables_3d") else 0)

cache = {
if cache is None:
cache = dict()
cache.update({
"layer_%d" % layer: {
"k":
common_attention.split_heads(
Expand All @@ -870,7 +873,7 @@ def fast_decode(encoder_output,
"f":
tf.zeros([batch_size, 0, hparams.hidden_size]),
} for layer in range(num_layers)
}
})

if encoder_output is not None:
for layer in range(num_layers):
Expand All @@ -894,7 +897,7 @@ def fast_decode(encoder_output,

if beam_size > 1: # Beam Search
initial_ids = sos_id * tf.ones([batch_size], dtype=tf.int32)
decoded_ids, scores = beam_search.beam_search(
decoded_ids, scores, cache = beam_search.beam_search(
symbols_to_logits_fn,
initial_ids,
beam_size,
Expand Down Expand Up @@ -940,7 +943,7 @@ def is_not_finished(i, hit_eos, *_):
hit_eos = tf.fill([batch_size], False)
next_id = sos_id * tf.ones([batch_size, 1], dtype=tf.int64)
initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)
_, _, _, decoded_ids, _, log_prob = tf.while_loop(
_, _, _, decoded_ids, cache, log_prob = tf.while_loop(
is_not_finished,
inner_loop, [
tf.constant(0), hit_eos, next_id, decoded_ids, cache,
Expand All @@ -956,7 +959,7 @@ def is_not_finished(i, hit_eos, *_):
])
scores = log_prob

return {"outputs": decoded_ids, "scores": scores}
return {"outputs": decoded_ids, "scores": scores, "cache": cache}


@registry.register_model
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/utils/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
tf.less(i, decode_length), tf.logical_not(bound_is_met))

(_, alive_seq, alive_log_probs, finished_seq, finished_scores,
finished_flags, _) = tf.while_loop(
finished_flags, states) = tf.while_loop(
_is_finished,
inner_loop, [
tf.constant(0), alive_seq, alive_log_probs, finished_seq,
Expand Down Expand Up @@ -535,4 +535,4 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
finished_scores = tf.where(
tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
return finished_seq, finished_scores
return finished_seq, finished_scores, states
51 changes: 43 additions & 8 deletions tensor2tensor/utils/beam_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def symbols_to_logits(_):
# Just return random logits
return tf.random_uniform((batch_size * beam_size, vocab_size))

final_ids, final_probs = beam_search.beam_search(
final_ids, final_probs, _ = beam_search.beam_search(
symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size,
0.)

Expand Down Expand Up @@ -113,7 +113,7 @@ def symbols_to_logits(ids):
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits

final_ids, final_probs = beam_search.beam_search(
final_ids, final_probs, _ = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
Expand Down Expand Up @@ -144,7 +144,7 @@ def symbols_to_logits(ids):
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits

final_ids, final_probs = beam_search.beam_search(
final_ids, final_probs, _ = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
Expand Down Expand Up @@ -173,7 +173,7 @@ def symbols_to_logits(ids):
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits

final_ids, final_probs = beam_search.beam_search(
final_ids, final_probs, _ = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
Expand Down Expand Up @@ -213,7 +213,7 @@ def symbols_to_logits(ids):
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits

final_ids, final_scores = beam_search.beam_search(
final_ids, final_scores, _ = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
Expand Down Expand Up @@ -256,7 +256,7 @@ def symbols_to_logits(ids):
return logits

# Disable early stopping
final_ids, final_scores = beam_search.beam_search(
final_ids, final_scores, _ = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
Expand Down Expand Up @@ -302,7 +302,7 @@ def symbols_to_logits(ids, _, states):
states["state"] = tf.placeholder_with_default(
states["state"], shape=(None, 1))

final_ids, _ = beam_search.beam_search(
final_ids, _, _ = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
Expand All @@ -319,6 +319,41 @@ def symbols_to_logits(ids, _, states):
except tf.errors.InvalidArgumentError as e:
raise AssertionError(e.message)

def testStatesAfterLoop(self):
batch_size = 1
beam_size = 1
vocab_size = 2
decode_length = 3

initial_ids = tf.constant([0] * batch_size) # GO
probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]])

def symbols_to_logits(ids, _, states):
pos = tf.shape(ids)[1] - 1
logits = tf.to_float(tf.log(probabilities[pos, :]))
states["state"] += 1
return logits, states

states = {
"state": tf.zeros((batch_size, 1)),
}
states["state"] = tf.placeholder_with_default(
states["state"], shape=(None, 1))

_, _, final_states = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
decode_length,
vocab_size,
0.0,
eos_id=1,
states=states)

with self.test_session() as sess:
final_states = sess.run(final_states)
self.assertAllEqual([[1]], final_states["state"])

def testStateBeamTwo(self):
batch_size = 1
beam_size = 2
Expand Down Expand Up @@ -352,7 +387,7 @@ def symbols_to_logits(ids, _, states):
states["state"] = tf.placeholder_with_default(
states["state"], shape=(None, 1))

final_ids, _ = beam_search.beam_search(
final_ids, _, _ = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/utils/t2t_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def symbols_to_logits_fn(ids):
inputs = features["inputs"]
decode_length = (common_layers.shape_list(inputs)[1] +
features.get("decode_length", decode_length))
ids, scores = beam_search.beam_search(
ids, scores, _ = beam_search.beam_search(
symbols_to_logits_fn,
initial_ids,
beam_size,
Expand Down