Skip to content

Transformer building blocks tutorial #3075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
pyspelling + linkcheck + flex
  • Loading branch information
mikaylagawarecki committed Oct 30, 2024
commit c11927a3a064aa30f80b0ccff465aeb2971b225c
14 changes: 14 additions & 0 deletions en-wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ Kihyuk
Kiuk
Kubernetes
Kuei
KV-Caching
LRSchedulers
LSTM
LSTMs
Expand Down Expand Up @@ -276,6 +277,7 @@ Xcode
Xeon
Yidong
YouTube
Zipf
accelerometer
accuracies
activations
Expand Down Expand Up @@ -305,6 +307,7 @@ bbAP
benchmarked
benchmarking
bitwise
bool
boolean
breakpoint
broadcasted
Expand Down Expand Up @@ -333,6 +336,7 @@ csv
cuDNN
cuda
customizable
customizations
datafile
dataflow
dataframe
Expand Down Expand Up @@ -377,6 +381,7 @@ fbgemm
feedforward
finetune
finetuning
FlexAttention
fp
frontend
functionalized
Expand Down Expand Up @@ -431,6 +436,7 @@ mAP
macos
manualSeed
matmul
matmuls
matplotlib
memcpy
memset
Expand All @@ -446,6 +452,7 @@ modularized
mpp
mucosa
multihead
MultiheadAttention
multimodal
multimodality
multinode
Expand All @@ -456,7 +463,10 @@ multithreading
namespace
natively
ndarrays
nheads
nightlies
NJT
NJTs
num
numericalize
numpy
Expand Down Expand Up @@ -532,6 +542,7 @@ runtime
runtime
runtimes
scalable
SDPA
sharded
softmax
sparsified
Expand Down Expand Up @@ -591,12 +602,14 @@ tradeoff
tradeoffs
triton
uint
UX
umap
uncomment
uncommented
underflowing
unfused
unimodal
unigram
unnormalized
unoptimized
unparametrized
Expand All @@ -618,6 +631,7 @@ warmstarted
warmstarting
warmup
webp
wikitext
wsi
wsis
Meta's
Expand Down
113 changes: 57 additions & 56 deletions intermediate_source/transformer_building_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
Please head there instead!

If you are only interested in performant attention score modifications, please
head to the `FlexAttention blog <https://flexattention.com/blog/>`_ that
head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_ .

If you are wondering about what building blocks the ``torch`` library provides
Expand All @@ -63,7 +63,7 @@
*`scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_

``scaled_dot_product_attention`` is a primitive for
$\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V$ that dispatches into either fused
:math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
implementations of the operator or a fallback implementation. It works out of
the box in eager mode (i.e. the default mode of using PyTorch where operations
are executed on the fly as they are encountered) and also integrates seamlessly
Expand Down Expand Up @@ -118,7 +118,7 @@
# The improvements are threefold:
#
# * User Experience
# Recall that `nn.MultiheadAttention` requires ``query```, ``key`` and
# Recall that ``nn.MultiheadAttention`` requires ``query```, ``key`` and
# ``value`` to be dense ``torch.Tensor``s. It also provides a
# ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
# that arise due to different sequence lengths within a batch. Since there is
Expand Down Expand Up @@ -202,10 +202,10 @@ def forward(self,
4. Apply output projection

Args:
query (torch.Tensor): query of shape (N, L_q, E_qk)
key (torch.Tensor): key of shape (N, L_kv, E_qk)
value (torch.Tensor): value of shape (N, L_kv, E_v)
attn_mask (torch.Tensor, optional): attention mask of shape (N, L_q, L_kv) to pass to sdpa. Default: None
query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``)
key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``)
value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``)
attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None
is_causal (bool, optional): Whether to apply causal mask. Default: False

Returns:
Expand Down Expand Up @@ -251,11 +251,10 @@ def forward(self,

return attn_output

# TODO: Check whether there is a way to collapse this section by default
# sphinx.collapse?

###############################################################################
# Utilities
# =========
# ========================
# In this section, we include a utility to generate semi-realistic data using
# Zipf distribution for sentence lengths. This is used to generate the nested
# query, key and value tensors. We also include a benchmark utility.
Expand Down Expand Up @@ -343,7 +342,7 @@ def benchmark(func, *args, **kwargs):
torch.manual_seed(6)
vanilla_mha_layer = nn.MultiheadAttention(E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device='cuda')

# nn.MultiheadAttention uses a non conventional init for layers, so do this for exact parity :(
# nn.MultiheadAttention uses a non conventional initialization for layers, so do this for exact parity :(
mha_layer.out_proj.weight = nn.Parameter(vanilla_mha_layer.out_proj.weight.clone().detach())
mha_layer.packed_proj.weight = nn.Parameter(vanilla_mha_layer.in_proj_weight.clone().detach())
mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach())
Expand All @@ -357,8 +356,8 @@ def benchmark(func, *args, **kwargs):
nested_result, nested_time, nested_peak_memory = benchmark(new_mha_layer, query, query, query, is_causal=True)
padded_nested_result = nested_result.to_padded_tensor(0.0)

