Skip to content

Commit 948f8ef

Browse files
gcampaxkpe
authored andcommitted
transformer_fast_decode, beam search: take an optional cache and return it (tensorflow#999)
Some models, eg. semantic parsing models with copying mechanisms, want to use the output of Transformer for multiple predictions. One way to do so is to modify the symbols_to_logits_fn to generate the additional predictions and save it in the cache dictionary. To do so, though, fast_decode() must allow an externally supplied cache, and must return it to the caller after the loop.
1 parent 9c36ecf commit 948f8ef

File tree

7 files changed

+58
-20
lines changed

7 files changed

+58
-20
lines changed

tensor2tensor/layers/latent_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def symbols_to_logits_fn(ids):
168168

169169
initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
170170
length = tf.shape(latents_dense_in)[1]
171-
ids, _ = beam_search.beam_search(
171+
ids, _, _ = beam_search.beam_search(
172172
symbols_to_logits_fn,
173173
initial_ids,
174174
1,

tensor2tensor/models/research/transformer_nat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def symbols_to_logits_fn(ids):
228228

229229
initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
230230
length = tf.shape(latents_dense_in)[1]
231-
ids, _ = beam_search.beam_search(
231+
ids, _, _ = beam_search.beam_search(
232232
symbols_to_logits_fn,
233233
initial_ids,
234234
beam_size=1,

tensor2tensor/models/research/transformer_vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def symbols_to_logits_fn(ids):
290290

291291
initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
292292
length = tf.shape(latents_dense_in)[1]
293-
ids, _ = beam_search.beam_search(
293+
ids, _, _ = beam_search.beam_search(
294294
symbols_to_logits_fn, initial_ids, beam_size, length,
295295
vocab_size, alpha=0.0, eos_id=-1, stop_early=False)
296296

tensor2tensor/models/transformer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,8 @@ def fast_decode(encoder_output,
910910
eos_id=beam_search.EOS_ID,
911911
batch_size=None,
912912
force_decode_length=False,
913-
scope_prefix="body/"):
913+
scope_prefix="body/",
914+
cache=None):
914915
"""Given encoder output and a symbols to logits function, does fast decoding.
915916
916917
Implements both greedy and beam search decoding, uses beam search iff
@@ -957,7 +958,9 @@ def fast_decode(encoder_output,
957958
vars_3d_num_heads = (
958959
hparams.num_heads if hparams.get("attention_variables_3d") else 0)
959960

960-
cache = {
961+
if cache is None:
962+
cache = dict()
963+
cache.update({
961964
"layer_%d" % layer: {
962965
"k":
963966
common_attention.split_heads(
@@ -966,7 +969,7 @@ def fast_decode(encoder_output,
966969
common_attention.split_heads(
967970
tf.zeros([batch_size, 0, value_channels]), hparams.num_heads),
968971
} for layer in range(num_layers)
969-
}
972+
})
970973

971974
# If `ffn_layer` is in `["dense_relu_dense" or "conv_hidden_relu"]`, then the
972975
# cache key "f" won't be used, which means that the` shape of cache["f"]`
@@ -1000,7 +1003,7 @@ def fast_decode(encoder_output,
10001003

10011004
if beam_size > 1: # Beam Search
10021005
initial_ids = sos_id * tf.ones([batch_size], dtype=tf.int32)
1003-
decoded_ids, scores = beam_search.beam_search(
1006+
decoded_ids, scores, cache = beam_search.beam_search(
10041007
symbols_to_logits_fn,
10051008
initial_ids,
10061009
beam_size,
@@ -1047,7 +1050,7 @@ def is_not_finished(i, hit_eos, *_):
10471050
hit_eos = tf.fill([batch_size], False)
10481051
next_id = sos_id * tf.ones([batch_size, 1], dtype=tf.int64)
10491052
initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)
1050-
_, _, _, decoded_ids, _, log_prob = tf.while_loop(
1053+
_, _, _, decoded_ids, cache, log_prob = tf.while_loop(
10511054
is_not_finished,
10521055
inner_loop, [
10531056
tf.constant(0), hit_eos, next_id, decoded_ids, cache,
@@ -1063,7 +1066,7 @@ def is_not_finished(i, hit_eos, *_):
10631066
])
10641067
scores = log_prob
10651068

1066-
return {"outputs": decoded_ids, "scores": scores}
1069+
return {"outputs": decoded_ids, "scores": scores, "cache": cache}
10671070

10681071

10691072
@registry.register_model

tensor2tensor/utils/beam_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
752752
tf.less(i, decode_length), tf.logical_not(bound_is_met))
753753

754754
(_, alive_seq, alive_log_probs, finished_seq, finished_scores,
755-
finished_flags, _) = tf.while_loop(
755+
finished_flags, states) = tf.while_loop(
756756
_is_finished,
757757
inner_loop, [
758758
tf.constant(0), alive_seq, alive_log_probs, finished_seq,
@@ -786,4 +786,4 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
786786
tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
787787
finished_scores = tf.where(
788788
tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
789-
return finished_seq, finished_scores
789+
return finished_seq, finished_scores, states

tensor2tensor/utils/beam_search_test.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def symbols_to_logits(_):
3838
# Just return random logits
3939
return tf.random_uniform((batch_size * beam_size, vocab_size))
4040

41-
final_ids, final_probs = beam_search.beam_search(
41+
final_ids, final_probs, _ = beam_search.beam_search(
4242
symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size,
4343
0.)
4444

@@ -114,7 +114,7 @@ def symbols_to_logits(ids):
114114
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
115115
return logits
116116

117-
final_ids, final_probs = beam_search.beam_search(
117+
final_ids, final_probs, _ = beam_search.beam_search(
118118
symbols_to_logits,
119119
initial_ids,
120120
beam_size,
@@ -145,7 +145,7 @@ def symbols_to_logits(ids):
145145
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
146146
return logits
147147

148-
final_ids, final_probs = beam_search.beam_search(
148+
final_ids, final_probs, _ = beam_search.beam_search(
149149
symbols_to_logits,
150150
initial_ids,
151151
beam_size,
@@ -214,7 +214,7 @@ def symbols_to_logits(ids):
214214
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
215215
return logits
216216

217-
final_ids, final_probs = beam_search.beam_search(
217+
final_ids, final_probs, _ = beam_search.beam_search(
218218
symbols_to_logits,
219219
initial_ids,
220220
beam_size,
@@ -254,7 +254,7 @@ def symbols_to_logits(ids):
254254
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
255255
return logits
256256

257-
final_ids, final_scores = beam_search.beam_search(
257+
final_ids, final_scores, _ = beam_search.beam_search(
258258
symbols_to_logits,
259259
initial_ids,
260260
beam_size,
@@ -297,7 +297,7 @@ def symbols_to_logits(ids):
297297
return logits
298298

299299
# Disable early stopping
300-
final_ids, final_scores = beam_search.beam_search(
300+
final_ids, final_scores, _ = beam_search.beam_search(
301301
symbols_to_logits,
302302
initial_ids,
303303
beam_size,
@@ -343,7 +343,7 @@ def symbols_to_logits(ids, _, states):
343343
states["state"] = tf.placeholder_with_default(
344344
states["state"], shape=(None, 1))
345345

346-
final_ids, _ = beam_search.beam_search(
346+
final_ids, _, _ = beam_search.beam_search(
347347
symbols_to_logits,
348348
initial_ids,
349349
beam_size,
@@ -360,6 +360,41 @@ def symbols_to_logits(ids, _, states):
360360
except tf.errors.InvalidArgumentError as e:
361361
raise AssertionError(e.message)
362362

363+
def testStatesAfterLoop(self):
364+
batch_size = 1
365+
beam_size = 1
366+
vocab_size = 2
367+
decode_length = 3
368+
369+
initial_ids = tf.constant([0] * batch_size) # GO
370+
probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]])
371+
372+
def symbols_to_logits(ids, _, states):
373+
pos = tf.shape(ids)[1] - 1
374+
logits = tf.to_float(tf.log(probabilities[pos, :]))
375+
states["state"] += 1
376+
return logits, states
377+
378+
states = {
379+
"state": tf.zeros((batch_size, 1)),
380+
}
381+
states["state"] = tf.placeholder_with_default(
382+
states["state"], shape=(None, 1))
383+
384+
_, _, final_states = beam_search.beam_search(
385+
symbols_to_logits,
386+
initial_ids,
387+
beam_size,
388+
decode_length,
389+
vocab_size,
390+
0.0,
391+
eos_id=1,
392+
states=states)
393+
394+
with self.test_session() as sess:
395+
final_states = sess.run(final_states)
396+
self.assertAllEqual([[1]], final_states["state"])
397+
363398
def testStateBeamTwo(self):
364399
batch_size = 1
365400
beam_size = 2
@@ -393,7 +428,7 @@ def symbols_to_logits(ids, _, states):
393428
states["state"] = tf.placeholder_with_default(
394429
states["state"], shape=(None, 1))
395430

396-
final_ids, _ = beam_search.beam_search(
431+
final_ids, _, _ = beam_search.beam_search(
397432
symbols_to_logits,
398433
initial_ids,
399434
beam_size,

tensor2tensor/utils/t2t_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def symbols_to_logits_fn(ids, i=None):
891891
inputs = features["inputs"]
892892
decode_length = (common_layers.shape_list(inputs)[1] +
893893
features.get("decode_length", decode_length))
894-
ids, scores = beam_search.beam_search(
894+
ids, scores, _ = beam_search.beam_search(
895895
symbols_to_logits_fn,
896896
initial_ids,
897897
beam_size,

0 commit comments

Comments
 (0)