-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open to contribution: adding torch.nn.functional.scaled_dot_product_attention
support for more architectures
#28005
Comments
Hi @fxmarty I can take a look at this issue. Of I can ask questions if necessary. Or has anyone taken it already? |
does someone know if longT5 and all T5 models are blocked by bias support in flash attention ? |
Hi @davidan5 are you working on the implementation? |
@ENate I was trying to understand the status and have an estimation of the code change to see if I can contribute. |
I see. |
I'm interested in taking a look at this for the Mistral model if that's still needed. Otherwise, please let me know if there are any other models that still need some work. Thanks |
Is LongT5 still open? |
Mistral is already covered! LongT5 if it is like T5 and has attention bias that might not be supported |
Oh yea, looks like you added support for Mistral/Mixtral last month. It doesn't seem to be supported for BERT yet (I think someone else is working on FA2 but not SDPA), so I'll take a crack at it. It looks like there is a config for relative position embeddings for BERT, so I'll just have it fallback to the original attention for configs using relative position embeddings. @ArthurZucker - Please let me know if you know if someone else is already working on SDPA for BERT and I can look for something else to do. Thanks! |
Not sure anyone is working on that but bert is already so small that I doubt it will have a lot of impact on perf! |
@ArthurZucker for the T5 family of models, attention bias is required, so flash-attention won't work for now but torch SDPA can still use the memory efficient kernel from xformers, right? I did some benchmarking with Chronos models (based on T5 architecture) here (amazon-science/chronos-forecasting#33) and there's a clear speedup when using torch SDPA. |
@abdulfatir That's correct |
I can open a PR for T5 with SDPA then. Are there specific things that I should know of or a reference that can look at? |
@abdulfatir For sure, some specific things that are good to know: pytorch/pytorch#108108 (
example of a PR: #29108 |
FYI going forward we should rather use
|
Hey @abdulfatir just wanted to check in if you are still working on dropping a PR to add SDPA support for T5? It would tremendously help accelerating diffusion models that use T5. |
@fxmarty @amyeroberts Could you recommend a model to work on? |
@OmarManzoor Feel free to add to any model which tickles your fancy and isn't listed here: https://github.com/huggingface/transformers/blob/main/docs/source/en/perf_infer_gpu_one.md#pytorch-scaled-dot-product-attention |
Okay, working on |
Is it possible to support moondream2? |
Hi @Bocchi-Chan2023, there is an implementation of moondream2 available on the hub: https://huggingface.co/vikhyatk/moondream2 |
I will start working on OPT |
Just a stupid question: we are adding native sdpa support in modeling code, since pytorch has provide the fused sdpa? |
Hi,, Not sure if this is the best place to raise this, but I think it should be fairly straightforward to add support for DinoV2, since DinoV2 is using code copied from ViT, and ViT has since been updated (see here for comments about copied code, and #30555 for adding SDPA to ViT). I'm using DinoV2 so would be keen for any speedups! Should I just flag that here, or worth opening a separate issue? I'm also not sure whether I should be raising support for BetterTransformer for DinoV2 in Optimum instead/as well, but my impression is that native support here is preferable if straightforward so it's used by default? |
@EFord36 Regarding adding SDPA, you can create a separate issue to track but we can regard this comment as a request for it to be added :) If you or anyone else in the community would like to open a PR to add this to the model we'd be very happy to review a PR! Yes, it'd be preferable to be added in transformers as this would mean SDPA is available outside of optimum usage |
I can handle dinoV2 |
I have some free time, I will add sdpa to Speech2text. will use bart as inspiration |
@allenyummy Have you implement sdpa on debertav3? if not, is it sophisticated? |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
It should not be super complicated, and a good refactoring is quite needed. Once #22105 is merged, we can easily add it! |
Feature request
In
Transformers 4.36
, we started adding native support of torch.nn.functional.scaled_dot_product_attention (SDPA), enabled by default in Transformers: https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattentionSDPA allows to dispatch to memory-efficient attention, flash attention on supported GPUs (currently NVIDIA-only), and even on Intel CPUs.
For the record, here's a benchmark on some currently supported models:
Training benchmark, run on A100-SXM4-80GB.
"eager"
, s)"sdpa"
, s)"eager"
, MB)"sdpa"
, MB)Inference benchmark, run on A100-SXM4-80GB.
"eager"
(ms)"sdpa"
(ms)Previously, we had a partial support of SDPA in Optimum BetterTransformer but we are now looking to slowly deprecate it in favor of upstream support of SDPA directly in Transformers.
Here are the architectures for which support has been requested:
The integration could take inspiration from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/decoder_models.py & https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py
Motivation
Faster training & inference, lower memory requirement
Your contribution
I may work on some at some point, but contributions are most welcome.
You should refer to #26572 to add the support of SDPA for a model, roughly following these steps:
XxxSdpaAttention
class inheriting fromXxxAttention
and implement the attention logic using SDPA_prepare_4d_causal_attention_mask_for_sdpa
instead of_prepare_4d_causal_attention_mask
for SDPA_prepare_4d_attention_mask_for_sdpa
instead of_prepare_4d_attention_mask
for SDPA_supports_sdpa = True
toXxxPreTrainedModel
"sdpa"
key toXXX_ATTENTION_CLASSES
in the model modeling fileThe text was updated successfully, but these errors were encountered: