Skip to content

Commit d95c864

Browse files
authored
πŸ”΄πŸ”΄πŸ”΄ [Attention] Refactor Attention Interface for Bart-based Models (#38108)
* starting attn refactor for encoder decoder models via bart (eager + sdpa) * flash attention works, remove unnecessary code * flex attention support for bart!, gotta check if the renaming is not too aggressive * some comments * skip flex grad test for standalone as done with the other test * revert flex attn rename (for now), sdpa simplify, and todos * more todos * refactor mask creation for reuse * modular attempt at biogpt * first batch of other models * fix attn dropout * fix autoformer copies * hubert * another batch of models * copies/style + last round of bart models --> whisper next? * remove unnecessary _reshape function and remove copy to whisper * add skip for decoder-only models out of enc-dec (same as in bart) * bring back licences * remove comment, added to pr read instead * mostly docs * disable sew flex attn as it's unclear attn mask for now * oops * test fixes for enc-dec * torch fx fixes + try at flex attn * skip on mbart * some more fixes * musicgen skip / delete old attn class logic + sdpa compose compile skip * disable flex attn for musicgen, not worth the effort * more fixes and style * flex attention test for dropout and encoder decoder that dont have main input names * informer fixes * the weirdest thing I've encountered yet... * style * remove empty tensor attempt, found core root in previous commits * disable time series due to tests being very text centric on inputs * add speech to text to be ignoring the other attns, also due to tests * update docs * remaining issues resolved ? * update docs for current state --> nllb moe and pegasus x sdpa is questionable :D * some models have not set the is_causal flag... * change dtype in softmax tol old behaviour + some modular fixes * I hate it but it is what it is * fixes from main for bart * forgot this one * some model fixes * style * current status * marian works now * fixing some copies * some copy fixes + time series x informer * last models possibly and fixes on style/copies * some post merge fixes * more fixes * make attention interface callable and move warnings there * style lol * add comment to "unsupported" * remove callable interface and change interface warnings + some copies * fix * ternary is ugly af, make it simpler * how did that happen * fix flex attn test * failing the test * no more fallback! fixing copies next * style + attn fixed * fixing copies and mask creation * wrong copy * fixup tests and disable flex attn for now * fixup last tests?
1 parent 9895819 commit d95c864

File tree

75 files changed

+8035
-5932
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+8035
-5932
lines changed

β€Ždocs/source/en/model_doc/biogpt.mdβ€Ž

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
1818

1919
<div class="flex flex-wrap space-x-1">
2020
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
21+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
2122
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
2223
</div>
2324

@@ -40,13 +41,13 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The
4041

4142
### Using Scaled Dot Product Attention (SDPA)
4243

43-
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
44-
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
45-
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
44+
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
45+
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
46+
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
4647
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
4748
page for more information.
4849

49-
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
50+
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
5051
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
5152

5253
```
@@ -109,7 +110,7 @@ we saw the following speedups during inference.
109110
[[autodoc]] BioGptForCausalLM
110111
- forward
111112

112-
113+
113114
## BioGptForTokenClassification
114115

115116
[[autodoc]] BioGptForTokenClassification

β€Ždocs/source/en/model_doc/blenderbot-small.mdβ€Ž

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ rendered properly in your Markdown viewer.
2121
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
2222
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
2323
">
24+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
25+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
2426
</div>
2527

2628
Note that [`BlenderbotSmallModel`] and
@@ -52,7 +54,7 @@ found [here](https://github.com/facebookresearch/ParlAI).
5254

5355
## Usage tips
5456

55-
Blenderbot Small is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
57+
Blenderbot Small is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
5658
the left.
5759

5860

β€Ždocs/source/en/model_doc/blenderbot.mdβ€Ž

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ rendered properly in your Markdown viewer.
2121
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
2222
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
2323
">
24+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
25+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
2426
</div>
2527

2628
## Overview
@@ -45,7 +47,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The
4547

4648
## Usage tips and example
4749

48-
Blenderbot is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
50+
Blenderbot is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
4951
rather than the left.
5052

5153
An example:
@@ -71,7 +73,7 @@ An example:
7173
`facebook/blenderbot_small_90M`, have a different architecture and consequently should be used with
7274
[BlenderbotSmall](blenderbot-small).
7375

74-
76+
7577
## Resources
7678

7779
- [Causal language modeling task guide](../tasks/language_modeling)

β€Ždocs/source/en/model_doc/marian.mdβ€Ž

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ rendered properly in your Markdown viewer.
2121
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
2222
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
2323
">
24+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
25+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
2426
</div>
2527

2628
## Overview
@@ -155,7 +157,7 @@ Example of translating english to many romance languages, using old-style 2 char
155157
>>> model = MarianMTModel.from_pretrained(model_name)
156158
>>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
157159
>>> tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
158-
["c'est une phrase en anglais que nous voulons traduire en franΓ§ais",
160+
["c'est une phrase en anglais que nous voulons traduire en franΓ§ais",
159161
'Isto deve ir para o portuguΓͺs.',
160162
'Y esto al espaΓ±ol']
161163
```

β€Ždocs/source/en/model_doc/nllb-moe.mdβ€Ž

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ The original code can be found [here](https://github.com/facebookresearch/fairse
5151

5252
## Implementation differences with SwitchTransformers
5353

54-
The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the
55-
highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed,
56-
which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden
57-
states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism.
54+
The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the
55+
highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed,
56+
which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden
57+
states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism.
5858

5959
## Generating with NLLB-MoE
6060

β€Ždocs/source/en/model_doc/pegasus.mdβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ rendered properly in your Markdown viewer.
2121
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
2222
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
2323
">
24+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
25+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
2426
</div>
2527

2628
## Overview

β€Ždocs/source/en/model_doc/pegasus_x.mdβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
1818

1919
<div class="flex flex-wrap space-x-1">
2020
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
21+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
2122
</div>
2223

2324
## Overview

0 commit comments

Comments
Β (0)