@@ -1477,11 +1477,14 @@ def dot_product_attention(q,
1477
1477
return tf .matmul (weights , v )
1478
1478
1479
1479
1480
- def _generate_relative_positions_matrix (length , max_relative_position ):
1480
+ def _generate_relative_positions_matrix (length , max_relative_position , cache = False ):
1481
1481
"""Generates matrix of relative positions between inputs."""
1482
- range_vec = tf .range (length )
1483
- range_mat = tf .reshape (tf .tile (range_vec , [length ]), [length , length ])
1484
- distance_mat = range_mat - tf .transpose (range_mat )
1482
+ if not cache :
1483
+ range_vec = tf .range (length )
1484
+ range_mat = tf .reshape (tf .tile (range_vec , [length ]), [length , length ])
1485
+ distance_mat = range_mat - tf .transpose (range_mat )
1486
+ else :
1487
+ distance_mat = tf .expand_dims (tf .range (- length + 1 , 1 , 1 ), 0 )
1485
1488
distance_mat_clipped = tf .clip_by_value (distance_mat , - max_relative_position ,
1486
1489
max_relative_position )
1487
1490
# Shift values to be >= 0. Each integer still uniquely identifies a relative
@@ -1491,11 +1494,15 @@ def _generate_relative_positions_matrix(length, max_relative_position):
1491
1494
1492
1495
1493
1496
def _generate_relative_positions_embeddings (length , depth ,
1494
- max_relative_position , name ):
1495
- """Generates tensor of size [length, length, depth]."""
1497
+ max_relative_position , name ,
1498
+ cache = False ):
1499
+ """
1500
+ Generates tensor of size [length, length, depth] if not cache.
1501
+ Generates tensor of size [1, length, depth] if cache.
1502
+ """
1496
1503
with tf .variable_scope (name ):
1497
1504
relative_positions_matrix = _generate_relative_positions_matrix (
1498
- length , max_relative_position )
1505
+ length , max_relative_position , cache = cache )
1499
1506
vocab_size = max_relative_position * 2 + 1
1500
1507
# Generates embedding for each relative position of dimension depth.
1501
1508
embeddings_table = tf .get_variable ("embeddings" , [vocab_size , depth ])
@@ -1509,9 +1516,9 @@ def _relative_attention_inner(x, y, z, transpose):
1509
1516
This batches matrix multiply calculations to avoid unnecessary broadcasting.
1510
1517
1511
1518
Args:
1512
- x: Tensor with shape [batch_size, heads, length, length or depth].
1513
- y: Tensor with shape [batch_size, heads, length, depth].
1514
- z: Tensor with shape [length, length, depth].
1519
+ x: Tensor with shape [batch_size, heads, length or 1 , length or depth].
1520
+ y: Tensor with shape [batch_size, heads, length or 1 , depth].
1521
+ z: Tensor with shape [length or 1 , length, depth].
1515
1522
transpose: Whether to transpose inner matrices of y and z. Should be true if
1516
1523
last dimension of x is depth, not length.
1517
1524
@@ -1522,17 +1529,17 @@ def _relative_attention_inner(x, y, z, transpose):
1522
1529
heads = x .get_shape ().as_list ()[1 ]
1523
1530
length = tf .shape (x )[2 ]
1524
1531
1525
- # xy_matmul is [batch_size, heads, length, length or depth]
1532
+ # xy_matmul is [batch_size, heads, length or 1 , length or depth]
1526
1533
xy_matmul = tf .matmul (x , y , transpose_b = transpose )
1527
- # x_t is [length, batch_size, heads, length or depth]
1534
+ # x_t is [length or 1 , batch_size, heads, length or depth]
1528
1535
x_t = tf .transpose (x , [2 , 0 , 1 , 3 ])
1529
- # x_t_r is [length, batch_size * heads, length or depth]
1536
+ # x_t_r is [length or 1 , batch_size * heads, length or depth]
1530
1537
x_t_r = tf .reshape (x_t , [length , heads * batch_size , - 1 ])
1531
- # x_tz_matmul is [length, batch_size * heads, length or depth]
1538
+ # x_tz_matmul is [length or 1 , batch_size * heads, length or depth]
1532
1539
x_tz_matmul = tf .matmul (x_t_r , z , transpose_b = transpose )
1533
- # x_tz_matmul_r is [length, batch_size, heads, length or depth]
1540
+ # x_tz_matmul_r is [length or 1 , batch_size, heads, length or depth]
1534
1541
x_tz_matmul_r = tf .reshape (x_tz_matmul , [length , batch_size , heads , - 1 ])
1535
- # x_tz_matmul_r_t is [batch_size, heads, length, length or depth]
1542
+ # x_tz_matmul_r_t is [batch_size, heads, length or 1 , length or depth]
1536
1543
x_tz_matmul_r_t = tf .transpose (x_tz_matmul_r , [1 , 2 , 0 , 3 ])
1537
1544
return xy_matmul + x_tz_matmul_r_t
1538
1545
@@ -1545,7 +1552,8 @@ def dot_product_attention_relative(q,
1545
1552
dropout_rate = 0.0 ,
1546
1553
image_shapes = None ,
1547
1554
name = None ,
1548
- make_image_summary = True ):
1555
+ make_image_summary = True ,
1556
+ cache = False ):
1549
1557
"""Calculate relative position-aware dot-product self-attention.
1550
1558
1551
1559
The attention calculation is augmented with learned representations for the
@@ -1562,6 +1570,7 @@ def dot_product_attention_relative(q,
1562
1570
image_shapes: optional tuple of integer scalars.
1563
1571
name: an optional string.
1564
1572
make_image_summary: Whether to make an attention image summary.
1573
+ cache: whether use cache mode
1565
1574
1566
1575
Returns:
1567
1576
A Tensor.
@@ -1577,16 +1586,17 @@ def dot_product_attention_relative(q,
1577
1586
1578
1587
# This calculation only works for self attention.
1579
1588
# q, k and v must therefore have the same shape.
1580
- q .get_shape ().assert_is_compatible_with (k .get_shape ())
1581
- q .get_shape ().assert_is_compatible_with (v .get_shape ())
1589
+ if not cache :
1590
+ q .get_shape ().assert_is_compatible_with (k .get_shape ())
1591
+ q .get_shape ().assert_is_compatible_with (v .get_shape ())
1582
1592
1583
1593
# Use separate embeddings suitable for keys and values.
1584
- depth = q .get_shape ().as_list ()[3 ]
1585
- length = common_layers .shape_list (q )[2 ]
1594
+ depth = k .get_shape ().as_list ()[3 ]
1595
+ length = common_layers .shape_list (k )[2 ]
1586
1596
relations_keys = _generate_relative_positions_embeddings (
1587
- length , depth , max_relative_position , "relative_positions_keys" )
1597
+ length , depth , max_relative_position , "relative_positions_keys" , cache = cache )
1588
1598
relations_values = _generate_relative_positions_embeddings (
1589
- length , depth , max_relative_position , "relative_positions_values" )
1599
+ length , depth , max_relative_position , "relative_positions_values" , cache = cache )
1590
1600
1591
1601
# Compute self attention considering the relative position embeddings.
1592
1602
logits = _relative_attention_inner (q , k , relations_keys , True )
@@ -3389,7 +3399,7 @@ def multihead_attention(query_antecedent,
3389
3399
kv_filter_width , q_padding , kv_padding ,
3390
3400
vars_3d_num_heads = vars_3d_num_heads )
3391
3401
if cache is not None :
3392
- if attention_type != "dot_product" :
3402
+ if attention_type not in [ "dot_product" , "dot_product_relative" ] :
3393
3403
# TODO(petershaw): Support caching when using relative position
3394
3404
# representations, i.e. "dot_product_relative" attention.
3395
3405
raise NotImplementedError (
@@ -3456,7 +3466,8 @@ def multihead_attention(query_antecedent,
3456
3466
max_relative_position ,
3457
3467
dropout_rate ,
3458
3468
image_shapes ,
3459
- make_image_summary = make_image_summary )
3469
+ make_image_summary = make_image_summary ,
3470
+ cache = cache is not None )
3460
3471
elif attention_type == "dot_product_unmasked_relative_v2" :
3461
3472
x = dot_product_unmasked_self_attention_relative_v2 (
3462
3473
q ,
0 commit comments