Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit f027fb2

Browse files
hoogendongMesh TensorFlow Team
authored and
Mesh TensorFlow Team
committed
Adding a new Gradient Estimator for Routing using REINFORCE with a leave-one-out baseline.
PiperOrigin-RevId: 435129337
1 parent 39f4bd6 commit f027fb2

File tree

2 files changed

+114
-14
lines changed

2 files changed

+114
-14
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from __future__ import division
2626
from __future__ import print_function
2727

28+
import math
2829
import gin
2930

3031
import mesh_tensorflow as mtf
@@ -65,7 +66,10 @@ def __init__(self,
6566
word_embed_mode=None,
6667
use_second_place_expert_prob=None,
6768
use_second_place_expert_prob_temp=None,
68-
top_n_num_experts_per_token=3):
69+
top_n_num_experts_per_token=3,
70+
rloo=False,
71+
loss_type="load_balance",
72+
p_dot_e=True):
6973
self._hparams = HParams(
7074
moe_gating=moe_gating,
7175
moe_num_experts=num_experts,
@@ -95,7 +99,10 @@ def __init__(self,
9599
use_second_place_expert_prob),
96100
moe_use_second_place_expert_prob_temp=(
97101
use_second_place_expert_prob_temp),
98-
moe_top_n_num_experts_per_token=top_n_num_experts_per_token)
102+
moe_top_n_num_experts_per_token=top_n_num_experts_per_token,
103+
moe_rloo=rloo,
104+
loss_type=loss_type,
105+
p_dot_e=p_dot_e)
99106
self._activation = activation
100107

