21
21
from official .nlp .modeling .layers import kernel_attention as attention
22
22
23
23
24
- _FEATURE_TRANSFORM = [' relu' , ' elu' , ' exp' , ' expplus' ]
24
+ _FEATURE_TRANSFORM = [" relu" , " elu" , " exp" , " expplus" ]
25
25
_REDRAW = [True , False ]
26
26
_TRAINING = [True , False ]
27
27
_IS_SHORT_SEQ = [True , False ]
@@ -62,10 +62,10 @@ def test_attention_projection(
62
62
63
63
@parameterized .parameters (
64
64
itertools .product (_FEATURE_TRANSFORM , [127 ], _TRAINING , [True , False ],
65
- [0 ]))
65
+ [0 ], [ None , "left" , "right" ] ))
66
66
def test_causal_windowed_attention_projection (
67
67
self , feature_transform , num_random_features , training , redraw ,
68
- begin_kernel ):
68
+ begin_kernel , causal_padding ):
69
69
num_heads = 12
70
70
key_dim = 64
71
71
seq_length = 1024
@@ -80,7 +80,8 @@ def test_causal_windowed_attention_projection(
80
80
begin_kernel = begin_kernel ,
81
81
use_causal_windowed = True ,
82
82
causal_chunk_length = 8 ,
83
- causal_window_length = 3 )
83
+ causal_window_length = 3 ,
84
+ causal_padding = causal_padding )
84
85
query = tf .random .normal (
85
86
shape = (batch_size , seq_length , key_dim ))
86
87
value = query
@@ -150,14 +151,14 @@ def test_attention_scale_by_length(self, seq_length):
150
151
self .assertNotAllClose (output_scale_by_length , output_no_scale_by_length )
151
152
152
153
def test_unsupported_feature_transform (self ):
153
- with self .assertRaisesRegex (ValueError , ' Unsupported feature_transform.*' ):
154
- _ = attention .KernelAttention (feature_transform = ' test' )
154
+ with self .assertRaisesRegex (ValueError , " Unsupported feature_transform.*" ):
155
+ _ = attention .KernelAttention (feature_transform = " test" )
155
156
156
157
def test_redraw_true_no_projection (self ):
157
158
with self .assertRaisesRegex (
158
- ValueError , ' There is nothing to redraw when num_random_features.*' ):
159
+ ValueError , " There is nothing to redraw when num_random_features.*" ):
159
160
_ = attention .KernelAttention (
160
- num_heads = 2 , key_dim = 64 , feature_transform = ' elu' ,
161
+ num_heads = 2 , key_dim = 64 , feature_transform = " elu" ,
161
162
num_random_features = 0 , redraw = True )
162
163
163
164
def test_config (self ):
@@ -166,13 +167,13 @@ def test_config(self):
166
167
test_layer = attention .KernelAttention (
167
168
num_heads = num_heads ,
168
169
key_dim = key_dim ,
169
- feature_transform = ' exp' ,
170
+ feature_transform = " exp" ,
170
171
num_random_features = 128 ,
171
172
is_short_seq = True )
172
173
new_layer = attention .KernelAttention .from_config (
173
174
test_layer .get_config ())
174
175
# If the serialization was successful, the new config should match the old.
175
176
self .assertAllEqual (test_layer .get_config (), new_layer .get_config ())
176
177
177
- if __name__ == ' __main__' :
178
+ if __name__ == " __main__" :
178
179
tf .test .main ()
0 commit comments