@@ -65,7 +65,8 @@ def __init__(self,
65
65
word_embed_mode = None ,
66
66
use_second_place_expert_prob = None ,
67
67
use_second_place_expert_prob_temp = None ,
68
- top_n_num_experts_per_token = 3 ):
68
+ top_n_num_experts_per_token = 3 ,
69
+ token_logging = False ):
69
70
self ._hparams = HParams (
70
71
moe_gating = moe_gating ,
71
72
moe_num_experts = num_experts ,
@@ -97,6 +98,7 @@ def __init__(self,
97
98
use_second_place_expert_prob_temp ),
98
99
moe_top_n_num_experts_per_token = top_n_num_experts_per_token )
99
100
self ._activation = activation
101
+ self .token_logging = token_logging
100
102
101
103
def call (self , context , x , losses = None ):
102
104
"""Call the layer."""
@@ -116,7 +118,13 @@ def call(self, context, x, losses=None):
116
118
output_dim = self ._hparams .moe_output_dim
117
119
else :
118
120
output_dim = context .model .model_dim
119
- y , loss = transformer_moe_layer_v1 (
121
+ if self .token_logging :
122
+ tokens = _detokenize (context .inputs , context .model .vocabulary )
123
+ x = mtf .Print (x , [tokens ], "tokens:" , summarize = 1000 )
124
+ extras = _windows (context .inputs , context .length_dim )
125
+ else :
126
+ extras = None
127
+ y , loss , extras = transformer_moe_layer_v1 (
120
128
x ,
121
129
output_dim ,
122
130
self ._hparams ,
@@ -127,7 +135,16 @@ def call(self, context, x, losses=None):
127
135
nonpadding = context .nonpadding ,
128
136
activation = self ._activation ,
129
137
num_microbatches = context .num_microbatches ,
130
- token_embeddings = context .input_embeddings )
138
+ token_embeddings = context .input_embeddings ,
139
+ extras = extras )
140
+
141
+ if extras :
142
+ extras = _detokenize (extras , context .model .vocabulary )
143
+ experts_dim = mtf .Dimension ("experts" , self ._hparams .moe_num_experts )
144
+ extras = mtf .unstack (extras , experts_dim )
145
+ for i , t in enumerate (extras ):
146
+ y = mtf .Print (y , [t ], "EXPERT %s:" % i , summarize = 1000 )
147
+
131
148
if context .losses is not None :
132
149
context .losses .append (loss )
133
150
if not has_length_dim :
@@ -139,6 +156,23 @@ def call(self, context, x, losses=None):
139
156
return y
140
157
141
158
159
+ @gin .configurable
160
+ def _windows (ids , length_dim , window_start = 0 , window_end = 0 ):
161
+ to_stack = []
162
+ for offset in range (window_start , window_end + 1 ):
163
+ to_stack .append (mtf .shift (ids , - offset , length_dim , wrap = False ))
164
+ return mtf .stack (to_stack , "window" , axis = ids .shape .ndims )
165
+
166
+
167
+ def _detokenize (ids , vocabulary ):
168
+ return mtf .slicewise (
169
+ vocabulary .decode_tf ,
170
+ [ids ],
171
+ output_shape = mtf .Shape (ids .shape .dims [:- 1 ]),
172
+ output_dtype = tf .string ,
173
+ splittable_dims = ids .shape .dims [:- 1 ])
174
+
175
+
142
176
class MoE2D (transformer .TransformerLayer ):
143
177
"""Mixture of Experts Layer."""
144
178
@@ -202,7 +236,7 @@ def call(self, context, x, losses=None):
202
236
def transformer_moe_layer_v1 (
203
237
inputs , output_dim , hparams , train , variable_dtype ,
204
238
layout = None , mesh_shape = None , nonpadding = None , activation = mtf .relu ,
205
- num_microbatches = None , token_embeddings = None ):
239
+ num_microbatches = None , token_embeddings = None , extras = None ):
206
240
"""Local mixture of experts that works well on TPU.
207
241
208
242
Adapted from the paper https://arxiv.org/abs/1701.06538
@@ -281,6 +315,7 @@ def transformer_moe_layer_v1(
281
315
[batch_dim(s), length_dim, input_dim]. These are the word embeddings for
282
316
that correspond to the inputs. These can optionally be used to make
283
317
routing decisions.
318
+ extras: a tensor to dispatch (for debugging purposes)
284
319
285
320
Returns:
286
321
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
@@ -344,6 +379,10 @@ def transformer_moe_layer_v1(
344
379
# over which those groups are split.
345
380
batch_and_length_dims , input_dim = (orig_inputs .shape .dims [:- 1 ],
346
381
orig_inputs .shape .dims [- 1 ])
382
+
383
+ if extras :
384
+ extras_dims = extras .shape .dims [len (batch_and_length_dims ):]
385
+
347
386
# Hack: we assume that
348
387
# "outer_batch" == replication of experts
349
388
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
@@ -381,6 +420,11 @@ def transformer_moe_layer_v1(
381
420
token_embeddings = mtf .cast (
382
421
mtf .reshape (token_embeddings , moe_input_dims ), inputs .dtype )
383
422
423
+ if extras :
424
+ extras = mtf .reshape (
425
+ extras ,
426
+ [outer_batch_dim , num_groups_dim , group_size_dim ] + extras_dims )
427
+
384
428
# Each sequence sends expert_capacity positions to each expert.
385
429
if train :
386
430
capacity_factor = hparams .moe_capacity_factor_train
@@ -503,6 +547,17 @@ def transformer_moe_layer_v1(
503
547
input_dim
504
548
]))
505
549
550
+ if extras :
551
+ extras = mtf .einsum ([extras , mtf .cast (dispatch_tensor , extras .dtype )],
552
+ mtf .Shape ([
553
+ outer_batch_dim , experts_dim_unsplit ,
554
+ num_groups_dim , expert_capacity_dim ] + extras_dims ))
555
+ extras = mtf .reshape (
556
+ extras ,
557
+ mtf .Shape ([
558
+ outer_batch_dim , experts_dim , batch_dim_unsplit ,
559
+ expert_capacity_dim ] + extras_dims ))
560
+
506
561
# Now feed the expert inputs through the experts.
507
562
h = mtf .layers .dense_product (
508
563
expert_inputs ,
@@ -559,10 +614,15 @@ def _compute_output(hidden, layer_name):
559
614
k = _compute_output (k_h , layer_name = "k_wo" )
560
615
outputs .append (q )
561
616
outputs .append (k )
562
- return outputs , loss * hparams .moe_loss_coef
617
+ return outputs , loss * hparams .moe_loss_coef , None
563
618
else :
564
619
output = _compute_output (h , layer_name = "wo" )
565
- return output , loss * hparams .moe_loss_coef
620
+ loss *= hparams .moe_loss_coef
621
+
622
+ if extras :
623
+ return output , loss , extras
624
+ else :
625
+ return output , loss , None
566
626
567
627
568
628
def transformer_moe_layer_v2 (
0 commit comments