101108
def call(self, context, x, losses=None):
@@ -127,7 +134,8 @@ def call(self, context, x, losses=None):
127134
nonpadding=context.nonpadding,
128135
activation=self._activation,
129136
num_microbatches=context.num_microbatches,
130-
token_embeddings=context.input_embeddings)
137+
token_embeddings=context.input_embeddings,
138+
context=context)
131139
if context.losses is not None:
132140
context.losses.append(loss)
133141
if not has_length_dim:
@@ -202,7 +210,7 @@ def call(self, context, x, losses=None):
202210
def transformer_moe_layer_v1(
203211
inputs, output_dim, hparams, train, variable_dtype,
204212
layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
205-
num_microbatches=None, token_embeddings=None):
213+
num_microbatches=None, token_embeddings=None, context=None):
206214
"""Local mixture of experts that works well on TPU.
207215
208216
Adapted from the paper https://arxiv.org/abs/1701.06538
@@ -281,6 +289,8 @@ def transformer_moe_layer_v1(
281289
[batch_dim(s), length_dim, input_dim]. These are the word embeddings for
282290
that correspond to the inputs. These can optionally be used to make
283291
routing decisions.
292+
context: a Context object contains extra information that layers need
293+
at call time, as defined in transformer.py.
284294
285295
Returns:
286296
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
@@ -436,7 +446,8 @@ def transformer_moe_layer_v1(
436446
variable_dtype=variable_dtype,
437447
importance=nonpadding,
438448
num_microbatches=num_microbatches,
439-
token_embeddings=token_embeddings)
449+
token_embeddings=token_embeddings,
450+
context=context)
440451
elif hparams.moe_gating == "ntlb":
441452
dispatch_tensor, combine_tensor, loss = _ntlb_gating(
442453
inputs=inputs,
@@ -1303,7 +1314,8 @@ def _expert_selection_gating(
13031314
def _switch_gating(
13041315
inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
13051316
hparams, train, variable_dtype, importance=None, name="switch_gating",
1306-
num_microbatches=None, token_embeddings=None):
1317+
num_microbatches=None, token_embeddings=None,
1318+
context=None):
13071319
"""Compute Switch gating."""
13081320
# SELECT EXPERT
13091321
if train:
@@ -1351,6 +1363,11 @@ def _switch_gating(
13511363
expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim)
13521364
else:
13531365
raise ValueError("Unknown Switch gating policy %s" % policy)
1366+
full_expert_gate_log_probs = gate_logits / hparams.moe_switch_temperature
1367+
full_expert_gate_log_probs -= mtf.reduce_logsumexp(full_expert_gate_log_probs,
1368+
reduced_dim=experts_dim)
1369+
expert_gate_log_probs = mtf.gather(full_expert_gate_log_probs, expert_index,
1370+
dim=experts_dim)
13541371

13551372
expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype)
13561373

@@ -1363,21 +1380,40 @@ def _switch_gating(
13631380
expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
13641381
density_1_proxy *= mtf.cast(
13651382
mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
1366-
loss = (
1383+
load_balance_loss = (
13671384
mtf.reduce_mean(density_1_proxy * density_1) *
13681385
float(experts_dim.size * experts_dim.size))
1386+
1387+
kl_with_uniform = (
1388+
- math.log(float(experts_dim.size))
1389+
- mtf.reduce_logsumexp(full_expert_gate_log_probs,
1390+
reduced_dim=group_size_dim)
1391+
+ math.log(float(group_size_dim.size)))
1392+
if importance:
1393+
kl_with_uniform *= mtf.cast(mtf.equal(importance, 1.0),
1394+
dtype=raw_gates.dtype)
1395+
kl_with_uniform = mtf.reduce_mean(kl_with_uniform)
1396+
1397+
if hparams.loss_type.lower() == "kl":
1398+
loss = kl_with_uniform
1399+
else:
1400+
loss = load_balance_loss
1401+
13691402
if num_microbatches and num_microbatches > 1:
13701403
tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
13711404
num_microbatches))
13721405
loss /= num_microbatches
13731406

13741407
# Logging
13751408
if train:
1376-
entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
1377-
reduced_dim=experts_dim)
1409+
entropy = mtf.reduce_sum(
1410+
-mtf.exp(full_expert_gate_log_probs) * full_expert_gate_log_probs,
1411+
reduced_dim=experts_dim)
13781412
batch_entropy = mtf.reduce_mean(entropy)
13791413
mtf.scalar_summary(name + "/entropy", batch_entropy)
13801414
mtf.scalar_summary("expert_gate", mtf.reduce_mean(expert_gate))
1415+
mtf.scalar_summary("tempered_expert_gate",
1416+
mtf.reduce_mean(mtf.exp(expert_gate_log_probs)))
13811417

13821418
mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
13831419
total_routed = mtf.reduce_sum(mask_count_experts)
@@ -1389,7 +1425,25 @@ def _switch_gating(
13891425
for fraction in split_fractions:
13901426
mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
13911427
mtf.reduce_mean(fraction))
1392-
mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))
1428+
dead_expert_fraction = mtf.reduce_mean(
1429+
mtf.cast(mtf.equal(mask_count_experts, 0.),
1430+
dtype=raw_gates.dtype))
1431+
mtf.scalar_summary("dead_expert_fraction",
1432+
dead_expert_fraction)
1433+
mtf.scalar_summary("load_balancing_loss",
1434+
mtf.reduce_mean(load_balance_loss))
1435+
mtf.scalar_summary("kl_with_uniform",
1436+
mtf.reduce_mean(kl_with_uniform))
1437+
1438+
split_expert_index = mtf.rename_dimension(
1439+
expert_index, 'batch', 'batch_split')
1440+
first_expert_index, second_expert_index = mtf.split(
1441+
split_expert_index,
1442+
split_expert_index.shape.get_dim_by_name('batch_split'), 2)
1443+
duplicate_sample = mtf.reduce_mean(
1444+
mtf.cast(mtf.equal(first_expert_index, second_expert_index),
1445+
dtype=raw_gates.dtype))
1446+
mtf.scalar_summary("duplicate_sample_fraction", duplicate_sample)
13931447

13941448
# Add in the z_loss for router.
13951449
if train and hparams.moe_z_loss is not None:
@@ -1421,9 +1475,16 @@ def _switch_gating(
14211475
# Mask out the experts that have overflowed expert capacity. Sparsify the
14221476
# expert_gate.
14231477
expert_gate *= expert_mask_flat
1478+
if hparams.moe_rloo:
1479+
expert_gate_log_probs *= expert_mask_flat
1480+
context.expert_gate_log_probs.append(expert_gate_log_probs)
14241481

1425-
combine_tensor = (
1426-
expert_gate * expert_mask_flat *
1482+
if hparams.p_dot_e:
1483+
combine_tensor = expert_gate
1484+
else:
1485+
combine_tensor = expert_mask_flat
1486+
1487+
combine_tensor *= (
14271488
mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) *
14281489
mtf.one_hot(
14291490
mtf.to_int32(position_in_expert),

mesh_tensorflow/transformer/transformer.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def __init__(self,
144144
read_priority=None,
145145
inputs=None,
146146
encoder_inputs=None,
147-
num_microbatches=1):
147+
num_microbatches=1,
148+
expert_gate_log_probs=None):
148149
"""Create a context.
149150
150151
Args:
@@ -201,6 +202,8 @@ def __init__(self,
201202
decoder.
202203
num_microbatches: integer - greater than one if the step has been
203204
serialized into multiple microbatches to save memory.
205+
expert_gate_log_probs: an optional list of Tensors of expert gate log
206+
probs. This will be used to compute REINFORCE gradients.
204207
"""
205208
self.model = model
206209
self.mesh = mesh
@@ -235,6 +238,7 @@ def __init__(self,
235238
self.encoder_inputs = encoder_inputs
236239
self.num_microbatches = num_microbatches
237240
self.input_embeddings = None
241+
self.expert_gate_log_probs = expert_gate_log_probs
238242

239243
@property
240244
def train(self):
@@ -848,6 +852,19 @@ def _compute_loss(self, context, logits, targets, output_vocab_dim):
848852
if self.loss_on_targets_only:
849853
weights *= mtf.cast(mtf.logical_not(delimited_lm_inputs_mask(targets)),
850854
dtype=context.activation_dtype)
855+
856+
# Compute REINFORCE loss
857+
if context.expert_gate_log_probs:
858+
log_probs = mtf.reshape(
859+
mtf.add_n(context.expert_gate_log_probs), loss.shape)
860+
split_loss = mtf.rename_dimension(loss, "batch", "batch_unsplit")
861+
first_loss, second_loss = mtf.split(
862+
split_loss, split_loss.shape.get_dim_by_name("batch_unsplit"), 2)
863+
baseline = mtf.concat([second_loss, first_loss], "batch_unsplit")
864+
baseline = mtf.rename_dimension(baseline, "batch_unsplit", "batch")
865+
loss += mtf.stop_gradient(loss - baseline) * mtf.cast(
866+
log_probs, loss.dtype)
867+
851868
return (mtf.reduce_sum(loss * weights) /
852869
self.loss_denominator(targets, context.num_microbatches))
853870

@@ -1007,6 +1024,27 @@ def call_simple(self,
10071024
logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
10081025
loss: an optional Scalar (if compute_loss=True)
10091026
"""
1027+
if mode == tf.estimator.ModeKeys.TRAIN:
1028+
1029+
def duplicate_batch(t, batch_dim_name="batch"):
1030+
if t:
1031+
# Assumes that the batch size is divisible by 2
1032+
half_batch_size = t.shape.get_dim_by_name(batch_dim_name).size // 2
1033+
t = mtf.rename_dimension(t, batch_dim_name, batch_dim_name + "_slice")
1034+
half_batch = mtf.slice(t, 0, half_batch_size,
1035+
batch_dim_name + "_slice")
1036+
t = mtf.concat([half_batch, half_batch], batch_dim_name + "_slice")
1037+
return mtf.rename_dimension(t, batch_dim_name + "_slice",
1038+
batch_dim_name)
1039+
else:
1040+
return t
1041+
1042+
inputs = duplicate_batch(inputs)
1043+
targets = duplicate_batch(targets)
1044+
sequence_id = duplicate_batch(sequence_id)
1045+
position = duplicate_batch(position)
1046+
encoder_sequence_id = duplicate_batch(encoder_sequence_id)
1047+
10101048
batch_dims = inputs.shape.dims[:-1]
10111049
length_dim = inputs.shape.dims[-1]
10121050
length_range = mtf.range(inputs.mesh, length_dim, dtype=tf.int32)
@@ -1061,7 +1099,8 @@ def call_simple(self,
10611099
read_priority=read_priority,
10621100
inputs=inputs,
10631101
encoder_inputs=encoder_inputs,
1064-
num_microbatches=num_microbatches)
1102+
num_microbatches=num_microbatches,
1103+
expert_gate_log_probs=[],)
10651104
with tf.variable_scope(self.name):
10661105
logits = self._call_internal(context, inputs, targets)
10671106
if compute_loss:

0 commit comments

Comments
 (0)