forked from rasbt/LLMs-from-scratch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
431 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# More Efficient Multi-Head Attention Implementations | ||
|
||
- [mha-implementations.ipynb](mha-implementations.ipynb) contains and compares different implementations of multi-head attention |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class MultiHeadAttention(nn.Module): | ||
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False): | ||
super().__init__() | ||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads" | ||
|
||
self.d_out = d_out | ||
self.num_heads = num_heads | ||
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim | ||
|
||
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) | ||
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) | ||
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) | ||
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs | ||
self.dropout = nn.Dropout(dropout) | ||
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) | ||
|
||
def forward(self, x): | ||
b, num_tokens, d_in = x.shape | ||
|
||
keys = self.W_key(x) # Shape: (b, num_tokens, d_out) | ||
queries = self.W_query(x) | ||
values = self.W_value(x) | ||
|
||
# We implicitly split the matrix by adding a `num_heads` dimension | ||
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) | ||
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) | ||
values = values.view(b, num_tokens, self.num_heads, self.head_dim) | ||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) | ||
|
||
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) | ||
keys = keys.transpose(1, 2) | ||
queries = queries.transpose(1, 2) | ||
values = values.transpose(1, 2) | ||
|
||
# Compute scaled dot-product attention (aka self-attention) with a causal mask | ||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head | ||
# Original mask truncated to the number of tokens and converted to boolean | ||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens] | ||
# Unsqueeze the mask to match dimensions | ||
mask_unsqueezed = mask_bool.unsqueeze(0) | ||
# Use the unsqueezed mask to fill attention scores | ||
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf) | ||
|
||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) | ||
attn_weights = self.dropout(attn_weights) | ||
|
||
# Shape: (b, num_tokens, num_heads, head_dim) | ||
context_vec = (attn_weights @ values).transpose(1, 2) | ||
|
||
# Combine heads, where self.d_out = self.num_heads * self.head_dim | ||
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) | ||
context_vec = self.out_proj(context_vec) # optional projection | ||
|
||
return context_vec |
Oops, something went wrong.