Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs] Adding a blocksparse example + numbers #20

Merged
merged 3 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion HOWTO.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Let's present here a couple of code snippets on how to solve a couple of questio
- [Create complex sparsity patterns with xFormers](#create-complex-sparsity-patterns-with-xformers)
- [Replace all attentions from an existing ViT model with a sparse equivalent ?](#replace-all-attentions-from-an-existing-vit-model-with-a-sparse-equivalent-)
- [Some more examples](#some-more-examples)
- [BlockSparseAttention](#blocksparseattention)
- [Extend the xFormers parts zoo locally](#extend-the-xformers-parts-zoo-locally)
- [Contributing an extension to the xFormers repository](#contributing-an-extension-to-the-xformers-repository)
- [Per component cherry picking](#per-component-cherry-picking)
Expand Down Expand Up @@ -187,6 +188,100 @@ Note that in practice exchanging all the attentions with a sparse alternative ma
- on creating [complex sparsity patterns](docs/source/2d_attention_patterns.ipynb)
- on a [SwinTransformers](docs/source/swin_transformers.ipynb)

## BlockSparseAttention

BlockSparse attention uses [Triton](https://github.com/openai/triton) to limit the attention computations to some tiles, which you define at construction time.
A simple example is that of a causal attention: just compute the lower triangular tiles ! The tile size can be changed, the minimum being 16 coefficients on one dimension.

If you already have a per-coefficient pattern in mind and this is not a perfect match with a block pattern, this is probably fine,
BlockSparse is fast enough so that dropping some of the computations after the fact with a fine-grained mask is still probably better than dense computations.
We provide a small helper (this is just maxpooling) to convert in between a per coefficient binary mask and the layout that you will need to build a block sparse attention.
Let's look at an example:

```python
import torch

from xformers.components import MultiHeadDispatch
from xformers.components.attention import BlockSparseAttention

BATCH = 2
HEADS = 8
SEQ = 2048
EMB = 1024
BLOCK_SIZE = 32
DROPOUT = 0.1
dtype = torch.float16

# Let's try out a causal mask, but really it could be anything "block sparse enough"
causal_mask = torch.tril(torch.ones((SEQ, SEQ), device=torch.device("cuda"), dtype=dtype))

blocks = SEQ // BLOCK_SIZE
causal_layout = torch.tril(torch.ones([HEADS, blocks, blocks]))

# Let's build our blocksparse attention. Please note that the layout can be
# [SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE] or [HEADS, SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE]
# so that _you can pass a different layout per head_
attention = BlockSparseAttention(layout=causal_layout, block_size=BLOCK_SIZE, dropout=DROPOUT, num_heads=HEADS)

# Out of commodity, let's build our multihead attention now
# "multi_head" will be responsible for the forward
multi_head = (
MultiHeadDispatch(
seq_len=SEQ,
dim_model=EMB,
residual_dropout=DROPOUT,
num_heads=HEADS,
attention=attention,
)
.cuda()
.half()
)

# Now FW some random data
# Note that passing a per-coefficient mask makes it possible to remove extra coefficients,
# which where required by the blockification
query = torch.randn((BATCH, SEQ, EMB), requires_grad=True, device=torch.device("cuda"), dtype=dtype)

# Self attention in this particular example, no limitations really
att_val = multi_head(query=query, key=query, value=query, att_mask=causal_mask)

#########################################
# Bonus: compare the memory use vs dense:
def mem_use(fn, kwargs, title):
# bookeeping
import time

start = time.time()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# actually run the function
fn(**kwargs)
torch.cuda.synchronize()
stop = time.time()

# now report
max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"{title} - Peak memory use: {max_memory}MB - {round((stop-start)*1e6)/1e3}ms")


pytorch_multihead = torch.nn.MultiheadAttention(
EMB, HEADS, batch_first=True, device=torch.device("cuda"), dtype=torch.float16
)

mem_use(multi_head, {"query": query, "key": query, "value": query, "att_mask": causal_mask}, "Blocksparse")
mem_use(pytorch_multihead, {"query": query, "key": query, "value": query, "attn_mask": causal_mask}, "PyTorch")
```

On a V100, with PyTorch 1.9, Triton 1.1 and xFormers 0.0.1 this reports something along the lines of:

```bash
Blocksparse - Peak memory use: 151MB - 6.619ms
PyTorch - Peak memory use: 393MB - 6.837ms
```

Note that the pattern here is not that sparse (half of the matrix is empty), the more sparse it gets the more biased the result will get towards BlockSparseAttention.

## Extend the xFormers parts zoo locally

This can be done in a private fork of xFormers, if this is a work in progress or not something that you would like to share at this point, or directly in xFormers in order to submit a [pull request](https://github.com/fairinternal/xformers/pulls).
Expand Down Expand Up @@ -306,7 +401,8 @@ Any of the other attention mechanisms can be instantiated and called in a simila
"name": attention_name, # you can easily make this dependent on a file, sweep,..
"dropout": DROPOUT,
"seq_len": SEQ,
"attention_query_mask": torch.rand((SEQ, 1)) < 0.3, # some dummy mask
# add any extra parameter that this specific attention would require
# this can be a superset of all the parameters if you're sweeping, useless parameters will be ignored
}

attention = build_attention(my_config)
Expand Down
98 changes: 98 additions & 0 deletions docs/source/tutorials/blocksparse.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@

Using BlockSparseAttention
==========================

BlockSparse attention uses Triton_ to limit the attention computations to some tiles, which you define at construction time.
A simple example is that of a causal attention: just compute the lower triangular tiles ! The tile size can be changed, the minimum being 16 coefficients on one dimension.

.. _Triton: https://github.com/openai/triton

If you already have a per-coefficient pattern in mind and this is not a perfect match with a block pattern, this is probably fine,
BlockSparse is fast enough so that dropping some of the computations after the fact with a fine-grained mask is still probably better than dense computations.
We provide a small helper (this is just maxpooling) to convert in between a per coefficient binary mask and the layout that you will need to build a block sparse attention.
Let's look at an example:

.. code-block:: python

import torch

from xformers.components import MultiHeadDispatch
from xformers.components.attention import BlockSparseAttention

BATCH = 2
HEADS = 8
SEQ = 2048
EMB = 1024
BLOCK_SIZE = 32
DROPOUT = 0.1
dtype = torch.float16

# Let's try out a causal mask, but really it could be anything "block sparse enough"
causal_mask = torch.tril(torch.ones((SEQ, SEQ), device=torch.device("cuda"), dtype=dtype))

blocks = SEQ // BLOCK_SIZE
causal_layout = torch.tril(torch.ones([HEADS, blocks, blocks]))

# Let's build our blocksparse attention. Please note that the layout can be
# [SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE] or [HEADS, SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE]
# so that _you can pass a different layout per head_
attention = BlockSparseAttention(layout=causal_layout, block_size=BLOCK_SIZE, dropout=DROPOUT, num_heads=HEADS)

# Out of commodity, let's build our multihead attention now
# "multi_head" will be responsible for the forward
multi_head = (
MultiHeadDispatch(
seq_len=SEQ,
dim_model=EMB,
residual_dropout=DROPOUT,
num_heads=HEADS,
attention=attention,
)
.cuda()
.half()
)

# Now FW some random data
# Note that passing a per-coefficient mask makes it possible to remove extra coefficients,
# which where required by the blockification
query = torch.randn((BATCH, SEQ, EMB), requires_grad=True, device=torch.device("cuda"), dtype=dtype)

# Self attention in this particular example, no limitations really
att_val = multi_head(query=query, key=query, value=query, att_mask=causal_mask)


#########################################
# Bonus: compare the memory use vs dense:
def mem_use(fn, kwargs, title):
# bookeeping
import time

start = time.time()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# actually run the function
fn(**kwargs)
torch.cuda.synchronize()
stop = time.time()

# now report
max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"{title} - Peak memory use: {max_memory}MB - {round((stop-start)*1e6)/1e3}ms")


pytorch_multihead = torch.nn.MultiheadAttention(
EMB, HEADS, batch_first=True, device=torch.device("cuda"), dtype=torch.float16
)

mem_use(multi_head, {"query": query, "key": query, "value": query, "att_mask": causal_mask}, "Blocksparse")
mem_use(pytorch_multihead, {"query": query, "key": query, "value": query, "attn_mask": causal_mask}, "PyTorch")

On a V100, with PyTorch 1.9, Triton 1.1 and xFormers 0.0.1 this reports something along the lines of:

.. code-block:: bash

Blocksparse - Peak memory use: 151MB - 6.619ms
PyTorch - Peak memory use: 393MB - 6.837ms

Note that the pattern here is not that sparse (half of the matrix is empty), the more sparse it gets the more biased the result will get towards BlockSparseAttention.
1 change: 1 addition & 0 deletions docs/source/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Tutorials
:maxdepth: 1

sparse_vit
blocksparse
extend_attentions
use_attention
pytorch_encoder
Expand Down
6 changes: 3 additions & 3 deletions tests/test_attention_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,20 @@ def test_pattern_to_layout():
# All ones
mask1 = torch.ones((SIZE, SIZE), dtype=torch.bool)
layout1 = AP.pattern_to_layout(mask1, BLOCK)
ref1 = torch.ones((LAYOUT_SIZE, LAYOUT_SIZE), dtype=torch.int)
ref1 = torch.ones((LAYOUT_SIZE, LAYOUT_SIZE), dtype=torch.long)
assert torch.allclose(layout1, ref1)

# Diagonal -> expect block diagonal
mask2 = torch.eye(SIZE, dtype=torch.bool)
layout2 = AP.pattern_to_layout(mask2, BLOCK)
ref2 = torch.eye(LAYOUT_SIZE, dtype=torch.int)
ref2 = torch.eye(LAYOUT_SIZE, dtype=torch.long)
assert torch.allclose(layout2, ref2)

# Lower triangular, without the diagonal
# note that the layout will need to have the diagonal, else the coefficients close enough would not be computed
mask3 = torch.tril(torch.ones((SIZE, SIZE)), diagonal=-1).to(torch.bool)
layout3 = AP.pattern_to_layout(mask3, BLOCK)
ref3 = torch.tril(torch.ones((LAYOUT_SIZE, LAYOUT_SIZE)), diagonal=0).to(torch.int)
ref3 = torch.tril(torch.ones((LAYOUT_SIZE, LAYOUT_SIZE)), diagonal=0).to(torch.long)
assert torch.allclose(layout3, ref3)

# Handle heads properly
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/attention/attention_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def pattern_to_layout(mask: torch.Tensor, block_size: int) -> torch.Tensor:
layout = torch.nn.functional.max_pool2d(
mask.to(torch.float), kernel_size=block_size, stride=block_size
)
layout = layout.to(torch.int)
layout = layout.to(torch.long)

if _should_squeeze:
layout.squeeze_(0)
Expand Down