Skip to content

Commit e0d9825

Browse files
Internal change
PiperOrigin-RevId: 464067452
1 parent ee4f707 commit e0d9825

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

official/nlp/modeling/layers/kernel_attention.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
5858
Returns:
5959
Padded tensor with shape[axis] divisible by chunk_length.
6060
"""
61+
if padding is None:
62+
return tensor
6163
shape = tf.shape(tensor)
6264
rank = tf.rank(tensor)
6365
if axis < 0:
@@ -68,14 +70,9 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
6870
axis_paddings = [[0, pad_length]]
6971
elif padding == "left":
7072
axis_paddings = [[pad_length, 0]]
71-
elif padding is None:
72-
if pad_length != 0:
73-
raise ValueError("When padding is None, the axis dimension"
74-
"has to be divisible by the chunk_length.")
75-
return tensor
7673
else:
77-
raise ValueError("Illegal padding value; must be one of \"left\""
78-
"\"right\" or None.")
74+
raise ValueError(
75+
"Illegal padding value; must be one of \"left\", \"right\" or None.")
7976
paddings = tf.concat(
8077
[tf.zeros([axis, 2], dtype=tf.int32),
8178
axis_paddings,
@@ -109,16 +106,18 @@ def causal_windowed_performer_attention(query_matrix,
109106
padding=None):
110107
"""Applies windowed causal kernel attention with query, key, value tensors.
111108
112-
We partition the T-length input sequence into N chunks, each of chunk_length
113-
tokens (thus: T = N * chunk_length). Within each chunk, we apply bidirectional
114-
(non-causal) Performers’ implicit attention and we model relationships between
115-
different chunks using Performers’ causal attention. We consider windowed
116-
causal variant of performer, where the current chunk attends only to the
117-
window of window_length of the most recent chunks.
109+
We partition the T-length input sequence into N chunks, each of
110+
chunk_length tokens (thus: T = N * chunk_length). Within each chunk,
111+
we apply bidirectional (non-causal) Performers’ implicit attention
112+
and we model relationships between different chunks using
113+
Performers’ causal attention. We consider windowed causal variant of
114+
performer, where the current chunk attends only to the window of
115+
window_length of the most recent chunks.
116+
117+
Below is an example with T=9, chunk_length=3, window_length=2. In
118+
this example 1 indicates attention is computed between the pair
119+
while 0 indicates attention is not computed between the pairs:
118120
119-
Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
120-
attention is computed between the pair while 0 indicates attention is not
121-
computed between the pairs:
122121
111000000
123122
111000000
124123
111000000
@@ -454,7 +453,7 @@ def __init__(self,
454453
scale_by_length=False,
455454
use_causal_windowed=False,
456455
causal_chunk_length=1,
457-
causal_window_length=1,
456+
causal_window_length=3,
458457
causal_padding=None,
459458
**kwargs):
460459
r"""Constructor of KernelAttention.

official/nlp/modeling/layers/kernel_attention_test.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from official.nlp.modeling.layers import kernel_attention as attention
2222

2323

24-
_FEATURE_TRANSFORM = ['relu', 'elu', 'exp', 'expplus']
24+
_FEATURE_TRANSFORM = ["relu", "elu", "exp", "expplus"]
2525
_REDRAW = [True, False]
2626
_TRAINING = [True, False]
2727
_IS_SHORT_SEQ = [True, False]
@@ -62,10 +62,10 @@ def test_attention_projection(
6262

6363
@parameterized.parameters(
6464
itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
65-
[0]))
65+
[0], [None, "left", "right"]))
6666
def test_causal_windowed_attention_projection(
6767
self, feature_transform, num_random_features, training, redraw,
68-
begin_kernel):
68+
begin_kernel, causal_padding):
6969
num_heads = 12
7070
key_dim = 64
7171
seq_length = 1024
@@ -80,7 +80,8 @@ def test_causal_windowed_attention_projection(
8080
begin_kernel=begin_kernel,
8181
use_causal_windowed=True,
8282
causal_chunk_length=8,
83-
causal_window_length=3)
83+
causal_window_length=3,
84+
causal_padding=causal_padding)
8485
query = tf.random.normal(
8586
shape=(batch_size, seq_length, key_dim))
8687
value = query
@@ -150,14 +151,14 @@ def test_attention_scale_by_length(self, seq_length):
150151
self.assertNotAllClose(output_scale_by_length, output_no_scale_by_length)
151152

152153
def test_unsupported_feature_transform(self):
153-
with self.assertRaisesRegex(ValueError, 'Unsupported feature_transform.*'):
154-
_ = attention.KernelAttention(feature_transform='test')
154+
with self.assertRaisesRegex(ValueError, "Unsupported feature_transform.*"):
155+
_ = attention.KernelAttention(feature_transform="test")
155156

156157
def test_redraw_true_no_projection(self):
157158
with self.assertRaisesRegex(
158-
ValueError, 'There is nothing to redraw when num_random_features.*'):
159+
ValueError, "There is nothing to redraw when num_random_features.*"):
159160
_ = attention.KernelAttention(
160-
num_heads=2, key_dim=64, feature_transform='elu',
161+
num_heads=2, key_dim=64, feature_transform="elu",
161162
num_random_features=0, redraw=True)
162163

163164
def test_config(self):
@@ -166,13 +167,13 @@ def test_config(self):
166167
test_layer = attention.KernelAttention(
167168
num_heads=num_heads,
168169
key_dim=key_dim,
169-
feature_transform='exp',
170+
feature_transform="exp",
170171
num_random_features=128,
171172
is_short_seq=True)
172173
new_layer = attention.KernelAttention.from_config(
173174
test_layer.get_config())
174175
# If the serialization was successful, the new config should match the old.
175176
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
176177

177-
if __name__ == '__main__':
178+
if __name__ == "__main__":
178179
tf.test.main()

0 commit comments

Comments
 (0)