Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit da43c27

Browse files
William FedusMesh TensorFlow Team
William Fedus
authored and
Mesh TensorFlow Team
committed
Option to use mtf.Print to log which tokens are sent to which experts when run on CPU.
PiperOrigin-RevId: 368137313
1 parent 57ed401 commit da43c27

File tree

3 files changed

+85
-11
lines changed

3 files changed

+85
-11
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def __init__(self,
6565
word_embed_mode=None,
6666
use_second_place_expert_prob=None,
6767
use_second_place_expert_prob_temp=None,
68-
top_n_num_experts_per_token=3):
68+
top_n_num_experts_per_token=3,
69+
token_logging=False):
6970
self._hparams = HParams(
7071
moe_gating=moe_gating,
7172
moe_num_experts=num_experts,
@@ -97,6 +98,7 @@ def __init__(self,
9798
use_second_place_expert_prob_temp),
9899
moe_top_n_num_experts_per_token=top_n_num_experts_per_token)
99100
self._activation = activation
101+
self.token_logging = token_logging
100102

101103
def call(self, context, x, losses=None):
102104
"""Call the layer."""
@@ -116,7 +118,13 @@ def call(self, context, x, losses=None):
116118
output_dim = self._hparams.moe_output_dim
117119
else:
118120
output_dim = context.model.model_dim
119-
y, loss = transformer_moe_layer_v1(
121+
if self.token_logging:
122+
tokens = _detokenize(context.inputs, context.model.vocabulary)
123+
x = mtf.Print(x, [tokens], "tokens:", summarize=1000)
124+
extras = _windows(context.inputs, context.length_dim)
125+
else:
126+
extras = None
127+
y, loss, extras = transformer_moe_layer_v1(
120128
x,
121129
output_dim,
122130
self._hparams,
@@ -127,7 +135,16 @@ def call(self, context, x, losses=None):
127135
nonpadding=context.nonpadding,
128136
activation=self._activation,
129137
num_microbatches=context.num_microbatches,
130-
token_embeddings=context.input_embeddings)
138+
token_embeddings=context.input_embeddings,
139+
extras=extras)
140+
141+
if extras:
142+
extras = _detokenize(extras, context.model.vocabulary)
143+
experts_dim = mtf.Dimension("experts", self._hparams.moe_num_experts)
144+
extras = mtf.unstack(extras, experts_dim)
145+
for i, t in enumerate(extras):
146+
y = mtf.Print(y, [t], "EXPERT %s:" % i, summarize=1000)
147+
131148
if context.losses is not None:
132149
context.losses.append(loss)
133150
if not has_length_dim:
@@ -139,6 +156,23 @@ def call(self, context, x, losses=None):
139156
return y
140157

141158

159+
@gin.configurable
160+
def _windows(ids, length_dim, window_start=0, window_end=0):
161+
to_stack = []
162+
for offset in range(window_start, window_end + 1):
163+
to_stack.append(mtf.shift(ids, -offset, length_dim, wrap=False))
164+
return mtf.stack(to_stack, "window", axis=ids.shape.ndims)
165+
166+
167+
def _detokenize(ids, vocabulary):
168+
return mtf.slicewise(
169+
vocabulary.decode_tf,
170+
[ids],
171+
output_shape=mtf.Shape(ids.shape.dims[:-1]),
172+
output_dtype=tf.string,
173+
splittable_dims=ids.shape.dims[:-1])
174+
175+
142176
class MoE2D(transformer.TransformerLayer):
143177
"""Mixture of Experts Layer."""
144178

