-
Notifications
You must be signed in to change notification settings - Fork 651
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding the same tuto to the HTML docs
- Loading branch information
1 parent
7a69bd1
commit f5f40d1
Showing
3 changed files
with
99 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
|
||
Using BlockSparseAttention | ||
========================== | ||
|
||
BlockSparse attention uses Triton_ to limit the attention computations to some tiles, which you define at construction time. | ||
A simple example is that of a causal attention: just compute the lower triangular tiles ! The tile size can be changed, the minimum being 16 coefficients on one dimension. | ||
|
||
.. _Triton: https://github.com/openai/triton | ||
|
||
If you already have a per-coefficient pattern in mind and this is not a perfect match with a block pattern, this is probably fine, | ||
BlockSparse is fast enough so that dropping some of the computations after the fact with a fine-grained mask is still probably better than dense computations. | ||
We provide a small helper (this is just maxpooling) to convert in between a per coefficient binary mask and the layout that you will need to build a block sparse attention. | ||
Let's look at an example: | ||
|
||
.. code-block:: python | ||
import torch | ||
from xformers.components import MultiHeadDispatch | ||
from xformers.components.attention import BlockSparseAttention | ||
BATCH = 2 | ||
HEADS = 8 | ||
SEQ = 2048 | ||
EMB = 1024 | ||
BLOCK_SIZE = 32 | ||
DROPOUT = 0.1 | ||
dtype = torch.float16 | ||
# Let's try out a causal mask, but really it could be anything "block sparse enough" | ||
causal_mask = torch.tril(torch.ones((SEQ, SEQ), device=torch.device("cuda"), dtype=dtype)) | ||
blocks = SEQ // BLOCK_SIZE | ||
causal_layout = torch.tril(torch.ones([HEADS, blocks, blocks])) | ||
# Let's build our blocksparse attention. Please note that the layout can be | ||
# [SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE] or [HEADS, SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE] | ||
# so that _you can pass a different layout per head_ | ||
attention = BlockSparseAttention(layout=causal_layout, block_size=BLOCK_SIZE, dropout=DROPOUT, num_heads=HEADS) | ||
# Out of commodity, let's build our multihead attention now | ||
# "multi_head" will be responsible for the forward | ||
multi_head = ( | ||
MultiHeadDispatch( | ||
seq_len=SEQ, | ||
dim_model=EMB, | ||
residual_dropout=DROPOUT, | ||
num_heads=HEADS, | ||
attention=attention, | ||
) | ||
.cuda() | ||
.half() | ||
) | ||
# Now FW some random data | ||
# Note that passing a per-coefficient mask makes it possible to remove extra coefficients, | ||
# which where required by the blockification | ||
query = torch.randn((BATCH, SEQ, EMB), requires_grad=True, device=torch.device("cuda"), dtype=dtype) | ||
# Self attention in this particular example, no limitations really | ||
multi_head(query=query, key=query, value=query, att_mask=causal_mask) | ||
######################################### | ||
# Bonus: compare the memory use vs dense: | ||
def mem_use(fn, kwargs, title): | ||
# bookeeping | ||
import time | ||
start = time.time() | ||
torch.cuda.empty_cache() | ||
torch.cuda.reset_peak_memory_stats() | ||
# actually run the function | ||
fn(**kwargs) | ||
torch.cuda.synchronize() | ||
stop = time.time() | ||
# now report | ||
max_memory = torch.cuda.max_memory_allocated() // 2 ** 20 | ||
print(f"{title} - Peak memory use: {max_memory}MB - {round((stop-start)*1e6)/1e3}ms") | ||
pytorch_multihead = torch.nn.MultiheadAttention( | ||
EMB, HEADS, batch_first=True, device=torch.device("cuda"), dtype=torch.float16 | ||
) | ||
mem_use(multi_head, {"query": query, "key": query, "value": query, "att_mask": causal_mask}, "Blocksparse") | ||
mem_use(pytorch_multihead, {"query": query, "key": query, "value": query, "attn_mask": causal_mask}, "PyTorch") | ||
On a V100, with PyTorch 1.9, Triton 1.1 and xFormers 0.0.1 this reports something along the lines of: | ||
|
||
.. code-block:: bash | ||
Blocksparse - Peak memory use: 151MB - 6.619ms | ||
PyTorch - Peak memory use: 393MB - 6.837ms | ||
Note that the pattern here is not that sparse (half of the matrix is empty), the more sparse it gets the more biased the result will get towards BlockSparseAttention. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ Tutorials | |
:maxdepth: 1 | ||
|
||
sparse_vit | ||
blocksparse | ||
extend_attentions | ||
use_attention | ||
pytorch_encoder | ||
|