-
Notifications
You must be signed in to change notification settings - Fork 4.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve flash attn monkey patch #2212
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious is there an issue with einops or just reducing dependencies?
if past_key_value is not None: | ||
# reuse k, v, self_attention | ||
# reuse k, v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fyi this comment is in the upstream transformers version too
) # [bsz, nh, 3, q_len, hd] | ||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] | ||
# We have disabled _prepare_decoder_attention_mask in LlamaModel | ||
# the attention_mask should be the same as the key_padding_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this note about how _prepare_decoder_attention_mask
has been neutered seems relevant? i think that's important to keep in place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, we can add it back.
I met some compatibility issues between einops and torch.compile |
ah i understand now i think it would be better to figure out these issues in your environment, because flash-attn itself also uses einops and it will cause you more problems later https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py#L116 |
okay it seems adding the 70B support is fairly straightforward, the 70B specific code needs to be copied to the beginning and end of LlamaAttention.forward from the latest transformers i'm not able to test 70B myself, but i can make a PR with the required changes if that's helpful i'm also trying to figure out inference and could use some help there. i observe what is described here: huggingface/transformers@ee81bf5#diff-e889ae4211091d62c7135af08ba6a828017e3e3791494281916d8acf6c934e45R100-R114 specifically the first few forward passes when using inference will have in other implementations i found, they switch back to normal softmax calculation when the shapes don't match and in the link above, you can see flash-attn v1 had separate cu_len for q/k but in v2 i'm not sure if that's still possible, it appears you have to pass qkv packed and that requires they be the same size so i wonder, can we resize q to match and fill with zeros for the previous tokens? i'm not sure what is the correct equivalent operation here. |
I think v2 supports all APIs of v1. FlashAttention V2 just renamed some APIs according to this https://github.com/Dao-AILab/flash-attention#upgrading-from-flashattention-1x-to-flashattention-2. Could you try to port the v1 code? If it does not work, I think we can switch back to non-flash-attn for decoding steps. The most important advantage of flash-attn is to reduce the peak memory of processing prompts (the first forward). For the autoregressive decoding steps, we do not need to create a |
thanks, i realized the same after i explained where i was stuck last night i agree we can switch to another method for decoding, either the regular upstream calc or torch's sdp. however one wrinkle in that plan is that since we neutered the attention mask its no longer compatible with those other methods and will have to be recreated first. in diving around the flash-attn codebase, i found a helper in the tests that can assemble qkvpacked and kvpacked outputs to use with their various apis, so i pulled that out and was able to simplify the patch a bit. you can track my work here: axolotl-ai-cloud/axolotl#381 i'm using that currently with my benchmarking suite in axolotl-ai-cloud/axolotl#357, however i'm still seeing that the flash-attn patch has terrible performance as compared to the other implementations...
|
i will start investigating some llama-based models which have integrated flash-attn to see how they're handling inference https://huggingface.co/emozilla/LLongMA-2-13b-storysummarizer/blob/main/modeling_llama.py https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py Also the work to integrate flash-attn v2 into pytorch may be a helpful reference pytorch/pytorch#105602 cc @philschmid |
the llongma2 impl shows a simple way to avoid butchering if attention_mask is not None:
attention_mask = attention_mask[:, 0, -1] that would undo the "preparation", i.e. |
i ran my impl through line-profiler:
seems like there's a fast path that can be implemented during decoding when attention_mask is set but is fully-enabled, i.e. |
that did it!
now flash is much more comparable, and very close to xformers (which i built from git and is also using flashv2). you can see the vram usage between flash/xformers is identical, and flash is slightly faster when used directly |
Cool! When you are done, please submit a PR. I will give it a try. |
it might be simplest if you pull the patch from my tree. it should drop in here. i will send a PR but i need to flesh out some tests first to get a better handle on correctness also i found Dao-AILab/flash-attention#436 which is also attempting to solve the same issue fwiu |
simplify code and remove the dependency of einops