25
25
from __future__ import division
26
26
from __future__ import print_function
27
27
28
+ import math
28
29
import gin
29
30
30
31
import mesh_tensorflow as mtf
@@ -65,7 +66,10 @@ def __init__(self,
65
66
word_embed_mode = None ,
66
67
use_second_place_expert_prob = None ,
67
68
use_second_place_expert_prob_temp = None ,
68
- top_n_num_experts_per_token = 3 ):
69
+ top_n_num_experts_per_token = 3 ,
70
+ rloo = False ,
71
+ loss_type = "load_balance" ,
72
+ p_dot_e = True ):
69
73
self ._hparams = HParams (
70
74
moe_gating = moe_gating ,
71
75
moe_num_experts = num_experts ,
@@ -95,7 +99,10 @@ def __init__(self,
95
99
use_second_place_expert_prob ),
96
100
moe_use_second_place_expert_prob_temp = (
97
101
use_second_place_expert_prob_temp ),
98
- moe_top_n_num_experts_per_token = top_n_num_experts_per_token )
102
+ moe_top_n_num_experts_per_token = top_n_num_experts_per_token ,
103
+ moe_rloo = rloo ,
104
+ loss_type = loss_type ,
105
+ p_dot_e = p_dot_e )
99
106
self ._activation = activation
100
107
101
108
def call (self , context , x , losses = None ):
@@ -127,7 +134,8 @@ def call(self, context, x, losses=None):
127
134
nonpadding = context .nonpadding ,
128
135
activation = self ._activation ,
129
136
num_microbatches = context .num_microbatches ,
130
- token_embeddings = context .input_embeddings )
137
+ token_embeddings = context .input_embeddings ,
138
+ context = context )
131
139
if context .losses is not None :
132
140
context .losses .append (loss )
133
141
if not has_length_dim :
@@ -202,7 +210,7 @@ def call(self, context, x, losses=None):
202
210
def transformer_moe_layer_v1 (
203
211
inputs , output_dim , hparams , train , variable_dtype ,
204
212
layout = None , mesh_shape = None , nonpadding = None , activation = mtf .relu ,
205
- num_microbatches = None , token_embeddings = None ):
213
+ num_microbatches = None , token_embeddings = None , context = None ):
206
214
"""Local mixture of experts that works well on TPU.
207
215
208
216
Adapted from the paper https://arxiv.org/abs/1701.06538
@@ -281,6 +289,8 @@ def transformer_moe_layer_v1(
281
289
[batch_dim(s), length_dim, input_dim]. These are the word embeddings for
282
290
that correspond to the inputs. These can optionally be used to make
283
291
routing decisions.
292
+ context: a Context object contains extra information that layers need
293
+ at call time, as defined in transformer.py.
284
294
285
295
Returns:
286
296
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
@@ -436,7 +446,8 @@ def transformer_moe_layer_v1(
436
446
variable_dtype = variable_dtype ,
437
447
importance = nonpadding ,
438
448
num_microbatches = num_microbatches ,
439
- token_embeddings = token_embeddings )
449
+ token_embeddings = token_embeddings ,
450
+ context = context )
440
451
elif hparams .moe_gating == "ntlb" :
441
452
dispatch_tensor , combine_tensor , loss = _ntlb_gating (
442
453
inputs = inputs ,
@@ -1303,7 +1314,8 @@ def _expert_selection_gating(
1303
1314
def _switch_gating (
1304
1315
inputs , outer_expert_dims , experts_dim , expert_capacity_dim ,
1305
1316
hparams , train , variable_dtype , importance = None , name = "switch_gating" ,
1306
- num_microbatches = None , token_embeddings = None ):
1317
+ num_microbatches = None , token_embeddings = None ,
1318
+ context = None ):
1307
1319
"""Compute Switch gating."""
1308
1320
# SELECT EXPERT
1309
1321
if train :
@@ -1351,6 +1363,11 @@ def _switch_gating(
1351
1363
expert_gate = mtf .gather (raw_gates , expert_index , dim = experts_dim )
1352
1364
else :
1353
1365
raise ValueError ("Unknown Switch gating policy %s" % policy )
1366
+ full_expert_gate_log_probs = gate_logits / hparams .moe_switch_temperature
1367
+ full_expert_gate_log_probs -= mtf .reduce_logsumexp (full_expert_gate_log_probs ,
1368
+ reduced_dim = experts_dim )
1369
+ expert_gate_log_probs = mtf .gather (full_expert_gate_log_probs , expert_index ,
1370
+ dim = experts_dim )
1354
1371
1355
1372
expert_mask = mtf .one_hot (expert_index , experts_dim , dtype = raw_gates .dtype )
1356
1373
@@ -1363,21 +1380,40 @@ def _switch_gating(
1363
1380
expert_gate *= mtf .cast (mtf .equal (importance , 1.0 ), dtype = raw_gates .dtype )
1364
1381
density_1_proxy *= mtf .cast (
1365
1382
mtf .equal (importance , 1.0 ), dtype = raw_gates .dtype )
1366
- loss = (
1383
+ load_balance_loss = (
1367
1384
mtf .reduce_mean (density_1_proxy * density_1 ) *
1368
1385
float (experts_dim .size * experts_dim .size ))
1386
+
1387
+ kl_with_uniform = (
1388
+ - math .log (float (experts_dim .size ))
1389
+ - mtf .reduce_logsumexp (full_expert_gate_log_probs ,
1390
+ reduced_dim = group_size_dim )
1391
+ + math .log (float (group_size_dim .size )))
1392
+ if importance :
1393
+ kl_with_uniform *= mtf .cast (mtf .equal (importance , 1.0 ),
1394
+ dtype = raw_gates .dtype )
1395
+ kl_with_uniform = mtf .reduce_mean (kl_with_uniform )
1396
+
1397
+ if hparams .loss_type .lower () == "kl" :
1398
+ loss = kl_with_uniform
1399
+ else :
1400
+ loss = load_balance_loss
1401
+
1369
1402
if num_microbatches and num_microbatches > 1 :
1370
1403
tf .logging .info ("Dividing load-balance loss by num_microbatches={}" .format (
1371
1404
num_microbatches ))
1372
1405
loss /= num_microbatches
1373
1406
1374
1407
# Logging
1375
1408
if train :
1376
- entropy = mtf .reduce_sum (- raw_gates * mtf .log (raw_gates + 1e-9 ),
1377
- reduced_dim = experts_dim )
1409
+ entropy = mtf .reduce_sum (
1410
+ - mtf .exp (full_expert_gate_log_probs ) * full_expert_gate_log_probs ,
1411
+ reduced_dim = experts_dim )
1378
1412
batch_entropy = mtf .reduce_mean (entropy )
1379
1413
mtf .scalar_summary (name + "/entropy" , batch_entropy )
1380
1414
mtf .scalar_summary ("expert_gate" , mtf .reduce_mean (expert_gate ))
1415
+ mtf .scalar_summary ("tempered_expert_gate" ,
1416
+ mtf .reduce_mean (mtf .exp (expert_gate_log_probs )))
1381
1417
1382
1418
mask_count_experts = mtf .reduce_sum (expert_mask , output_shape = [experts_dim ])
1383
1419
total_routed = mtf .reduce_sum (mask_count_experts )
@@ -1389,7 +1425,25 @@ def _switch_gating(
1389
1425
for fraction in split_fractions :
1390
1426
mtf .scalar_summary ("experts/" + fraction .name .replace (":" , "/" ),
1391
1427
mtf .reduce_mean (fraction ))
1392
- mtf .scalar_summary ("aux_loss" , mtf .reduce_mean (loss ))
1428
+ dead_expert_fraction = mtf .reduce_mean (
1429
+ mtf .cast (mtf .equal (mask_count_experts , 0. ),
1430
+ dtype = raw_gates .dtype ))
1431
+ mtf .scalar_summary ("dead_expert_fraction" ,
1432
+ dead_expert_fraction )
1433
+ mtf .scalar_summary ("load_balancing_loss" ,
1434
+ mtf .reduce_mean (load_balance_loss ))
1435
+ mtf .scalar_summary ("kl_with_uniform" ,
1436
+ mtf .reduce_mean (kl_with_uniform ))
1437
+
1438
+ split_expert_index = mtf .rename_dimension (
1439
+ expert_index , 'batch' , 'batch_split' )
1440
+ first_expert_index , second_expert_index = mtf .split (
1441
+ split_expert_index ,
1442
+ split_expert_index .shape .get_dim_by_name ('batch_split' ), 2 )
1443
+ duplicate_sample = mtf .reduce_mean (
1444
+ mtf .cast (mtf .equal (first_expert_index , second_expert_index ),
1445
+ dtype = raw_gates .dtype ))
1446
+ mtf .scalar_summary ("duplicate_sample_fraction" , duplicate_sample )
1393
1447
1394
1448
# Add in the z_loss for router.
1395
1449
if train and hparams .moe_z_loss is not None :
@@ -1421,9 +1475,16 @@ def _switch_gating(
1421
1475
# Mask out the experts that have overflowed expert capacity. Sparsify the
1422
1476
# expert_gate.
1423
1477
expert_gate *= expert_mask_flat
1478
+ if hparams .moe_rloo :
1479
+ expert_gate_log_probs *= expert_mask_flat
1480
+ context .expert_gate_log_probs .append (expert_gate_log_probs )
1424
1481
1425
- combine_tensor = (
1426
- expert_gate * expert_mask_flat *
1482
+ if hparams .p_dot_e :
1483
+ combine_tensor = expert_gate
1484
+ else :
1485
+ combine_tensor = expert_mask_flat
1486
+
1487
+ combine_tensor *= (
1427
1488
mtf .one_hot (expert_index , experts_dim , dtype = raw_gates .dtype ) *
1428
1489
mtf .one_hot (
1429
1490
mtf .to_int32 (position_in_expert ),
0 commit comments