Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f7c6421

Browse files
T2T TeamCopybara-Service
T2T Team
authored and
Copybara-Service
committed
Evolved Transformer encoder.
PiperOrigin-RevId: 232009541
1 parent 512c9a2 commit f7c6421

File tree

1 file changed

+164
-0
lines changed

1 file changed

+164
-0
lines changed

tensor2tensor/layers/transformer_layers.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,170 @@ def transformer_encoder(encoder_input,
217217
return common_layers.layer_preprocess(x, hparams)
218218

219219

220+
def evolved_transformer_encoder(encoder_input,
221+
encoder_self_attention_bias,
222+
hparams,
223+
name="encoder",
224+
nonpadding=None,
225+
save_weights_to=None,
226+
make_image_summary=True,
227+
losses=None,
228+
attn_bias_for_padding=None):
229+
"""Evolved Transformer encoder. See arxiv.org/abs/1901.11117 for more details.
230+
231+
Note: Pad remover is not supported.
232+
233+
Args:
234+
encoder_input: a Tensor.
235+
encoder_self_attention_bias: bias Tensor for self-attention (see
236+
common_attention.attention_bias()).
237+
hparams: hyperparameters for model.
238+
name: a string.
239+
nonpadding: optional Tensor with shape [batch_size, encoder_length]
240+
indicating what positions are not padding. This must either be passed in,
241+
which we do for "packed" datasets, or inferred from
242+
encoder_self_attention_bias. The knowledge about padding is used for
243+
pad_remover(efficiency) and to mask out padding in convolutional layers.
244+
save_weights_to: an optional dictionary to capture attention weights for
245+
visualization; the weights tensor will be appended there under a string
246+
key created from the variable scope (including name).
247+
make_image_summary: Whether to make an attention image summary.
248+
losses: Not used.
249+
attn_bias_for_padding: Padded attention bias in case a unidirectional
250+
encoder is being used where future attention is masked.
251+
252+
Returns:
253+
Tensor encoder output.
254+
"""
255+
del losses
256+
257+
hidden_state = encoder_input
258+
attention_dropout_broadcast_dims = (
259+
common_layers.comma_separated_string_to_integer_list(
260+
getattr(hparams, "attention_dropout_broadcast_dims", "")))
261+
262+
with tf.variable_scope(name):
263+
if nonpadding is not None:
264+
padding = 1.0 - nonpadding
265+
else:
266+
attention_bias = encoder_self_attention_bias
267+
if attn_bias_for_padding is not None:
268+
attention_bias = attn_bias_for_padding
269+
padding = common_attention.attention_bias_to_padding(attention_bias)
270+
nonpadding = 1.0 - padding
271+
272+
for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers):
273+
with tf.variable_scope("layer_%d" % layer):
274+
275+
with tf.variable_scope("gated_linear_unit"):
276+
277+
residual_state = hidden_state
278+
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
279+
280+
values = tf.layers.dense(hidden_state, hparams.hidden_size)
281+
gates = tf.layers.dense(
282+
hidden_state, hparams.hidden_size, activation=tf.nn.sigmoid)
283+
hidden_state = values * gates
284+
285+
hidden_state = common_layers.layer_postprocess(
286+
residual_state, hidden_state, hparams)
287+
288+
with tf.variable_scope("conv_branches"):
289+
290+
residual_state = hidden_state
291+
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
292+
# Mask padding from conv layers.
293+
mask = tf.tile(
294+
tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size])
295+
hidden_state *= mask
296+
297+
left_output_dim = int(hparams.hidden_size * 4)
298+
left_state = tf.layers.dense(
299+
hidden_state, left_output_dim, activation=tf.nn.relu)
300+
left_state = tf.nn.dropout(left_state,
301+
1 - hparams.layer_prepostprocess_dropout)
302+
303+
right_output_dim = int(hparams.hidden_size / 2)
304+
right_state = tf.layers.conv1d(
305+
hidden_state,
306+
right_output_dim,
307+
3,
308+
padding="SAME",
309+
name="standard_conv_3x1",
310+
activation=tf.nn.relu)
311+
right_state = tf.nn.dropout(right_state,
312+
1 - hparams.layer_prepostprocess_dropout)
313+
314+
right_state = tf.pad(
315+
right_state,
316+
[[0, 0], [0, 0], [0, left_output_dim - right_output_dim]],
317+
constant_values=0)
318+
hidden_state = left_state + right_state
319+
320+
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
321+
# Mask padding from conv layer.
322+
mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, left_output_dim])
323+
hidden_state *= mask
324+
325+
separable_conv_9x1 = tf.layers.SeparableConv1D(
326+
right_output_dim, 9, padding="SAME", name="separable_conv_9x1")
327+
hidden_state = separable_conv_9x1.apply(hidden_state)
328+
hidden_state = tf.pad(
329+
hidden_state,
330+
[[0, 0], [0, 0], [0, hparams.hidden_size - right_output_dim]],
331+
constant_values=0)
332+
333+
hidden_state = common_layers.layer_postprocess(
334+
residual_state, hidden_state, hparams)
335+
336+
with tf.variable_scope("self_attention"):
337+
residual_state = hidden_state
338+
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
339+
340+
hidden_state = common_attention.multihead_attention(
341+
hidden_state,
342+
None,
343+
encoder_self_attention_bias,
344+
hparams.attention_key_channels or hparams.hidden_size,
345+
hparams.attention_value_channels or hparams.hidden_size,
346+
hparams.hidden_size,
347+
hparams.num_heads,
348+
hparams.attention_dropout,
349+
attention_type=hparams.self_attention_type,
350+
max_relative_position=hparams.max_relative_position,
351+
heads_share_relative_embedding=(
352+
hparams.heads_share_relative_embedding),
353+
add_relative_to_values=hparams.add_relative_to_values,
354+
save_weights_to=save_weights_to,
355+
make_image_summary=make_image_summary,
356+
dropout_broadcast_dims=attention_dropout_broadcast_dims,
357+
max_length=hparams.get("max_length"),
358+
vars_3d=hparams.get("attention_variables_3d"),
359+
activation_dtype=hparams.get("activation_dtype", "float32"),
360+
weight_dtype=hparams.get("weight_dtype", "float32"))
361+
362+
hidden_state = common_layers.layer_postprocess(
363+
residual_state, hidden_state, hparams)
364+
365+
with tf.variable_scope("dense_layers"):
366+
residual_state = hidden_state
367+
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
368+
369+
hidden_state = tf.layers.dense(
370+
hidden_state, int(hparams.hidden_size * 4), activation=tf.nn.relu)
371+
hidden_state = tf.nn.dropout(hidden_state,
372+
1 - hparams.layer_prepostprocess_dropout)
373+
374+
hidden_state = tf.layers.dense(hidden_state, hparams.hidden_size)
375+
hidden_state = common_layers.layer_postprocess(
376+
residual_state, hidden_state, hparams)
377+
378+
# If normalization is done in layer_preprocess, then it should also be done
379+
# on the output, since the output can grow very large, being the sum of
380+
# a whole stack of unnormalized layer outputs.
381+
return common_layers.layer_preprocess(hidden_state, hparams)
382+
383+
220384
def transformer_ffn_layer(x,
221385
hparams,
222386
pad_remover=None,

0 commit comments

Comments
 (0)