Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve flash attn monkey patch #2212

Merged
merged 1 commit into from
Aug 13, 2023
Merged

Improve flash attn monkey patch #2212

merged 1 commit into from
Aug 13, 2023

Conversation

merrymercy
Copy link
Member

@merrymercy merrymercy commented Aug 13, 2023

simplify code and remove the dependency of einops

@merrymercy merrymercy merged commit 5a6b920 into main Aug 13, 2023
2 checks passed
@merrymercy merrymercy mentioned this pull request Aug 13, 2023
3 tasks
@merrymercy merrymercy deleted the improve-flash-attn branch August 13, 2023 01:52
Copy link
Contributor

@tmm1 tmm1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious is there an issue with einops or just reducing dependencies?

if past_key_value is not None:
# reuse k, v, self_attention
# reuse k, v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi this comment is in the upstream transformers version too

) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this note about how _prepare_decoder_attention_mask has been neutered seems relevant? i think that's important to keep in place

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, we can add it back.

@merrymercy
Copy link
Member Author

merrymercy commented Aug 13, 2023

I met some compatibility issues between einops and torch.compile

@tmm1
Copy link
Contributor

tmm1 commented Aug 13, 2023

ah i understand now

i think it would be better to figure out these issues in your environment, because flash-attn itself also uses einops and it will cause you more problems later

https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py#L116

@tmm1
Copy link
Contributor

tmm1 commented Aug 13, 2023

okay it seems adding the 70B support is fairly straightforward, the 70B specific code needs to be copied to the beginning and end of LlamaAttention.forward from the latest transformers

i'm not able to test 70B myself, but i can make a PR with the required changes if that's helpful

i'm also trying to figure out inference and could use some help there.

i observe what is described here: huggingface/transformers@ee81bf5#diff-e889ae4211091d62c7135af08ba6a828017e3e3791494281916d8acf6c934e45R100-R114

specifically the first few forward passes when using inference will have query_state.shape == key_state.shape, however then onward there is only one token in query_state while key/value shapes increase

in other implementations i found, they switch back to normal softmax calculation when the shapes don't match

and in the link above, you can see flash-attn v1 had separate cu_len for q/k

but in v2 i'm not sure if that's still possible, it appears you have to pass qkv packed and that requires they be the same size

so i wonder, can we resize q to match and fill with zeros for the previous tokens? i'm not sure what is the correct equivalent operation here.

@merrymercy
Copy link
Member Author

I think v2 supports all APIs of v1. FlashAttention V2 just renamed some APIs according to this https://github.com/Dao-AILab/flash-attention#upgrading-from-flashattention-1x-to-flashattention-2. Could you try to port the v1 code?

If it does not work, I think we can switch back to non-flash-attn for decoding steps. The most important advantage of flash-attn is to reduce the peak memory of processing prompts (the first forward). For the autoregressive decoding steps, we do not need to create a seq_len x seq_len attn_weight matrix. Therefore, even the default implementation without flash-attn almost achieves the lowest peak memory. Although I believe flash-attn not only reduces peak memory but also reduces the latency, I think it is okay to switch back to non-flash-attn during decoding steps.

@tmm1
Copy link
Contributor

tmm1 commented Aug 13, 2023

thanks, i realized the same after i explained where i was stuck last night

i agree we can switch to another method for decoding, either the regular upstream calc or torch's sdp. however one wrinkle in that plan is that since we neutered the attention mask its no longer compatible with those other methods and will have to be recreated first.

in diving around the flash-attn codebase, i found a helper in the tests that can assemble qkvpacked and kvpacked outputs to use with their various apis, so i pulled that out and was able to simplify the patch a bit.

you can track my work here: axolotl-ai-cloud/axolotl#381

i'm using that currently with my benchmarking suite in axolotl-ai-cloud/axolotl#357, however i'm still seeing that the flash-attn patch has terrible performance as compared to the other implementations...

test_id                                                     vram_generate_cache    generate_time    generate_tps
--------------------------------------------------------  ---------------------  ---------------  --------------
test_inference[base-bf16-llama2_7b-prompt-size=4096]                    9.71613          5.99026         21.368
test_inference[xformers-bf16-llama2_7b-prompt-size=4096]                5.33137          6.44044         19.8744
test_inference[sdp-bf16-llama2_7b-prompt-size=4096]                     4.62629          5.57443         22.962
test_inference[flash-bf16-llama2_7b-prompt-size=4096]                   5.51691         10.8526          11.7944

@tmm1
Copy link
Contributor

tmm1 commented Aug 13, 2023

i will start investigating some llama-based models which have integrated flash-attn to see how they're handling inference

https://huggingface.co/emozilla/LLongMA-2-13b-storysummarizer/blob/main/modeling_llama.py

https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py

Also the work to integrate flash-attn v2 into pytorch may be a helpful reference pytorch/pytorch#105602

cc @philschmid

@tmm1
Copy link
Contributor

tmm1 commented Aug 13, 2023

the llongma2 impl shows a simple way to avoid butchering _prepare_decoder_attention_mask:

if attention_mask is not None:
    attention_mask = attention_mask[:, 0, -1]

that would undo the "preparation", i.e. [bsz, 1, tgt_seq_len, src_seq_len] -> [bsz, seq_len]

@tmm1
Copy link
Contributor

tmm1 commented Aug 13, 2023

i ran my impl through line-profiler:

Total time: 6.77868 s                                                                                                      
File: /mnt/ml/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py                                                   
Function: flashattn_forward at line 51                                                                                     
                                                                                                                           
Line #      Hits         Time  Per Hit   % Time  Line Contents                                                             
==============================================================                     
   188                                               else:                                                                 
   189      4064   15908203.0   3914.4      0.2          query_states = query_states.transpose(1, 2)                       
   190      4064    7522953.0   1851.1      0.1          key_states = key_states.transpose(1, 2)                           
   191      4064    6167117.0   1517.5      0.1          value_states = value_states.transpose(1, 2)                       
   192      4064    1265025.0    311.3      0.0          (  # pylint: disable=unbalanced-tuple-unpacking                   
   193      4064     703975.0    173.2      0.0              q_unpad,                                                      
   194      4064     510697.0    125.7      0.0              kv_unpad,                                                     
   195      4064     568327.0    139.8      0.0              cu_seqlens_q,                                                 
   196      4064     468365.0    115.2      0.0              cu_seqlens_k,                                                 
   197      4064     495261.0    121.9      0.0              max_seqlen_q,                                                 
   198      4064     442606.0    108.9      0.0              max_seqlen_k,                                                 
   199      4064     686618.0    169.0      0.0              _,                                                            
   200      4064     588445.0    144.8      0.0              _,                                                            
   201      4064     548991.0    135.1      0.0              output_pad_fn,                                                
   202      4064 3589082837.0 883140.5     52.9          ) = generate_qkv(                                                 
   203      4064     531844.0    130.9      0.0              query_states,                                                 
   204      4064     469537.0    115.5      0.0              key_states,                                                   
   205      4064     470396.0    115.7      0.0              value_states,                                                 
   206      4064     554246.0    136.4      0.0              kvpacked=True,                                                
   207      4064     527402.0    129.8      0.0              key_padding_mask=attention_mask,                              
   208      4064   29964776.0   7373.2      0.4              query_padding_mask=attention_mask[:, -query_states.size(1):] if attention_mask is not None else None,                                                                                    
   209                                                   ) 
Total time: 4.1738 s                                                                                                       
File: /mnt/ml/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py                                                   
Function: generate_qkv at line 251                                                                                         
                                                                                                                           
Line #      Hits         Time  Per Hit   % Time  Line Contents                                                             
==============================================================                                                             
   274      4096     749352.0    182.9      0.0      if query_padding_mask is not None:                                    
   275      4096 2066823654.0 504595.6     49.5          q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(     
   276      4096     567930.0    138.7      0.0              q, query_padding_mask                                         
   277                                                   )                                                                 
   ...
   298      4096     906386.0    221.3      0.0      if key_padding_mask is not None:
   299      4096  870280681.0 212470.9     20.9          k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_paddin
g_mask)
   300      4096 1092189013.0 266647.7     26.2          v_unpad, _, _, _ = unpad_input(v, key_padding_mask)

seems like there's a fast path that can be implemented during decoding when attention_mask is set but is fully-enabled, i.e. attention_mask.all().item()

@tmm1
Copy link
Contributor

tmm1 commented Aug 13, 2023

seems like there's a fast path that can be implemented during decoding when attention_mask is set but is fully-enabled, i.e. attention_mask.all().item()

that did it!

test_id                                                     vram_model    time_model    prompt_tokens    vram_generate_cache    generate_time    generate_tokens    generate_tps
--------------------------------------------------------  ------------  ------------  ---------------  ---------------------  ---------------  -----------------  --------------
test_inference[base-bf16-llama2_7b-prompt-size=2048]           12.6138       9.49888             2049                2.70052          4.60331                128         27.8061
test_inference[xformers-bf16-llama2_7b-prompt-size=2048]       12.6138       9.13354             2049                2.83919          4.99762                128         25.6122
test_inference[sdp-bf16-llama2_7b-prompt-size=2048]            12.6138       9.06869             2049                2.47005          4.19624                128         30.5035
test_inference[flash-bf16-llama2_7b-prompt-size=2048]          12.6138       9.20882             2049                2.83724          4.65223                128         27.5137

now flash is much more comparable, and very close to xformers (which i built from git and is also using flashv2). you can see the vram usage between flash/xformers is identical, and flash is slightly faster when used directly

@merrymercy
Copy link
Member Author

Cool! When you are done, please submit a PR. I will give it a try.

@tmm1
Copy link
Contributor

tmm1 commented Aug 14, 2023

it might be simplest if you pull the patch from my tree. it should drop in here. i will send a PR but i need to flesh out some tests first to get a better handle on correctness

also i found Dao-AILab/flash-attention#436 which is also attempting to solve the same issue fwiu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants