Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a9b89de

Browse files
aeloyqafrozenator
authored andcommitted
add caching mechanism support for fast decoding with relative_dot_product in transformer model (#1295)
1 parent 3f5ab98 commit a9b89de

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

tensor2tensor/layers/common_attention.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,11 +1477,14 @@ def dot_product_attention(q,
14771477
return tf.matmul(weights, v)
14781478

14791479

1480-
def _generate_relative_positions_matrix(length, max_relative_position):
1480+
def _generate_relative_positions_matrix(length, max_relative_position, cache=False):
14811481
"""Generates matrix of relative positions between inputs."""
1482-
range_vec = tf.range(length)
1483-
range_mat = tf.reshape(tf.tile(range_vec, [length]), [length, length])
1484-
distance_mat = range_mat - tf.transpose(range_mat)
1482+
if not cache:
1483+
range_vec = tf.range(length)
1484+
range_mat = tf.reshape(tf.tile(range_vec, [length]), [length, length])
1485+
distance_mat = range_mat - tf.transpose(range_mat)
1486+
else:
1487+
distance_mat = tf.expand_dims(tf.range(-length+1, 1, 1), 0)
14851488
distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position,
14861489
max_relative_position)
14871490
# Shift values to be >= 0. Each integer still uniquely identifies a relative
@@ -1491,11 +1494,15 @@ def _generate_relative_positions_matrix(length, max_relative_position):
14911494

14921495

14931496
def _generate_relative_positions_embeddings(length, depth,
1494-
max_relative_position, name):
1495-
"""Generates tensor of size [length, length, depth]."""
1497+
max_relative_position, name,
1498+
cache=False):
1499+
"""
1500+
Generates tensor of size [length, length, depth] if not cache.
1501+
Generates tensor of size [1, length, depth] if cache.
1502+
"""
14961503
with tf.variable_scope(name):
14971504
relative_positions_matrix = _generate_relative_positions_matrix(
1498-
length, max_relative_position)
1505+
length, max_relative_position, cache=cache)
14991506
vocab_size = max_relative_position * 2 + 1
15001507
# Generates embedding for each relative position of dimension depth.
15011508
embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
@@ -1509,9 +1516,9 @@ def _relative_attention_inner(x, y, z, transpose):
15091516
This batches matrix multiply calculations to avoid unnecessary broadcasting.
15101517
15111518
Args:
1512-
x: Tensor with shape [batch_size, heads, length, length or depth].
1513-
y: Tensor with shape [batch_size, heads, length, depth].
1514-
z: Tensor with shape [length, length, depth].
1519+
x: Tensor with shape [batch_size, heads, length or 1, length or depth].
1520+
y: Tensor with shape [batch_size, heads, length or 1, depth].
1521+
z: Tensor with shape [length or 1, length, depth].
15151522
transpose: Whether to transpose inner matrices of y and z. Should be true if
15161523
last dimension of x is depth, not length.
15171524
@@ -1522,17 +1529,17 @@ def _relative_attention_inner(x, y, z, transpose):
15221529
heads = x.get_shape().as_list()[1]
15231530
length = tf.shape(x)[2]
15241531

1525-
# xy_matmul is [batch_size, heads, length, length or depth]
1532+
# xy_matmul is [batch_size, heads, length or 1, length or depth]
15261533
xy_matmul = tf.matmul(x, y, transpose_b=transpose)
1527-
# x_t is [length, batch_size, heads, length or depth]
1534+
# x_t is [length or 1, batch_size, heads, length or depth]
15281535
x_t = tf.transpose(x, [2, 0, 1, 3])
1529-
# x_t_r is [length, batch_size * heads, length or depth]
1536+
# x_t_r is [length or 1, batch_size * heads, length or depth]
15301537
x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])
1531-
# x_tz_matmul is [length, batch_size * heads, length or depth]
1538+
# x_tz_matmul is [length or 1, batch_size * heads, length or depth]
15321539
x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)
1533-
# x_tz_matmul_r is [length, batch_size, heads, length or depth]
1540+
# x_tz_matmul_r is [length or 1, batch_size, heads, length or depth]
15341541
x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])
1535-
# x_tz_matmul_r_t is [batch_size, heads, length, length or depth]
1542+
# x_tz_matmul_r_t is [batch_size, heads, length or 1, length or depth]
15361543
x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])
15371544
return xy_matmul + x_tz_matmul_r_t
15381545

@@ -1545,7 +1552,8 @@ def dot_product_attention_relative(q,
15451552
dropout_rate=0.0,
15461553
image_shapes=None,
15471554
name=None,
1548-
make_image_summary=True):
1555+
make_image_summary=True,
1556+
cache=False):
15491557
"""Calculate relative position-aware dot-product self-attention.
15501558
15511559
The attention calculation is augmented with learned representations for the
@@ -1562,6 +1570,7 @@ def dot_product_attention_relative(q,
15621570
image_shapes: optional tuple of integer scalars.
15631571
name: an optional string.
15641572
make_image_summary: Whether to make an attention image summary.
1573+
cache: whether use cache mode
15651574
15661575
Returns:
15671576
A Tensor.
@@ -1577,16 +1586,17 @@ def dot_product_attention_relative(q,
15771586

15781587
# This calculation only works for self attention.
15791588
# q, k and v must therefore have the same shape.
1580-
q.get_shape().assert_is_compatible_with(k.get_shape())
1581-
q.get_shape().assert_is_compatible_with(v.get_shape())
1589+
if not cache:
1590+
q.get_shape().assert_is_compatible_with(k.get_shape())
1591+
q.get_shape().assert_is_compatible_with(v.get_shape())
15821592

15831593
# Use separate embeddings suitable for keys and values.
1584-
depth = q.get_shape().as_list()[3]
1585-
length = common_layers.shape_list(q)[2]
1594+
depth = k.get_shape().as_list()[3]
1595+
length = common_layers.shape_list(k)[2]
15861596
relations_keys = _generate_relative_positions_embeddings(
1587-
length, depth, max_relative_position, "relative_positions_keys")
1597+
length, depth, max_relative_position, "relative_positions_keys", cache=cache)
15881598
relations_values = _generate_relative_positions_embeddings(
1589-
length, depth, max_relative_position, "relative_positions_values")
1599+
length, depth, max_relative_position, "relative_positions_values", cache=cache)
15901600

15911601
# Compute self attention considering the relative position embeddings.
15921602
logits = _relative_attention_inner(q, k, relations_keys, True)
@@ -3389,7 +3399,7 @@ def multihead_attention(query_antecedent,
33893399
kv_filter_width, q_padding, kv_padding,
33903400
vars_3d_num_heads=vars_3d_num_heads)
33913401
if cache is not None:
3392-
if attention_type != "dot_product":
3402+
if attention_type not in ["dot_product", "dot_product_relative"]:
33933403
# TODO(petershaw): Support caching when using relative position
33943404
# representations, i.e. "dot_product_relative" attention.
33953405
raise NotImplementedError(
@@ -3456,7 +3466,8 @@ def multihead_attention(query_antecedent,
34563466
max_relative_position,
34573467
dropout_rate,
34583468
image_shapes,
3459-
make_image_summary=make_image_summary)
3469+
make_image_summary=make_image_summary,
3470+
cache=cache is not None)
34603471
elif attention_type == "dot_product_unmasked_relative_v2":
34613472
x = dot_product_unmasked_self_attention_relative_v2(
34623473
q,

tensor2tensor/models/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def _beam_decode(self,
273273
None if using greedy decoding (beam_size=1)
274274
}
275275
"""
276-
if self._hparams.self_attention_type != "dot_product":
276+
if self._hparams.self_attention_type not in ["dot_product", "dot_product_relative1"]:
277277
# Caching is not guaranteed to work with attention types other than
278278
# dot_product.
279279
# TODO(petershaw): Support fast decoding when using relative

0 commit comments

Comments
 (0)