Skip to content

Commit

Permalink
mha variants
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Mar 6, 2024
1 parent d4754f1 commit 87fcfd9
Show file tree
Hide file tree
Showing 10 changed files with 431 additions and 10 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ Alternatively, you can view this and other files on GitHub at [https://github.co
| Ch 6: Finetuning for Text Classification | Q2 2024 | ... |
| Ch 7: Finetuning with Human Feedback | Q2 2024 | ... |
| Ch 8: Using Large Language Models in Practice | Q2/3 2024 | ... |
| Appendix A: Introduction to PyTorch* | - [code-part1.ipynb](appendix-A/03_main-chapter-code/code-part1.ipynb)<br/>- [code-part2.ipynb](appendix-A/03_main-chapter-code/code-part2.ipynb)<br/>- [DDP-script.py](appendix-A/03_main-chapter-code/DDP-script.py)<br/>- [exercise-solutions.ipynb](appendix-A/03_main-chapter-code/exercise-solutions.ipynb) | [./appendix-A](./appendix-A) |
| Appendix A: Introduction to PyTorch | - [code-part1.ipynb](appendix-A/03_main-chapter-code/code-part1.ipynb)<br/>- [code-part2.ipynb](appendix-A/03_main-chapter-code/code-part2.ipynb)<br/>- [DDP-script.py](appendix-A/03_main-chapter-code/DDP-script.py)<br/>- [exercise-solutions.ipynb](appendix-A/03_main-chapter-code/exercise-solutions.ipynb) | [./appendix-A](./appendix-A) |
| Appendix B: References and Further Reading | No code | |
| Appendix C: Exercises | No code | |


<br>

> [!TIP]
> Please see [this](appendix-A/01_optional-python-setup-preferences) and [this](appendix-A/02_installing-python-libraries) folder if you need more guidance on installing Python and Python packages.)
> Please see [this](appendix-A/01_optional-python-setup-preferences) and [this](appendix-A/02_installing-python-libraries) folder if you need more guidance on installing Python and Python packages.


Expand Down
4 changes: 2 additions & 2 deletions ch03/01_main-chapter-code/ch03.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1637,7 +1637,7 @@
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
"\n",
" self.d_out = d_out\n",
" self.num_heads = num_heads\n",
Expand Down Expand Up @@ -1865,7 +1865,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions ch03/01_main-chapter-code/multihead-attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
"\n",
" self.d_out = d_out\n",
" self.num_heads = num_heads\n",
Expand Down Expand Up @@ -342,7 +342,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
3 changes: 3 additions & 0 deletions ch03/02_bonus_efficient-multihead-attention/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# More Efficient Multi-Head Attention Implementations

- [mha-implementations.ipynb](mha-implementations.ipynb) contains and compares different implementations of multi-head attention
58 changes: 58 additions & 0 deletions ch03/02_bonus_efficient-multihead-attention/ch03.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))

def forward(self, x):
b, num_tokens, d_in = x.shape

keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)

# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)

# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Unsqueeze the mask to match dimensions
mask_unsqueezed = mask_bool.unsqueeze(0)
# Use the unsqueezed mask to fill attention scores
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)

# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)

# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection

return context_vec
Loading

0 comments on commit 87fcfd9

Please sign in to comment.