Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit bc1268d

Browse files
rllin-fathomCopybara-Service
authored and
Copybara-Service
committed
internal merge of PR #1192
PiperOrigin-RevId: 219664613
1 parent cb655f0 commit bc1268d

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tensor2tensor/models/research/universal_transformer_util.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def universal_transformer_encoder(encoder_input,
125125
x, extra_output = universal_transformer_layer(
126126
x, hparams, ffn_unit, attention_unit, pad_remover=pad_remover)
127127

128+
if hparams.get("use_memory_as_last_state", False):
129+
x = extra_output # which is memory
128130
return common_layers.layer_preprocess(x, hparams), extra_output
129131

130132

@@ -249,9 +251,8 @@ def add_vanilla_transformer_layer(x, num_layers):
249251
output, _, extra_output = tf.foldl(
250252
ut_function, tf.range(hparams.num_rec_steps), initializer=initializer)
251253

252-
# Right now, this is only possible when the transition function is an lstm
253-
if (hparams.recurrence_type == "lstm" and
254-
hparams.get("use_memory_as_final_state", False)):
254+
# This is possible only when we are using lstm as transition function.
255+
if hparams.get("use_memory_as_final_state", False):
255256
output = extra_output
256257

257258
if hparams.mix_with_transformer == "after_ut":

0 commit comments

Comments
 (0)