# For the vanilla nn.MHA, we need to construct the key_padding_mask
# Further, nn.MultiheadAttention forces one to materialize the attn_mask even if using is_causal
# For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask``
# Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal``
src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0]
attn_mask = torch.empty((N, S, S), device=device).fill_(float('-inf'))
for i, s in enumerate(sentence_lengths):
Expand Down Expand Up @@ -431,14 +430,14 @@ def benchmark(func, *args, **kwargs):
# classification of modifications to the transformer architecture, recall that we
# classified the modifications into layer type, layer ordering, and modifications
# to the attention score. We trust that changing layer type and layer ordering
# (e.g. swapping LayerNorm for RMSNorm) is fairly straightforward.
# (e.g. swapping``LayerNorm`` for ``RMSNorm``) is fairly straightforward.
#
# In this section, we will discuss various functionalities using the
# aforementioned building blocks. In particular,
#
# * Packed Projection
# * Cross Attention
# * Fully masked rows no longer cause NaNs
# * Fully masked rows no longer cause ``NaN``s
# * [TODO] Modifying attention score: Relative Positional Embedding with NJT
# * [TODO] KV-Caching with NJT
# * [TODO] Grouped Query Attention with NJT
Expand All @@ -448,13 +447,13 @@ def benchmark(func, *args, **kwargs):
# -----------------
#
# Packed projection is a technique that makes use of the fact that when the input
# for projection (matmul) are the same (self-attention), we can pack the projection
# for projection (matrix multiplications) are the same (self-attention), we can pack the projection
# weights and biases into single tensors. It is especially useful when the individual
# projections (matmuls) are memory bound rather than compute bound. There are
# projections are memory bound rather than compute bound. There are
# two examples that we will demonstrate here:
#
# * Input projection for MultiheadAttention
# * SwiGLU activation in FFN of Transformer Layer
# * SwiGLU activation in feed-forward network of Transformer Layer
#
# Input projection for MultiheadAttention
# ----------------------------------------
Expand Down Expand Up @@ -505,7 +504,7 @@ def forward(self, query):
# SwiGLU feed forward network of Transformer Layer
# ------------------------------------------------
# SwiGLU is a non-linear activation function that is increasingly popular in the feed-forward
# network of the transformer layer (e.g. Llama). A FFN with SwiGLU activation is defined as
# network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as

class SwiGLUFFN(nn.Module):
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None):
Expand Down Expand Up @@ -601,45 +600,47 @@ def forward(self, x):


################################################################################
# [PENDING] KV-Caching with NJT
# ----------------------------
# During decoding in inference, the query comprises of the current token. However,
# the key and value comprises of all the previous keys and values in addition to
# the current token.
#
# When we do batched inference, each batch item will be at a different stage of
# decoding, so we expect the keys and values to have different sequence lengths.
# The query is a dense tensor of shape ``[B, 1, E_qk]`` and the keys and values
# will be of shapes ``[B, *, E_qk]`` and ``[B, *, E_v]`` where ``B`` represents
# batch size, ``*`` represents varying sequence lengths and ``E_qk`` and ``E_v``
# are embedding dimensions for query/key and value respectively.

# Directly related to the above point is the idea of KV-Caching. This is a technique
# that is used in inference to reduce the latency of decoding. The idea is to cache
# the key and value tensors for the previous tokens and use them for the current
# token. This is especially useful when the sequence length is long.

# FIXME: Pending https://github.com/pytorch/pytorch/pull/135722


################################################################################
# [PENDING] Relative Positional Embedding with NJT (FlexAttention + NJT)
# ALiBi with NJT (FlexAttention + NJT)
# ---------------------------------------------------------------------
#
# FIXME: Pending https://github.com/pytorch/pytorch/pull/136792
# NJT also composes with the ``FlexAttention`` module. This is a generalization
# of the ``MultiheadAttention`` layer that allows for arbitrary modifications
# to the attention score. The example below takes the ``alibi_mod`` from
# attention gym and uses it with nested input tensors.

from torch.nn.attention.flex_attention import flex_attention

def generate_alibi_bias(H: int):
"""Returns an alibi bias score_mod given the number of heads H
Args:
H: number of heads
Returns:
alibi_bias: alibi bias score_mod
"""
def alibi_mod(score, b, h, q_idx, kv_idx):
scale = torch.exp2(-((h + 1) * 8.0 / H))
bias = (q_idx - kv_idx) * scale
return score + bias
return alibi_mod

query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
n_heads, D = 8, E_q // 8
alibi_score_mod = generate_alibi_bias(n_heads)
query = (
query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
)
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = (
value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
)
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)

################################################################################
# [PENDING] Grouped Query Attention with NJT
# ------------------------------------------
#
# Grouped Query Attention refers to using a number of key/value heads that is
# less than the number of query heads. Compared to MultiheadAttention, this
# decreases the size of the kv-cache during inference.
#
# We can implement this using nested tensors as follows
# And more
# --------
#
# FIXME: Pending FlexAttention/testing for NJT with grouped query attention
# We intend to update this tutorial to demonstrate more examples of how to use
# the various performant building blocks such as KV-Caching, Grouped Query Attention
# etc.


################################################################################
Expand All @@ -649,7 +650,7 @@ def forward(self, x):
# There are several good examples of using various performant building blocks to
# implement various transformer architectures. Some examples include
#
# * `gpt_fast <https://github.com/pytorch-labs/gpt-fast>`_
# * `sam_fast <https://github.com/pytorch-labs/sam-fast>`_
# * `lucidrains implementation of ViT with nested tensors <https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/nested_tensor.py>`_
# * `gpt-fast <https://github.com/pytorch-labs/gpt-fast>`_
# * `segment-anything-fast <https://github.com/pytorch-labs/segment-anything-fast>`_
# * `lucidrains implementation of NaViT with nested tensors <https://github.com/lucidrains/vit-pytorch/blob/73199ab486e0fad9eced2e3350a11681db08b61b/vit_pytorch/na_vit_nested_tensor.py>`_
# * `torchtune's implementation of VisionTransformer <https://github.com/pytorch/torchtune/blob/a8a64ec6a99a6ea2be4fdaf0cd5797b03a2567cf/torchtune/modules/vision_transformer.py#L16>`_