@@ -217,6 +217,170 @@ def transformer_encoder(encoder_input,
217
217
return common_layers .layer_preprocess (x , hparams )
218
218
219
219
220
+ def evolved_transformer_encoder (encoder_input ,
221
+ encoder_self_attention_bias ,
222
+ hparams ,
223
+ name = "encoder" ,
224
+ nonpadding = None ,
225
+ save_weights_to = None ,
226
+ make_image_summary = True ,
227
+ losses = None ,
228
+ attn_bias_for_padding = None ):
229
+ """Evolved Transformer encoder. See arxiv.org/abs/1901.11117 for more details.
230
+
231
+ Note: Pad remover is not supported.
232
+
233
+ Args:
234
+ encoder_input: a Tensor.
235
+ encoder_self_attention_bias: bias Tensor for self-attention (see
236
+ common_attention.attention_bias()).
237
+ hparams: hyperparameters for model.
238
+ name: a string.
239
+ nonpadding: optional Tensor with shape [batch_size, encoder_length]
240
+ indicating what positions are not padding. This must either be passed in,
241
+ which we do for "packed" datasets, or inferred from
242
+ encoder_self_attention_bias. The knowledge about padding is used for
243
+ pad_remover(efficiency) and to mask out padding in convolutional layers.
244
+ save_weights_to: an optional dictionary to capture attention weights for
245
+ visualization; the weights tensor will be appended there under a string
246
+ key created from the variable scope (including name).
247
+ make_image_summary: Whether to make an attention image summary.
248
+ losses: Not used.
249
+ attn_bias_for_padding: Padded attention bias in case a unidirectional
250
+ encoder is being used where future attention is masked.
251
+
252
+ Returns:
253
+ Tensor encoder output.
254
+ """
255
+ del losses
256
+
257
+ hidden_state = encoder_input
258
+ attention_dropout_broadcast_dims = (
259
+ common_layers .comma_separated_string_to_integer_list (
260
+ getattr (hparams , "attention_dropout_broadcast_dims" , "" )))
261
+
262
+ with tf .variable_scope (name ):
263
+ if nonpadding is not None :
264
+ padding = 1.0 - nonpadding
265
+ else :
266
+ attention_bias = encoder_self_attention_bias
267
+ if attn_bias_for_padding is not None :
268
+ attention_bias = attn_bias_for_padding
269
+ padding = common_attention .attention_bias_to_padding (attention_bias )
270
+ nonpadding = 1.0 - padding
271
+
272
+ for layer in range (hparams .num_encoder_layers or hparams .num_hidden_layers ):
273
+ with tf .variable_scope ("layer_%d" % layer ):
274
+
275
+ with tf .variable_scope ("gated_linear_unit" ):
276
+
277
+ residual_state = hidden_state
278
+ hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
279
+
280
+ values = tf .layers .dense (hidden_state , hparams .hidden_size )
281
+ gates = tf .layers .dense (
282
+ hidden_state , hparams .hidden_size , activation = tf .nn .sigmoid )
283
+ hidden_state = values * gates
284
+
285
+ hidden_state = common_layers .layer_postprocess (
286
+ residual_state , hidden_state , hparams )
287
+
288
+ with tf .variable_scope ("conv_branches" ):
289
+
290
+ residual_state = hidden_state
291
+ hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
292
+ # Mask padding from conv layers.
293
+ mask = tf .tile (
294
+ tf .expand_dims (nonpadding , 2 ), [1 , 1 , hparams .hidden_size ])
295
+ hidden_state *= mask
296
+
297
+ left_output_dim = int (hparams .hidden_size * 4 )
298
+ left_state = tf .layers .dense (
299
+ hidden_state , left_output_dim , activation = tf .nn .relu )
300
+ left_state = tf .nn .dropout (left_state ,
301
+ 1 - hparams .layer_prepostprocess_dropout )
302
+
303
+ right_output_dim = int (hparams .hidden_size / 2 )
304
+ right_state = tf .layers .conv1d (
305
+ hidden_state ,
306
+ right_output_dim ,
307
+ 3 ,
308
+ padding = "SAME" ,
309
+ name = "standard_conv_3x1" ,
310
+ activation = tf .nn .relu )
311
+ right_state = tf .nn .dropout (right_state ,
312
+ 1 - hparams .layer_prepostprocess_dropout )
313
+
314
+ right_state = tf .pad (
315
+ right_state ,
316
+ [[0 , 0 ], [0 , 0 ], [0 , left_output_dim - right_output_dim ]],
317
+ constant_values = 0 )
318
+ hidden_state = left_state + right_state
319
+
320
+ hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
321
+ # Mask padding from conv layer.
322
+ mask = tf .tile (tf .expand_dims (nonpadding , 2 ), [1 , 1 , left_output_dim ])
323
+ hidden_state *= mask
324
+
325
+ separable_conv_9x1 = tf .layers .SeparableConv1D (
326
+ right_output_dim , 9 , padding = "SAME" , name = "separable_conv_9x1" )
327
+ hidden_state = separable_conv_9x1 .apply (hidden_state )
328
+ hidden_state = tf .pad (
329
+ hidden_state ,
330
+ [[0 , 0 ], [0 , 0 ], [0 , hparams .hidden_size - right_output_dim ]],
331
+ constant_values = 0 )
332
+
333
+ hidden_state = common_layers .layer_postprocess (
334
+ residual_state , hidden_state , hparams )
335
+
336
+ with tf .variable_scope ("self_attention" ):
337
+ residual_state = hidden_state
338
+ hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
339
+
340
+ hidden_state = common_attention .multihead_attention (
341
+ hidden_state ,
342
+ None ,
343
+ encoder_self_attention_bias ,
344
+ hparams .attention_key_channels or hparams .hidden_size ,
345
+ hparams .attention_value_channels or hparams .hidden_size ,
346
+ hparams .hidden_size ,
347
+ hparams .num_heads ,
348
+ hparams .attention_dropout ,
349
+ attention_type = hparams .self_attention_type ,
350
+ max_relative_position = hparams .max_relative_position ,
351
+ heads_share_relative_embedding = (
352
+ hparams .heads_share_relative_embedding ),
353
+ add_relative_to_values = hparams .add_relative_to_values ,
354
+ save_weights_to = save_weights_to ,
355
+ make_image_summary = make_image_summary ,
356
+ dropout_broadcast_dims = attention_dropout_broadcast_dims ,
357
+ max_length = hparams .get ("max_length" ),
358
+ vars_3d = hparams .get ("attention_variables_3d" ),
359
+ activation_dtype = hparams .get ("activation_dtype" , "float32" ),
360
+ weight_dtype = hparams .get ("weight_dtype" , "float32" ))
361
+
362
+ hidden_state = common_layers .layer_postprocess (
363
+ residual_state , hidden_state , hparams )
364
+
365
+ with tf .variable_scope ("dense_layers" ):
366
+ residual_state = hidden_state
367
+ hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
368
+
369
+ hidden_state = tf .layers .dense (
370
+ hidden_state , int (hparams .hidden_size * 4 ), activation = tf .nn .relu )
371
+ hidden_state = tf .nn .dropout (hidden_state ,
372
+ 1 - hparams .layer_prepostprocess_dropout )
373
+
374
+ hidden_state = tf .layers .dense (hidden_state , hparams .hidden_size )
375
+ hidden_state = common_layers .layer_postprocess (
376
+ residual_state , hidden_state , hparams )
377
+
378
+ # If normalization is done in layer_preprocess, then it should also be done
379
+ # on the output, since the output can grow very large, being the sum of
380
+ # a whole stack of unnormalized layer outputs.
381
+ return common_layers .layer_preprocess (hidden_state , hparams )
382
+
383
+
220
384
def transformer_ffn_layer (x ,
221
385
hparams ,
222
386
pad_remover = None ,
0 commit comments