@@ -202,7 +236,7 @@ def call(self, context, x, losses=None):
202236
def transformer_moe_layer_v1(
203237
inputs, output_dim, hparams, train, variable_dtype,
204238
layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
205-
num_microbatches=None, token_embeddings=None):
239+
num_microbatches=None, token_embeddings=None, extras=None):
206240
"""Local mixture of experts that works well on TPU.
207241
208242
Adapted from the paper https://arxiv.org/abs/1701.06538
@@ -281,6 +315,7 @@ def transformer_moe_layer_v1(
281315
[batch_dim(s), length_dim, input_dim]. These are the word embeddings for
282316
that correspond to the inputs. These can optionally be used to make
283317
routing decisions.
318+
extras: a tensor to dispatch (for debugging purposes)
284319
285320
Returns:
286321
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
@@ -344,6 +379,10 @@ def transformer_moe_layer_v1(
344379
# over which those groups are split.
345380
batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
346381
orig_inputs.shape.dims[-1])
382+
383+
if extras:
384+
extras_dims = extras.shape.dims[len(batch_and_length_dims):]
385+
347386
# Hack: we assume that
348387
# "outer_batch" == replication of experts
349388
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
@@ -381,6 +420,11 @@ def transformer_moe_layer_v1(
381420
token_embeddings = mtf.cast(
382421
mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)
383422

423+
if extras:
424+
extras = mtf.reshape(
425+
extras,
426+
[outer_batch_dim, num_groups_dim, group_size_dim] + extras_dims)
427+
384428
# Each sequence sends expert_capacity positions to each expert.
385429
if train:
386430
capacity_factor = hparams.moe_capacity_factor_train
@@ -503,6 +547,17 @@ def transformer_moe_layer_v1(
503547
input_dim
504548
]))
505549

550+
if extras:
551+
extras = mtf.einsum([extras, mtf.cast(dispatch_tensor, extras.dtype)],
552+
mtf.Shape([
553+
outer_batch_dim, experts_dim_unsplit,
554+
num_groups_dim, expert_capacity_dim] + extras_dims))
555+
extras = mtf.reshape(
556+
extras,
557+
mtf.Shape([
558+
outer_batch_dim, experts_dim, batch_dim_unsplit,
559+
expert_capacity_dim] + extras_dims))
560+
506561
# Now feed the expert inputs through the experts.
507562
h = mtf.layers.dense_product(
508563
expert_inputs,
@@ -559,10 +614,15 @@ def _compute_output(hidden, layer_name):
559614
k = _compute_output(k_h, layer_name="k_wo")
560615
outputs.append(q)
561616
outputs.append(k)
562-
return outputs, loss * hparams.moe_loss_coef
617+
return outputs, loss * hparams.moe_loss_coef, None
563618
else:
564619
output = _compute_output(h, layer_name="wo")
565-
return output, loss * hparams.moe_loss_coef
620+
loss *= hparams.moe_loss_coef
621+
622+
if extras:
623+
return output, loss, extras
624+
else:
625+
return output, loss, None
566626

567627

568628
def transformer_moe_layer_v2(

mesh_tensorflow/transformer/transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,8 @@ def __init__(self,
722722
input_full_attention=False,
723723
loss_on_targets_only=False,
724724
loss_denominator=None,
725-
token_dropout_rate=0.0):
725+
token_dropout_rate=0.0,
726+
vocabulary=None):
726727
"""Create a Unitransformer.
727728
728729
Args:
@@ -767,6 +768,7 @@ def __init__(self,
767768
same denominator as was used for the pretraining. This complication
768769
might be avoided by always using loss_denominator = 1.0.
769770
token_dropout_rate: an optional floating point value
771+
vocabulary: an optional vocabularies.Vocabulary
770772
"""
771773
self.layer_stack = layer_stack
772774
self.model_dim = mtf.Dimension("d_model", d_model)
@@ -807,6 +809,7 @@ def __init__(self,
807809
raise ValueError(
808810
"input_full_attention only makes sense with autoregressive")
809811
self.token_dropout_rate = token_dropout_rate
812+
self.vocabulary = vocabulary
810813

811814
@property
812815
def fully_autoregressive(self):

mesh_tensorflow/transformer/utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ def build_model(model_type="bitransformer",
172172
input_vocab_size=gin.REQUIRED,
173173
output_vocab_size=gin.REQUIRED,
174174
layout_rules=None,
175-
mesh_shape=None):
175+
mesh_shape=None,
176+
input_vocabulary=None,
177+
target_vocabulary=None):
176178
"""Build a transformer model.
177179
178180
Currently, four types of models are supported:
@@ -214,15 +216,21 @@ def build_model(model_type="bitransformer",
214216
output_vocab_size: an integer
215217
layout_rules: optional, input to mtf.convert_to_layout_rules
216218
mesh_shape: optional, an input to mtf.convert_to_shape()
219+
input_vocabulary: optional, a vocubalaries.Vocabulary
220+
target_vocabulary: optional, a vocubalaries.Vocabulary
221+
217222
Returns:
218223
a Unitransformer or Bitransformer
219224
"""
220225
if model_type == "bitransformer":
221-
return transformer.make_bitransformer(
226+
ret = transformer.make_bitransformer(
222227
input_vocab_size=input_vocab_size,
223228
output_vocab_size=output_vocab_size,
224229
mesh_shape=mesh_shape,
225230
layout=layout_rules)
231+
ret.encoder.vocabulary = input_vocabulary
232+
ret.decoder.vocabulary = target_vocabulary
233+
return ret
226234
elif model_type == "bi_student_teacher":
227235
return transformer.make_bi_student_teacher(
228236
input_vocab_size=input_vocab_size,
@@ -236,7 +244,8 @@ def build_model(model_type="bitransformer",
236244
input_vocab_size=input_vocab_size,
237245
output_vocab_size=output_vocab_size,
238246
mesh_shape=mesh_shape,
239-
layout=layout_rules)
247+
layout=layout_rules,
248+
vocabulary=input_vocabulary)
240249
else:
241250
raise ValueError("unknown model_type")
242251

@@ -2067,7 +2076,9 @@ def get_estimator(model_type, vocabulary, mesh_shape,
20672076
input_vocab_size=inputs_vocabulary(vocabulary).vocab_size,
20682077
output_vocab_size=targets_vocabulary(vocabulary).vocab_size,
20692078
layout_rules=layout_rules,
2070-
mesh_shape=mesh_shape)
2079+
mesh_shape=mesh_shape,
2080+
input_vocabulary=inputs_vocabulary(vocabulary),
2081+
target_vocabulary=targets_vocabulary(vocabulary))
20712082

20722083
model_fn = tpu_estimator_model_fn(
20732084
model_type=model_type,

0 commit comments

Comments
 (0)