|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# The Three Ways of Attention and Dot Product Attention: Ungraded Lab Notebook\n", |
| 8 | + "\n", |
| 9 | + "In this notebook you'll explore the three ways of attention (encoder-decoder attention, causal attention, and bi-directional self attention) and how to implement the latter two with dot product attention. \n", |
| 10 | + "\n", |
| 11 | + "## Background\n", |
| 12 | + "\n", |
| 13 | + "As you learned last week, **attention models** constitute powerful tools in the NLP practitioner's toolkit. Like LSTMs, they learn which words are most important to phrases, sentences, paragraphs, and so on. Moreover, they mitigate the vanishing gradient problem even better than LSTMs. You've already seen how to combine attention with LSTMs to build **encoder-decoder models** for applications such as machine translation. \n", |
| 14 | + "\n", |
| 15 | + "<img src=\"attention_lnb_figs/C4_W2_L3_dot-product-attention_S01_introducing-attention_stripped.png\" width=\"500\"/>\n", |
| 16 | + "\n", |
| 17 | + "This week, you'll see how to integrate attention into **transformers**. Because transformers are not sequence models, they are much easier to parallelize and accelerate. Beyond machine translation, applications of transformers include: \n", |
| 18 | + "* Auto-completion\n", |
| 19 | + "* Named Entity Recognition\n", |
| 20 | + "* Chatbots\n", |
| 21 | + "* Question-Answering\n", |
| 22 | + "* And more!\n", |
| 23 | + "\n", |
| 24 | + "Along with embedding, positional encoding, dense layers, and residual connections, attention is a crucial component of transformers. At the heart of any attention scheme used in a transformer is **dot product attention**, of which the figures below display a simplified picture:\n", |
| 25 | + "\n", |
| 26 | + "<img src=\"attention_lnb_figs/C4_W2_L3_dot-product-attention_S03_concept-of-attention_stripped.png\" width=\"500\"/>\n", |
| 27 | + "\n", |
| 28 | + "<img src=\"attention_lnb_figs/C4_W2_L3_dot-product-attention_S04_attention-math_stripped.png\" width=\"500\"/>\n", |
| 29 | + "\n", |
| 30 | + "With basic dot product attention, you capture the interactions between every word (embedding) in your query and every word in your key. If the queries and keys belong to the same sentences, this constitutes **bi-directional self-attention**. In some situations, however, it's more appropriate to consider only words which have come before the current one. Such cases, particularly when the queries and keys come from the same sentences, fall into the category of **causal attention**. \n", |
| 31 | + "\n", |
| 32 | + "<img src=\"attention_lnb_figs/C4_W2_L4_causal-attention_S02_causal-attention_stripped.png\" width=\"500\"/>\n", |
| 33 | + "\n", |
| 34 | + "For causal attention, we add a **mask** to the argument of our softmax function, as illustrated below: \n", |
| 35 | + "\n", |
| 36 | + "<img src=\"attention_lnb_figs/C4_W2_L4_causal-attention_S03_causal-attention-math_stripped.png\" width=\"500\"/>\n", |
| 37 | + "\n", |
| 38 | + "<img src=\"attention_lnb_figs/C4_W2_L4_causal-attention_S04_causal-attention-math-2_stripped.png\" width=\"500\"/>\n", |
| 39 | + "\n", |
| 40 | + "Now let's see how to implement attention with NumPy. When you integrate attention into a transformer network defined with Trax, you'll have to use `trax.fastmath.numpy` instead, since Trax's arrays are based on JAX DeviceArrays. Fortunately, the function interfaces are often identical." |
| 41 | + ] |
| 42 | + }, |
| 43 | + { |
| 44 | + "cell_type": "markdown", |
| 45 | + "metadata": {}, |
| 46 | + "source": [ |
| 47 | + "## Imports" |
| 48 | + ] |
| 49 | + }, |
| 50 | + { |
| 51 | + "cell_type": "code", |
| 52 | + "execution_count": 1, |
| 53 | + "metadata": {}, |
| 54 | + "outputs": [], |
| 55 | + "source": [ |
| 56 | + "import sys\n", |
| 57 | + "\n", |
| 58 | + "import numpy as np\n", |
| 59 | + "import scipy.special\n", |
| 60 | + "\n", |
| 61 | + "import textwrap\n", |
| 62 | + "wrapper = textwrap.TextWrapper(width=70)\n", |
| 63 | + "\n", |
| 64 | + "# to print the entire np array\n", |
| 65 | + "np.set_printoptions(threshold=sys.maxsize)" |
| 66 | + ] |
| 67 | + }, |
| 68 | + { |
| 69 | + "cell_type": "markdown", |
| 70 | + "metadata": {}, |
| 71 | + "source": [ |
| 72 | + "Here are some helper functions that will help you create tensors and display useful information:\n", |
| 73 | + "\n", |
| 74 | + "* `create_tensor()` creates a numpy array from a list of lists.\n", |
| 75 | + "* `display_tensor()` prints out the shape and the actual tensor." |
| 76 | + ] |
| 77 | + }, |
| 78 | + { |
| 79 | + "cell_type": "code", |
| 80 | + "execution_count": 2, |
| 81 | + "metadata": {}, |
| 82 | + "outputs": [], |
| 83 | + "source": [ |
| 84 | + "def create_tensor(t):\n", |
| 85 | + " \"\"\"Create tensor from list of lists\"\"\"\n", |
| 86 | + " return np.array(t)\n", |
| 87 | + "\n", |
| 88 | + "\n", |
| 89 | + "def display_tensor(t, name):\n", |
| 90 | + " \"\"\"Display shape and tensor\"\"\"\n", |
| 91 | + " print(f'{name} shape: {t.shape}\\n')\n", |
| 92 | + " print(f'{t}\\n')" |
| 93 | + ] |
| 94 | + }, |
| 95 | + { |
| 96 | + "cell_type": "markdown", |
| 97 | + "metadata": {}, |
| 98 | + "source": [ |
| 99 | + "Create some tensors and display their shapes. Feel free to experiment with your own tensors. Keep in mind, though, that the query, key, and value arrays must all have the same embedding dimensions (number of columns), and the mask array must have the same shape as `np.dot(query, key.T)`. " |
| 100 | + ] |
| 101 | + }, |
| 102 | + { |
| 103 | + "cell_type": "code", |
| 104 | + "execution_count": 3, |
| 105 | + "metadata": {}, |
| 106 | + "outputs": [ |
| 107 | + { |
| 108 | + "name": "stdout", |
| 109 | + "output_type": "stream", |
| 110 | + "text": [ |
| 111 | + "query shape: (2, 3)\n", |
| 112 | + "\n", |
| 113 | + "[[1 0 0]\n", |
| 114 | + " [0 1 0]]\n", |
| 115 | + "\n", |
| 116 | + "key shape: (2, 3)\n", |
| 117 | + "\n", |
| 118 | + "[[1 2 3]\n", |
| 119 | + " [4 5 6]]\n", |
| 120 | + "\n", |
| 121 | + "value shape: (2, 3)\n", |
| 122 | + "\n", |
| 123 | + "[[0 1 0]\n", |
| 124 | + " [1 0 1]]\n", |
| 125 | + "\n", |
| 126 | + "mask shape: (2, 2)\n", |
| 127 | + "\n", |
| 128 | + "[[ 0.e+00 0.e+00]\n", |
| 129 | + " [-1.e+09 0.e+00]]\n", |
| 130 | + "\n" |
| 131 | + ] |
| 132 | + } |
| 133 | + ], |
| 134 | + "source": [ |
| 135 | + "q = create_tensor([[1, 0, 0], [0, 1, 0]])\n", |
| 136 | + "display_tensor(q, 'query')\n", |
| 137 | + "k = create_tensor([[1, 2, 3], [4, 5, 6]])\n", |
| 138 | + "display_tensor(k, 'key')\n", |
| 139 | + "v = create_tensor([[0, 1, 0], [1, 0, 1]])\n", |
| 140 | + "display_tensor(v, 'value')\n", |
| 141 | + "m = create_tensor([[0, 0], [-1e9, 0]])\n", |
| 142 | + "display_tensor(m, 'mask')" |
| 143 | + ] |
| 144 | + }, |
| 145 | + { |
| 146 | + "cell_type": "markdown", |
| 147 | + "metadata": {}, |
| 148 | + "source": [ |
| 149 | + "## Dot product attention\n", |
| 150 | + "\n", |
| 151 | + "Here we come to the crux of this lab, in which we compute \n", |
| 152 | + "$\\textrm{softmax} \\left(\\frac{Q K^T}{\\sqrt{d}} + M \\right) V$, where the (optional, but default) scaling factor $\\sqrt{d}$ is the square root of the embedding dimension." |
| 153 | + ] |
| 154 | + }, |
| 155 | + { |
| 156 | + "cell_type": "code", |
| 157 | + "execution_count": 4, |
| 158 | + "metadata": {}, |
| 159 | + "outputs": [], |
| 160 | + "source": [ |
| 161 | + "def DotProductAttention(query, key, value, mask, scale=True):\n", |
| 162 | + " \"\"\"Dot product self-attention.\n", |
| 163 | + " Args:\n", |
| 164 | + " query (numpy.ndarray): array of query representations with shape (L_q by d)\n", |
| 165 | + " key (numpy.ndarray): array of key representations with shape (L_k by d)\n", |
| 166 | + " value (numpy.ndarray): array of value representations with shape (L_k by d) where L_v = L_k\n", |
| 167 | + " mask (numpy.ndarray): attention-mask, gates attention with shape (L_q by L_k)\n", |
| 168 | + " scale (bool): whether to scale the dot product of the query and transposed key\n", |
| 169 | + "\n", |
| 170 | + " Returns:\n", |
| 171 | + " numpy.ndarray: Self-attention array for q, k, v arrays. (L_q by L_k)\n", |
| 172 | + " \"\"\"\n", |
| 173 | + "\n", |
| 174 | + " assert query.shape[-1] == key.shape[-1] == value.shape[-1], \"Embedding dimensions of q, k, v aren't all the same\"\n", |
| 175 | + "\n", |
| 176 | + " # Save depth/dimension of the query embedding for scaling down the dot product\n", |
| 177 | + " if scale: \n", |
| 178 | + " depth = query.shape[-1]\n", |
| 179 | + " else:\n", |
| 180 | + " depth = 1\n", |
| 181 | + "\n", |
| 182 | + " # Calculate scaled query key dot product according to formula above\n", |
| 183 | + " dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) \n", |
| 184 | + " \n", |
| 185 | + " # Apply the mask\n", |
| 186 | + " if mask is not None:\n", |
| 187 | + " dots = np.where(mask, dots, np.full_like(dots, -1e9)) \n", |
| 188 | + " \n", |
| 189 | + " # Softmax formula implementation\n", |
| 190 | + " # Use scipy.special.logsumexp of masked_qkT to avoid underflow by division by large numbers\n", |
| 191 | + " # Note: softmax = e^(dots - logaddexp(dots)) = E^dots / sumexp(dots)\n", |
| 192 | + " logsumexp = scipy.special.logsumexp(dots, axis=-1, keepdims=True)\n", |
| 193 | + "\n", |
| 194 | + " # Take exponential of dots minus logsumexp to get softmax\n", |
| 195 | + " # Use np.exp()\n", |
| 196 | + " dots = np.exp(dots - logsumexp)\n", |
| 197 | + "\n", |
| 198 | + " # Multiply dots by value to get self-attention\n", |
| 199 | + " # Use np.matmul()\n", |
| 200 | + " attention = np.matmul(dots, value)\n", |
| 201 | + " \n", |
| 202 | + " return attention" |
| 203 | + ] |
| 204 | + }, |
| 205 | + { |
| 206 | + "cell_type": "markdown", |
| 207 | + "metadata": {}, |
| 208 | + "source": [ |
| 209 | + "Now let's implement the *masked* dot product self-attention (at the heart of causal attention) as a special case of dot product attention" |
| 210 | + ] |
| 211 | + }, |
| 212 | + { |
| 213 | + "cell_type": "code", |
| 214 | + "execution_count": 5, |
| 215 | + "metadata": {}, |
| 216 | + "outputs": [], |
| 217 | + "source": [ |
| 218 | + "def dot_product_self_attention(q, k, v, scale=True):\n", |
| 219 | + " \"\"\" Masked dot product self attention.\n", |
| 220 | + " Args:\n", |
| 221 | + " q (numpy.ndarray): queries.\n", |
| 222 | + " k (numpy.ndarray): keys.\n", |
| 223 | + " v (numpy.ndarray): values.\n", |
| 224 | + " Returns:\n", |
| 225 | + " numpy.ndarray: masked dot product self attention tensor.\n", |
| 226 | + " \"\"\"\n", |
| 227 | + " \n", |
| 228 | + " # Size of the penultimate dimension of the query\n", |
| 229 | + " mask_size = q.shape[-2]\n", |
| 230 | + "\n", |
| 231 | + " # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size)\n", |
| 232 | + " # Use np.tril() - Lower triangle of an array and np.ones()\n", |
| 233 | + " mask = np.tril(np.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) \n", |
| 234 | + " \n", |
| 235 | + " return DotProductAttention(q, k, v, mask, scale=scale)" |
| 236 | + ] |
| 237 | + }, |
| 238 | + { |
| 239 | + "cell_type": "code", |
| 240 | + "execution_count": 6, |
| 241 | + "metadata": {}, |
| 242 | + "outputs": [ |
| 243 | + { |
| 244 | + "data": { |
| 245 | + "text/plain": [ |
| 246 | + "array([[[0. , 1. , 0. ],\n", |
| 247 | + " [0.84967455, 0.15032545, 0.84967455]]])" |
| 248 | + ] |
| 249 | + }, |
| 250 | + "execution_count": 6, |
| 251 | + "metadata": {}, |
| 252 | + "output_type": "execute_result" |
| 253 | + } |
| 254 | + ], |
| 255 | + "source": [ |
| 256 | + "dot_product_self_attention(q, k, v)" |
| 257 | + ] |
| 258 | + } |
| 259 | + ], |
| 260 | + "metadata": { |
| 261 | + "kernelspec": { |
| 262 | + "display_name": "Python 3", |
| 263 | + "language": "python", |
| 264 | + "name": "python3" |
| 265 | + }, |
| 266 | + "language_info": { |
| 267 | + "codemirror_mode": { |
| 268 | + "name": "ipython", |
| 269 | + "version": 3 |
| 270 | + }, |
| 271 | + "file_extension": ".py", |
| 272 | + "mimetype": "text/x-python", |
| 273 | + "name": "python", |
| 274 | + "nbconvert_exporter": "python", |
| 275 | + "pygments_lexer": "ipython3", |
| 276 | + "version": "3.7.6" |
| 277 | + } |
| 278 | + }, |
| 279 | + "nbformat": 4, |
| 280 | + "nbformat_minor": 4 |
| 281 | +} |
0 commit comments