Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 441ff47

Browse files
afrozenatorMesh TensorFlow Team
authored and
Mesh TensorFlow Team
committed
[MTF] Maybe context.length_dim doesn't exist outside context.train?
PiperOrigin-RevId: 361157130
1 parent 44928c0 commit 441ff47

File tree

6 files changed

+34
-21
lines changed

6 files changed

+34
-21
lines changed

mesh_tensorflow/transformer/evolved_transformer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,10 @@ def _pad_channels_dim(tensor, size):
228228

229229

230230
def _dropout(x, context, dropout_rate):
231-
return mtf.dropout(
232-
x, context.train,
233-
rate=dropout_rate,
234-
noise_shape=mtf.Shape(context.batch_dims + x.shape.dims[-1:]))
231+
if context.train and dropout_rate > 0:
232+
return mtf.dropout(
233+
x, context.train,
234+
rate=dropout_rate,
235+
noise_shape=mtf.Shape(context.batch_dims + x.shape.dims[-1:]))
236+
else:
237+
return x

mesh_tensorflow/transformer/fixup_layers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,9 +316,10 @@ def call(self, context, x, losses=None):
316316
name="wi",
317317
kernel_initializer=self.upproject_initializer,
318318
expert_dims=context.model.ensemble_dims)
319-
h = mtf.dropout(
320-
h, context.train, 1.0 - self.dropout_rate,
321-
noise_shape=h.shape - context.length_dim)
319+
if context.train and self.dropout_rate != 0.0:
320+
h = mtf.dropout(
321+
h, context.train, 1.0 - self.dropout_rate,
322+
noise_shape=h.shape - context.length_dim)
322323
shift = get_single_scalar_bias(x, "shift")
323324
h_res = mtf.add(h, shift)
324325
h = mtf.reshape(h_res, h.shape)

mesh_tensorflow/transformer/transformer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,12 @@ def sublayer_residual(x, layer_stack, context):
587587
@gin.configurable
588588
def sublayer_dropout(x, layer_stack, context, dropout_rate=0.0):
589589
del layer_stack
590-
return mtf.dropout(
591-
x, context.train, rate=dropout_rate,
592-
noise_shape=mtf.Shape(context.batch_dims + [context.model.model_dim]))
590+
if context.train and dropout_rate > 0:
591+
return mtf.dropout(
592+
x, context.train, rate=dropout_rate,
593+
noise_shape=mtf.Shape(context.batch_dims + [context.model.model_dim]))
594+
else:
595+
return x
593596

594597

595598
@gin.configurable

mesh_tensorflow/transformer/transformer_layers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ def call(self, context, x, losses=None):
6464
variable_dtype=context.variable_dtype,
6565
name="wi",
6666
expert_dims=context.model.ensemble_dims)
67-
h = mtf.dropout(h, context.train, 1.0 - self.dropout_rate,
68-
noise_shape=h.shape - context.length_dim)
67+
if context.train and self.dropout_rate != 0.0:
68+
h = mtf.dropout(h, context.train, keep_prob=1.0 - self.dropout_rate,
69+
noise_shape=h.shape - context.length_dim)
6970
return mtf.layers.dense(h, io_channels,
7071
use_bias=self.use_bias,
7172
activation=None,

mesh_tensorflow/transformer/universal_transformer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -697,10 +697,13 @@ def call(self, context, x):
697697
return x
698698

699699
def _dropout(self, context, x):
700-
return mtf.dropout(
701-
x, context.train,
702-
rate=self._dropout_rate,
703-
noise_shape=mtf.Shape(context.batch_dims + [context.model.model_dim]))
700+
if context.train and self._dropout_rate > 0:
701+
return mtf.dropout(
702+
x, context.train,
703+
rate=self._dropout_rate,
704+
noise_shape=mtf.Shape(context.batch_dims + [context.model.model_dim]))
705+
else:
706+
return x
704707

705708
def _layer_norm(self, context, x, name=None):
706709
"""Layer normalization.

mesh_tensorflow/transformer/vocab_embeddings.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -509,11 +509,13 @@ def _sigmoid_tree(self, tensor):
509509
], self._gates_dim.name)
510510

511511
def _dropout(self, tensor, context):
512-
return mtf.dropout(
513-
tensor,
514-
context.train,
515-
1.0 - self._dropout_rate,
516-
noise_shape=tensor.shape - context.length_dim)
512+
if context.train and self._dropout_rate != 0.0:
513+
return mtf.dropout(
514+
tensor,
515+
context.train,
516+
1.0 - self._dropout_rate,
517+
noise_shape=tensor.shape - context.length_dim)
518+
return tensor
517519

518520
def _rearrange_sentinels(self, logits):
519521
"""Reorder along the vocab dim so the last few tokens don't share gates."""

0 commit comments

Comments
 (0)