|
| 1 | +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================= |
| 15 | + |
| 16 | +import typing |
| 17 | + |
| 18 | +import tensorflow as tf |
| 19 | + |
| 20 | + |
| 21 | +@tf.keras.utils.register_keras_serializable(package="Addons") |
| 22 | +class MultiHeadAttention(tf.keras.layers.Layer): |
| 23 | + r""" |
| 24 | + MultiHead Attention layer. |
| 25 | +
|
| 26 | + Defines the MultiHead Attention operation as defined in |
| 27 | + [Attention Is All You Need](https://arxiv.org/abs/1706.03762) which takes |
| 28 | + in a `query`, `key` and `value` tensors returns the dot-product attention |
| 29 | + between them: |
| 30 | +
|
| 31 | + ```python |
| 32 | + mha = MultiHeadAttention(head_size=128, num_heads=128) |
| 33 | +
|
| 34 | + query = tf.random.uniform((32, 20, 200)) # (batch_size, query_elements, query_depth) |
| 35 | + key = tf.random.uniform((32, 15, 300)) # (batch_size, key_elements, key_depth) |
| 36 | + value = tf.random.uniform((32, 15, 400)) # (batch_size, key_elements, value_depth) |
| 37 | +
|
| 38 | + attention = mha([query, key, value]) # (batch_size, query_elements, value_depth) |
| 39 | + ``` |
| 40 | +
|
| 41 | + If `value` is not given then internally `value = key` will be used: |
| 42 | +
|
| 43 | + ```python |
| 44 | + mha = MultiHeadAttention(head_size=128, num_heads=128) |
| 45 | +
|
| 46 | + query = tf.random.uniform((32, 20, 200)) # (batch_size, query_elements, query_depth) |
| 47 | + key = tf.random.uniform((32, 15, 300)) # (batch_size, key_elements, key_depth) |
| 48 | +
|
| 49 | + attention = mha([query, key]) # (batch_size, query_elements, key_depth) |
| 50 | + ``` |
| 51 | +
|
| 52 | + Arguments |
| 53 | + head_size: int, dimensionality of the `query`, `key` and `value` tensors |
| 54 | + after the linear transformation. |
| 55 | + num_heads: int, number of attention heads. |
| 56 | + output_size: int, dimensionality of the output space, if `None` then the |
| 57 | + input dimension of |
| 58 | + `value` or `key` will be used, default `None`. |
| 59 | + dropout: float, `rate` parameter for the dropout layer that is |
| 60 | + applied to attention after softmax, |
| 61 | + default `0`. |
| 62 | + use_projection_bias: bool, whether to use a bias term after the linear |
| 63 | + output projection. |
| 64 | + return_attn_coef: bool, if `True`, return the attention coefficients as |
| 65 | + an additional output argument. |
| 66 | + kernel_initializer: initializer, initializer for the kernel weights. |
| 67 | + kernel_regularizer: regularizer, regularizer for the kernel weights. |
| 68 | + kernel_constraint: constraint, constraint for the kernel weights. |
| 69 | + bias_initializer: initializer, initializer for the bias weights. |
| 70 | + bias_regularizer: regularizer, regularizer for the bias weights. |
| 71 | + bias_constraint: constraint, constraint for the bias weights. |
| 72 | +
|
| 73 | + Call Arguments |
| 74 | + inputs: List of the following tensors: |
| 75 | + * `query`: Tensor of shape `(..., query_elements, query_depth)` |
| 76 | + * `key`: `Tensor of shape '(..., key_elements, key_depth)` |
| 77 | + * `value`: Tensor of shape `(..., key_elements, value_depth)` (optional) |
| 78 | + mask: a binary Tensor of shape `[batch_size?, num_heads?, query_elements, key_elements]` |
| 79 | + which specifies which query elements can attendo to which key elements, |
| 80 | + `1` indicates attention and `0` indicates no attention. |
| 81 | +
|
| 82 | + Output shape |
| 83 | + - `(..., query_elements, output_size)` if `output_size` is given, else |
| 84 | + - `(..., query_elements, value_depth)` if `value` is given, else |
| 85 | + - `(..., query_elements, key_depth)` |
| 86 | + """ |
| 87 | + |
| 88 | + def __init__( |
| 89 | + self, |
| 90 | + head_size: int, |
| 91 | + num_heads: int, |
| 92 | + output_size: int = None, |
| 93 | + dropout: float = 0.0, |
| 94 | + use_projection_bias: bool = True, |
| 95 | + return_attn_coef: bool = False, |
| 96 | + kernel_initializer: typing.Union[str, typing.Callable] = "glorot_uniform", |
| 97 | + kernel_regularizer: typing.Union[str, typing.Callable] = None, |
| 98 | + kernel_constraint: typing.Union[str, typing.Callable] = None, |
| 99 | + bias_initializer: typing.Union[str, typing.Callable] = "zeros", |
| 100 | + bias_regularizer: typing.Union[str, typing.Callable] = None, |
| 101 | + bias_constraint: typing.Union[str, typing.Callable] = None, |
| 102 | + **kwargs |
| 103 | + ): |
| 104 | + super().__init__(**kwargs) |
| 105 | + |
| 106 | + if output_size is not None and output_size < 1: |
| 107 | + raise ValueError("output_size must be a positive number") |
| 108 | + |
| 109 | + self.head_size = head_size |
| 110 | + self.num_heads = num_heads |
| 111 | + self.output_size = output_size |
| 112 | + self.use_projection_bias = use_projection_bias |
| 113 | + self.return_attn_coef = return_attn_coef |
| 114 | + |
| 115 | + self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) |
| 116 | + self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) |
| 117 | + self.kernel_constraint = tf.keras.constraints.get(kernel_constraint) |
| 118 | + self.bias_initializer = tf.keras.initializers.get(bias_initializer) |
| 119 | + self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer) |
| 120 | + self.bias_constraint = tf.keras.constraints.get(bias_constraint) |
| 121 | + |
| 122 | + self.dropout = tf.keras.layers.Dropout(dropout) |
| 123 | + self._droput_rate = dropout |
| 124 | + |
| 125 | + def build(self, input_shape): |
| 126 | + |
| 127 | + num_query_features = input_shape[0][-1] |
| 128 | + num_key_features = input_shape[1][-1] |
| 129 | + num_value_features = ( |
| 130 | + input_shape[2][-1] if len(input_shape) > 2 else num_key_features |
| 131 | + ) |
| 132 | + output_size = ( |
| 133 | + self.output_size if self.output_size is not None else num_value_features |
| 134 | + ) |
| 135 | + |
| 136 | + self.query_kernel = self.add_weight( |
| 137 | + name="query_kernel", |
| 138 | + shape=[self.num_heads, num_query_features, self.head_size], |
| 139 | + initializer=self.kernel_initializer, |
| 140 | + regularizer=self.kernel_regularizer, |
| 141 | + constraint=self.kernel_constraint, |
| 142 | + ) |
| 143 | + self.key_kernel = self.add_weight( |
| 144 | + name="key_kernel", |
| 145 | + shape=[self.num_heads, num_key_features, self.head_size], |
| 146 | + initializer=self.kernel_initializer, |
| 147 | + regularizer=self.kernel_regularizer, |
| 148 | + constraint=self.kernel_constraint, |
| 149 | + ) |
| 150 | + self.value_kernel = self.add_weight( |
| 151 | + name="value_kernel", |
| 152 | + shape=[self.num_heads, num_value_features, self.head_size], |
| 153 | + initializer=self.kernel_initializer, |
| 154 | + regularizer=self.kernel_regularizer, |
| 155 | + constraint=self.kernel_constraint, |
| 156 | + ) |
| 157 | + self.projection_kernel = self.add_weight( |
| 158 | + name="projection_kernel", |
| 159 | + shape=[self.num_heads, self.head_size, output_size], |
| 160 | + initializer=self.kernel_initializer, |
| 161 | + regularizer=self.kernel_regularizer, |
| 162 | + constraint=self.kernel_constraint, |
| 163 | + ) |
| 164 | + |
| 165 | + if self.use_projection_bias: |
| 166 | + self.projection_bias = self.add_weight( |
| 167 | + name="projection_bias", |
| 168 | + shape=[output_size], |
| 169 | + initializer=self.bias_initializer, |
| 170 | + regularizer=self.bias_regularizer, |
| 171 | + constraint=self.bias_constraint, |
| 172 | + ) |
| 173 | + else: |
| 174 | + self.projection_bias = None |
| 175 | + |
| 176 | + super().build(input_shape) |
| 177 | + |
| 178 | + def call(self, inputs, training=None, mask=None): |
| 179 | + |
| 180 | + # einsum nomenclature |
| 181 | + # ------------------------ |
| 182 | + # N = query elements |
| 183 | + # M = key/value elements |
| 184 | + # H = heads |
| 185 | + # I = input features |
| 186 | + # O = output features |
| 187 | + |
| 188 | + query = inputs[0] |
| 189 | + key = inputs[1] |
| 190 | + value = inputs[2] if len(inputs) > 2 else key |
| 191 | + |
| 192 | + # verify shapes |
| 193 | + if mask is not None: |
| 194 | + if len(mask.shape) < 2: |
| 195 | + raise ValueError("'mask' must have atleast 2 dimensions") |
| 196 | + if query.shape[-2] != mask.shape[-2]: |
| 197 | + raise ValueError( |
| 198 | + "mask's second to last dimension must be equal to the number of elements in 'query'" |
| 199 | + ) |
| 200 | + if key.shape[-2] != mask.shape[-1]: |
| 201 | + raise ValueError( |
| 202 | + "mask's last dimension must be equal to the number of elements in 'key'" |
| 203 | + ) |
| 204 | + if key.shape[-2] != value.shape[-2]: |
| 205 | + raise ValueError( |
| 206 | + "the number of elements in 'key' must be equal to the same as the number of elements in 'value'" |
| 207 | + ) |
| 208 | + |
| 209 | + # Linear transformations |
| 210 | + query = tf.einsum("...NI , HIO -> ...NHO", query, self.query_kernel) |
| 211 | + key = tf.einsum("...MI , HIO -> ...MHO", key, self.key_kernel) |
| 212 | + value = tf.einsum("...MI , HIO -> ...MHO", value, self.value_kernel) |
| 213 | + |
| 214 | + # Scale dot-product, doing the division to either query or key |
| 215 | + # instead of their product saves some computation |
| 216 | + depth = tf.constant(self.head_size, dtype=tf.float32) |
| 217 | + query /= tf.sqrt(depth) |
| 218 | + |
| 219 | + # Calculate dot product attention |
| 220 | + logits = tf.einsum("...NHO,...MHO->...HNM", query, key) |
| 221 | + |
| 222 | + # apply mask |
| 223 | + if mask is not None: |
| 224 | + mask = tf.cast(mask, tf.float32) |
| 225 | + |
| 226 | + # possibly expand on the head dimension so broadcasting works |
| 227 | + if len(mask.shape) != len(logits.shape): |
| 228 | + mask = tf.expand_dims(mask, -3) |
| 229 | + |
| 230 | + logits += -10e9 * (1.0 - mask) |
| 231 | + |
| 232 | + attn_coef = tf.nn.softmax(logits) |
| 233 | + |
| 234 | + # attention dropout |
| 235 | + attn_coef_dropout = self.dropout(attn_coef, training=training) |
| 236 | + |
| 237 | + # attention * value |
| 238 | + multihead_output = tf.einsum("...HNM,...MHI->...NHI", attn_coef_dropout, value) |
| 239 | + |
| 240 | + # Run the outputs through another linear projection layer. Recombining heads |
| 241 | + # is automatically done. |
| 242 | + output = tf.einsum( |
| 243 | + "...NHI,HIO->...NO", multihead_output, self.projection_kernel |
| 244 | + ) |
| 245 | + |
| 246 | + if self.projection_bias is not None: |
| 247 | + output += self.projection_bias |
| 248 | + |
| 249 | + if self.return_attn_coef: |
| 250 | + return output, attn_coef |
| 251 | + else: |
| 252 | + return output |
| 253 | + |
| 254 | + def compute_output_shape(self, input_shape): |
| 255 | + num_value_features = ( |
| 256 | + input_shape[2][-1] if len(input_shape) > 2 else input_shape[1][-1] |
| 257 | + ) |
| 258 | + output_size = ( |
| 259 | + self.output_size if self.output_size is not None else num_value_features |
| 260 | + ) |
| 261 | + |
| 262 | + output_shape = input_shape[0][:-1] + (output_size,) |
| 263 | + |
| 264 | + if self.return_attn_coef: |
| 265 | + num_query_elements = input_shape[0][-2] |
| 266 | + num_key_elements = input_shape[1][-2] |
| 267 | + attn_coef_shape = input_shape[0][:-2] + ( |
| 268 | + self.num_heads, |
| 269 | + num_query_elements, |
| 270 | + num_key_elements, |
| 271 | + ) |
| 272 | + |
| 273 | + return output_shape, attn_coef_shape |
| 274 | + else: |
| 275 | + return output_shape |
| 276 | + |
| 277 | + def get_config(self): |
| 278 | + config = super().get_config() |
| 279 | + |
| 280 | + config.update( |
| 281 | + head_size=self.head_size, |
| 282 | + num_heads=self.num_heads, |
| 283 | + output_size=self.output_size, |
| 284 | + dropout=self._droput_rate, |
| 285 | + use_projection_bias=self.use_projection_bias, |
| 286 | + return_attn_coef=self.return_attn_coef, |
| 287 | + kernel_initializer=tf.keras.initializers.serialize(self.kernel_initializer), |
| 288 | + kernel_regularizer=tf.keras.regularizers.serialize(self.kernel_regularizer), |
| 289 | + kernel_constraint=tf.keras.constraints.serialize(self.kernel_constraint), |
| 290 | + bias_initializer=tf.keras.initializers.serialize(self.bias_initializer), |
| 291 | + bias_regularizer=tf.keras.regularizers.serialize(self.bias_regularizer), |
| 292 | + bias_constraint=tf.keras.constraints.serialize(self.bias_constraint), |
| 293 | + ) |
| 294 | + |
| 295 | + return config |
0 commit comments