Skip to content

Commit

Permalink
Adding the same tuto to the HTML docs
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Oct 21, 2021
1 parent 7a69bd1 commit f5f40d1
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 2 deletions.
2 changes: 0 additions & 2 deletions HOWTO.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ Note that in practice exchanging all the attentions with a sparse alternative ma
- on creating [complex sparsity patterns](docs/source/2d_attention_patterns.ipynb)
- on a [SwinTransformers](docs/source/swin_transformers.ipynb)


## BlockSparseAttention

BlockSparse attention uses [Triton](https://github.com/openai/triton) to limit the attention computations to some tiles, which you define at construction time.
Expand Down Expand Up @@ -284,7 +283,6 @@ On a V100, with PyTorch 1.9, Triton 1.1 and xFormers 0.0.1 this reports somethin

Note that the pattern here is not that sparse (half of the matrix is empty), the more sparse it gets the more biased the result will get towards BlockSparseAttention.


## Extend the xFormers parts zoo locally

This can be done in a private fork of xFormers, if this is a work in progress or not something that you would like to share at this point, or directly in xFormers in order to submit a [pull request](https://github.com/fairinternal/xformers/pulls).
Expand Down
98 changes: 98 additions & 0 deletions docs/source/tutorials/blocksparse.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@

Using BlockSparseAttention
==========================

BlockSparse attention uses Triton_ to limit the attention computations to some tiles, which you define at construction time.
A simple example is that of a causal attention: just compute the lower triangular tiles ! The tile size can be changed, the minimum being 16 coefficients on one dimension.

.. _Triton: https://github.com/openai/triton

If you already have a per-coefficient pattern in mind and this is not a perfect match with a block pattern, this is probably fine,
BlockSparse is fast enough so that dropping some of the computations after the fact with a fine-grained mask is still probably better than dense computations.
We provide a small helper (this is just maxpooling) to convert in between a per coefficient binary mask and the layout that you will need to build a block sparse attention.
Let's look at an example:

.. code-block:: python
import torch
from xformers.components import MultiHeadDispatch
from xformers.components.attention import BlockSparseAttention
BATCH = 2
HEADS = 8
SEQ = 2048
EMB = 1024
BLOCK_SIZE = 32
DROPOUT = 0.1
dtype = torch.float16
# Let's try out a causal mask, but really it could be anything "block sparse enough"
causal_mask = torch.tril(torch.ones((SEQ, SEQ), device=torch.device("cuda"), dtype=dtype))
blocks = SEQ // BLOCK_SIZE
causal_layout = torch.tril(torch.ones([HEADS, blocks, blocks]))
# Let's build our blocksparse attention. Please note that the layout can be
# [SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE] or [HEADS, SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE]
# so that _you can pass a different layout per head_
attention = BlockSparseAttention(layout=causal_layout, block_size=BLOCK_SIZE, dropout=DROPOUT, num_heads=HEADS)
# Out of commodity, let's build our multihead attention now
# "multi_head" will be responsible for the forward
multi_head = (
MultiHeadDispatch(
seq_len=SEQ,
dim_model=EMB,
residual_dropout=DROPOUT,
num_heads=HEADS,
attention=attention,
)
.cuda()
.half()
)
# Now FW some random data
# Note that passing a per-coefficient mask makes it possible to remove extra coefficients,
# which where required by the blockification
query = torch.randn((BATCH, SEQ, EMB), requires_grad=True, device=torch.device("cuda"), dtype=dtype)
# Self attention in this particular example, no limitations really
multi_head(query=query, key=query, value=query, att_mask=causal_mask)
#########################################
# Bonus: compare the memory use vs dense:
def mem_use(fn, kwargs, title):
# bookeeping
import time
start = time.time()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# actually run the function
fn(**kwargs)
torch.cuda.synchronize()
stop = time.time()
# now report
max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"{title} - Peak memory use: {max_memory}MB - {round((stop-start)*1e6)/1e3}ms")
pytorch_multihead = torch.nn.MultiheadAttention(
EMB, HEADS, batch_first=True, device=torch.device("cuda"), dtype=torch.float16
)
mem_use(multi_head, {"query": query, "key": query, "value": query, "att_mask": causal_mask}, "Blocksparse")
mem_use(pytorch_multihead, {"query": query, "key": query, "value": query, "attn_mask": causal_mask}, "PyTorch")
On a V100, with PyTorch 1.9, Triton 1.1 and xFormers 0.0.1 this reports something along the lines of:

.. code-block:: bash
Blocksparse - Peak memory use: 151MB - 6.619ms
PyTorch - Peak memory use: 393MB - 6.837ms
Note that the pattern here is not that sparse (half of the matrix is empty), the more sparse it gets the more biased the result will get towards BlockSparseAttention.
1 change: 1 addition & 0 deletions docs/source/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Tutorials
:maxdepth: 1

sparse_vit
blocksparse
extend_attentions
use_attention
pytorch_encoder
Expand Down

0 comments on commit f5f40d1

Please sign in to comment.