Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit abca210

Browse files
gcampaxCopybara-Service
authored and
Copybara-Service
committed
internal merge of PR #999
PiperOrigin-RevId: 228622817
1 parent cef5491 commit abca210

File tree

3 files changed

+18
-14
lines changed

3 files changed

+18
-14
lines changed

tensor2tensor/models/transformer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class Transformer(t2t_model.T2TModel):
5757

5858
def __init__(self, *args, **kwargs):
5959
super(Transformer, self).__init__(*args, **kwargs)
60-
self.attention_weights = dict() # For visualizing attention heads.
60+
self.attention_weights = {} # For visualizing attention heads.
6161

6262
def encode(self, inputs, target_space, hparams, features=None, losses=None):
6363
"""Encode transformer inputs.
@@ -824,7 +824,7 @@ def fast_decode_tpu(encoder_output,
824824
hparams=hparams)
825825
if beam_size > 1: # Beam Search
826826
initial_ids = sos_id * tf.ones([batch_size], dtype=tf.int32)
827-
decoded_ids, scores = beam_search.beam_search(
827+
decoded_ids, scores, _ = beam_search.beam_search(
828828
symbols_to_logits_fn,
829829
initial_ids,
830830
beam_size,
@@ -936,6 +936,7 @@ def fast_decode(encoder_output,
936936
force_decode_length: bool, whether to force the full decode length, or if
937937
False, stop when all beams hit eos_id.
938938
scope_prefix: str, prefix for decoder layer variable scopes.
939+
cache: cache dictionary for additional predictions.
939940
940941
Returns:
941942
A dict of decoding results {
@@ -959,7 +960,7 @@ def fast_decode(encoder_output,
959960
hparams.num_heads if hparams.get("attention_variables_3d") else 0)
960961

961962
if cache is None:
962-
cache = dict()
963+
cache = {}
963964
cache.update({
964965
"layer_%d" % layer: {
965966
"k":

tensor2tensor/utils/beam_search.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,13 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
751751
return tf.logical_and(
752752
tf.less(i, decode_length), tf.logical_not(bound_is_met))
753753

754+
inner_shape = tf.TensorShape([None, None, None])
755+
if use_tpu:
756+
inner_shape = tf.TensorShape([batch_size, beam_size, decode_length + 1])
757+
if use_tpu:
758+
state_struc = nest.map_structure(lambda state: state.get_shape(), states)
759+
else:
760+
state_struc = nest.map_structure(get_state_shape_invariants, states)
754761
(_, alive_seq, alive_log_probs, finished_seq, finished_scores,
755762
finished_flags, states) = tf.while_loop(
756763
_is_finished,
@@ -760,16 +767,12 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
760767
],
761768
shape_invariants=[
762769
tf.TensorShape([]),
763-
(tf.TensorShape([batch_size, beam_size, decode_length + 1])
764-
if use_tpu else tf.TensorShape([None, None, None])),
770+
inner_shape,
765771
alive_log_probs.get_shape(),
766-
(tf.TensorShape([batch_size, beam_size, decode_length + 1])
767-
if use_tpu else tf.TensorShape([None, None, None])),
772+
inner_shape,
768773
finished_scores.get_shape(),
769774
finished_flags.get_shape(),
770-
(nest.map_structure(lambda state: state.get_shape(), states)
771-
if use_tpu else
772-
nest.map_structure(get_state_shape_invariants, states)),
775+
state_struc
773776
],
774777
parallel_iterations=1,
775778
back_prop=False)

tensor2tensor/utils/beam_search_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def symbols_to_logits(ids):
182182
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
183183
return logits
184184

185-
final_ids, final_probs = beam_search.beam_search(
185+
final_ids, final_probs, _ = beam_search.beam_search(
186186
symbols_to_logits,
187187
initial_ids,
188188
beam_size,
@@ -390,10 +390,10 @@ def symbols_to_logits(ids, _, states):
390390
0.0,
391391
eos_id=1,
392392
states=states)
393-
393+
394394
with self.test_session() as sess:
395395
final_states = sess.run(final_states)
396-
self.assertAllEqual([[1]], final_states["state"])
396+
self.assertAllEqual([[[2]]], final_states["state"])
397397

398398
def testStateBeamTwo(self):
399399
batch_size = 1
@@ -476,7 +476,7 @@ def symbols_to_logits(_, i, states):
476476
states["state"] = tf.placeholder_with_default(
477477
states["state"], shape=(None, 1))
478478

479-
final_ids, _ = beam_search.beam_search(
479+
final_ids, _, _ = beam_search.beam_search(
480480
symbols_to_logits,
481481
initial_ids,
482482
beam_size,

0 commit comments

Comments
 (0)