Skip to content

Commit 3b0d978

Browse files
authored
MultiHeadAttention Layer (#1062)
* Add MultiHeadAttention Layer
1 parent fee4710 commit 3b0d978

File tree

5 files changed

+651
-0
lines changed

5 files changed

+651
-0
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
/tensorflow_addons/layers/gelu*.py @aakashkumarnain
5151
/tensorflow_addons/layers/maxout*.py @failure-to-thrive
52+
/tensorflow_addons/layers/multihead_attention*.py @cgarciae
5253
/tensorflow_addons/layers/netvlad*.py @joel-shor
5354
/tensorflow_addons/layers/normalizations*.py @smokrow
5455
/tensorflow_addons/layers/optical_flow*.py @failure-to-thrive

tensorflow_addons/layers/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ py_library(
88
"__init__.py",
99
"gelu.py",
1010
"maxout.py",
11+
"multihead_attention.py",
1112
"netvlad.py",
1213
"normalizations.py",
1314
"optical_flow.py",
@@ -147,3 +148,15 @@ py_test(
147148
":layers",
148149
],
149150
)
151+
152+
py_test(
153+
name = "multihead_attention_test",
154+
size = "small",
155+
srcs = [
156+
"multihead_attention_test.py",
157+
],
158+
main = "multihead_attention_test.py",
159+
deps = [
160+
":layers",
161+
],
162+
)

tensorflow_addons/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from tensorflow_addons.layers.gelu import GELU
1818
from tensorflow_addons.layers.maxout import Maxout
19+
from tensorflow_addons.layers.multihead_attention import MultiHeadAttention
1920
from tensorflow_addons.layers.normalizations import GroupNormalization
2021
from tensorflow_addons.layers.normalizations import InstanceNormalization
2122
from tensorflow_addons.layers.optical_flow import CorrelationCost
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# =============================================================================
15+
16+
import typing
17+
18+
import tensorflow as tf
19+
20+
21+
@tf.keras.utils.register_keras_serializable(package="Addons")
22+
class MultiHeadAttention(tf.keras.layers.Layer):
23+
r"""
24+
MultiHead Attention layer.
25+
26+
Defines the MultiHead Attention operation as defined in
27+
[Attention Is All You Need](https://arxiv.org/abs/1706.03762) which takes
28+
in a `query`, `key` and `value` tensors returns the dot-product attention
29+
between them:
30+
31+
```python
32+
mha = MultiHeadAttention(head_size=128, num_heads=128)
33+
34+
query = tf.random.uniform((32, 20, 200)) # (batch_size, query_elements, query_depth)
35+
key = tf.random.uniform((32, 15, 300)) # (batch_size, key_elements, key_depth)
36+
value = tf.random.uniform((32, 15, 400)) # (batch_size, key_elements, value_depth)
37+
38+
attention = mha([query, key, value]) # (batch_size, query_elements, value_depth)
39+
```
40+
41+
If `value` is not given then internally `value = key` will be used:
42+
43+
```python
44+
mha = MultiHeadAttention(head_size=128, num_heads=128)
45+
46+
query = tf.random.uniform((32, 20, 200)) # (batch_size, query_elements, query_depth)
47+
key = tf.random.uniform((32, 15, 300)) # (batch_size, key_elements, key_depth)
48+
49+
attention = mha([query, key]) # (batch_size, query_elements, key_depth)
50+
```
51+
52+
Arguments
53+
head_size: int, dimensionality of the `query`, `key` and `value` tensors
54+
after the linear transformation.
55+
num_heads: int, number of attention heads.
56+
output_size: int, dimensionality of the output space, if `None` then the
57+
input dimension of
58+
`value` or `key` will be used, default `None`.
59+
dropout: float, `rate` parameter for the dropout layer that is
60+
applied to attention after softmax,
61+
default `0`.
62+
use_projection_bias: bool, whether to use a bias term after the linear
63+
output projection.
64+
return_attn_coef: bool, if `True`, return the attention coefficients as
65+
an additional output argument.
66+
kernel_initializer: initializer, initializer for the kernel weights.
67+
kernel_regularizer: regularizer, regularizer for the kernel weights.
68+
kernel_constraint: constraint, constraint for the kernel weights.
69+
bias_initializer: initializer, initializer for the bias weights.
70+
bias_regularizer: regularizer, regularizer for the bias weights.
71+
bias_constraint: constraint, constraint for the bias weights.
72+
73+
Call Arguments
74+
inputs: List of the following tensors:
75+
* `query`: Tensor of shape `(..., query_elements, query_depth)`
76+
* `key`: `Tensor of shape '(..., key_elements, key_depth)`
77+
* `value`: Tensor of shape `(..., key_elements, value_depth)` (optional)
78+
mask: a binary Tensor of shape `[batch_size?, num_heads?, query_elements, key_elements]`
79+
which specifies which query elements can attendo to which key elements,
80+
`1` indicates attention and `0` indicates no attention.
81+
82+
Output shape
83+
- `(..., query_elements, output_size)` if `output_size` is given, else
84+
- `(..., query_elements, value_depth)` if `value` is given, else
85+
- `(..., query_elements, key_depth)`
86+
"""
87+
88+
def __init__(
89+
self,
90+
head_size: int,
91+
num_heads: int,
92+
output_size: int = None,
93+
dropout: float = 0.0,
94+
use_projection_bias: bool = True,
95+
return_attn_coef: bool = False,
96+
kernel_initializer: typing.Union[str, typing.Callable] = "glorot_uniform",
97+
kernel_regularizer: typing.Union[str, typing.Callable] = None,
98+
kernel_constraint: typing.Union[str, typing.Callable] = None,
99+
bias_initializer: typing.Union[str, typing.Callable] = "zeros",
100+
bias_regularizer: typing.Union[str, typing.Callable] = None,
101+
bias_constraint: typing.Union[str, typing.Callable] = None,
102+
**kwargs
103+
):
104+
super().__init__(**kwargs)
105+
106+
if output_size is not None and output_size < 1:
107+
raise ValueError("output_size must be a positive number")
108+
109+
self.head_size = head_size
110+
self.num_heads = num_heads
111+
self.output_size = output_size
112+
self.use_projection_bias = use_projection_bias
113+
self.return_attn_coef = return_attn_coef
114+
115+
self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
116+
self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
117+
self.kernel_constraint = tf.keras.constraints.get(kernel_constraint)
118+
self.bias_initializer = tf.keras.initializers.get(bias_initializer)
119+
self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
120+
self.bias_constraint = tf.keras.constraints.get(bias_constraint)
121+
122+
self.dropout = tf.keras.layers.Dropout(dropout)
123+
self._droput_rate = dropout
124+
125+
def build(self, input_shape):
126+
127+
num_query_features = input_shape[0][-1]
128+
num_key_features = input_shape[1][-1]
129+
num_value_features = (
130+
input_shape[2][-1] if len(input_shape) > 2 else num_key_features
131+
)
132+
output_size = (
133+
self.output_size if self.output_size is not None else num_value_features
134+
)
135+
136+
self.query_kernel = self.add_weight(
137+
name="query_kernel",
138+
shape=[self.num_heads, num_query_features, self.head_size],
139+
initializer=self.kernel_initializer,
140+
regularizer=self.kernel_regularizer,
141+
constraint=self.kernel_constraint,
142+
)
143+
self.key_kernel = self.add_weight(
144+
name="key_kernel",
145+
shape=[self.num_heads, num_key_features, self.head_size],
146+
initializer=self.kernel_initializer,
147+
regularizer=self.kernel_regularizer,
148+
constraint=self.kernel_constraint,
149+
)
150+
self.value_kernel = self.add_weight(
151+
name="value_kernel",
152+
shape=[self.num_heads, num_value_features, self.head_size],
153+
initializer=self.kernel_initializer,
154+
regularizer=self.kernel_regularizer,
155+
constraint=self.kernel_constraint,
156+
)
157+
self.projection_kernel = self.add_weight(
158+
name="projection_kernel",
159+
shape=[self.num_heads, self.head_size, output_size],
160+
initializer=self.kernel_initializer,
161+
regularizer=self.kernel_regularizer,
162+
constraint=self.kernel_constraint,
163+
)
164+
165+
if self.use_projection_bias:
166+
self.projection_bias = self.add_weight(
167+
name="projection_bias",
168+
shape=[output_size],
169+
initializer=self.bias_initializer,
170+
regularizer=self.bias_regularizer,
171+
constraint=self.bias_constraint,
172+
)
173+
else:
174+
self.projection_bias = None
175+
176+
super().build(input_shape)
177+
178+
def call(self, inputs, training=None, mask=None):
179+
180+
# einsum nomenclature
181+
# ------------------------
182+
# N = query elements
183+
# M = key/value elements
184+
# H = heads
185+
# I = input features
186+
# O = output features
187+
188+
query = inputs[0]
189+
key = inputs[1]
190+
value = inputs[2] if len(inputs) > 2 else key
191+
192+
# verify shapes
193+
if mask is not None:
194+
if len(mask.shape) < 2:
195+
raise ValueError("'mask' must have atleast 2 dimensions")
196+
if query.shape[-2] != mask.shape[-2]:
197+
raise ValueError(
198+
"mask's second to last dimension must be equal to the number of elements in 'query'"
199+
)
200+
if key.shape[-2] != mask.shape[-1]:
201+
raise ValueError(
202+
"mask's last dimension must be equal to the number of elements in 'key'"
203+
)
204+
if key.shape[-2] != value.shape[-2]:
205+
raise ValueError(
206+
"the number of elements in 'key' must be equal to the same as the number of elements in 'value'"
207+
)
208+
209+
# Linear transformations
210+
query = tf.einsum("...NI , HIO -> ...NHO", query, self.query_kernel)
211+
key = tf.einsum("...MI , HIO -> ...MHO", key, self.key_kernel)
212+
value = tf.einsum("...MI , HIO -> ...MHO", value, self.value_kernel)
213+
214+
# Scale dot-product, doing the division to either query or key
215+
# instead of their product saves some computation
216+
depth = tf.constant(self.head_size, dtype=tf.float32)
217+
query /= tf.sqrt(depth)
218+
219+
# Calculate dot product attention
220+
logits = tf.einsum("...NHO,...MHO->...HNM", query, key)
221+
222+
# apply mask
223+
if mask is not None:
224+
mask = tf.cast(mask, tf.float32)
225+
226+
# possibly expand on the head dimension so broadcasting works
227+
if len(mask.shape) != len(logits.shape):
228+
mask = tf.expand_dims(mask, -3)
229+
230+
logits += -10e9 * (1.0 - mask)
231+
232+
attn_coef = tf.nn.softmax(logits)
233+
234+
# attention dropout
235+
attn_coef_dropout = self.dropout(attn_coef, training=training)
236+
237+
# attention * value
238+
multihead_output = tf.einsum("...HNM,...MHI->...NHI", attn_coef_dropout, value)
239+
240+
# Run the outputs through another linear projection layer. Recombining heads
241+
# is automatically done.
242+
output = tf.einsum(
243+
"...NHI,HIO->...NO", multihead_output, self.projection_kernel
244+
)
245+
246+
if self.projection_bias is not None:
247+
output += self.projection_bias
248+
249+
if self.return_attn_coef:
250+
return output, attn_coef
251+
else:
252+
return output
253+
254+
def compute_output_shape(self, input_shape):
255+
num_value_features = (
256+
input_shape[2][-1] if len(input_shape) > 2 else input_shape[1][-1]
257+
)
258+
output_size = (
259+
self.output_size if self.output_size is not None else num_value_features
260+
)
261+
262+
output_shape = input_shape[0][:-1] + (output_size,)
263+
264+
if self.return_attn_coef:
265+
num_query_elements = input_shape[0][-2]
266+
num_key_elements = input_shape[1][-2]
267+
attn_coef_shape = input_shape[0][:-2] + (
268+
self.num_heads,
269+
num_query_elements,
270+
num_key_elements,
271+
)
272+
273+
return output_shape, attn_coef_shape
274+
else:
275+
return output_shape
276+
277+
def get_config(self):
278+
config = super().get_config()
279+
280+
config.update(
281+
head_size=self.head_size,
282+
num_heads=self.num_heads,
283+
output_size=self.output_size,
284+
dropout=self._droput_rate,
285+
use_projection_bias=self.use_projection_bias,
286+
return_attn_coef=self.return_attn_coef,
287+
kernel_initializer=tf.keras.initializers.serialize(self.kernel_initializer),
288+
kernel_regularizer=tf.keras.regularizers.serialize(self.kernel_regularizer),
289+
kernel_constraint=tf.keras.constraints.serialize(self.kernel_constraint),
290+
bias_initializer=tf.keras.initializers.serialize(self.bias_initializer),
291+
bias_regularizer=tf.keras.regularizers.serialize(self.bias_regularizer),
292+
bias_constraint=tf.keras.constraints.serialize(self.bias_constraint),
293+
)
294+
295+
return config

0 commit comments

Comments
 (0)