Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 572f1e7

Browse files
MostafaDehghaniCopybara-Service
authored and
Copybara-Service
committed
internal merge of PR #1194
PiperOrigin-RevId: 219705888
1 parent ee3794b commit 572f1e7

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

tensor2tensor/models/research/universal_transformer_util.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,6 @@ def universal_transformer_encoder(encoder_input,
125125
x, extra_output = universal_transformer_layer(
126126
x, hparams, ffn_unit, attention_unit, pad_remover=pad_remover)
127127

128-
if hparams.get("use_memory_as_last_state", False):
129-
x = extra_output # which is memory
130128
return common_layers.layer_preprocess(x, hparams), extra_output
131129

132130

@@ -251,8 +249,9 @@ def add_vanilla_transformer_layer(x, num_layers):
251249
output, _, extra_output = tf.foldl(
252250
ut_function, tf.range(hparams.num_rec_steps), initializer=initializer)
253251

254-
# This is possible only when we are using lstm as transition function.
255-
if hparams.get("use_memory_as_final_state", False):
252+
# Right now, this is only possible when the transition function is an lstm
253+
if (hparams.recurrence_type == "lstm" and
254+
hparams.get("use_memory_as_final_state", False)):
256255
output = extra_output
257256

258257
if hparams.mix_with_transformer == "after_ut":

0 commit comments

Comments
 (0)