Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce flash-attn (>= 2.5.0). #3010

Closed
wants to merge 1 commit into from

Conversation

sighingnow
Copy link
Contributor

@sighingnow sighingnow commented Feb 23, 2024

flash-attention starts support paged kv-cache since v2.5.0 (commit Dao-AILab/flash-attention@54e80a3) in the flash_attn_with_kvcache interface.

This PR enables the usage of flash-attn kernels for both prefilling, contexted-prefill (current context_attention_fwd kernel for prefix cache) and decoding, for both MHA and MQA/GQA.

(This PR includes GQA fixes for prefix cache in #3007 which should be accepted/merged first).

Takeaways

(Correct me if I made something wrong in the following evaluation.)

  • For prompting, flash-attention's kernel performance is as good as xformers memory_efficient_attention_forward.
  • For decoding, flash-attention's kernel outperforms vllm's paged attention kernels for 10+%
  • The larger block size (256) decrease the performance of saving key/value tensors into kv-cache when context length is 512 (but not for context length 1024). For context length 1024, the affect of larger block size looks equivalent for both flash-attn and xformer's kernel.

Benchmarks

Kernel performance

  • Performance of benchmarks/kernel/benchmark_paged_attention.py (on A800-SXM4-80GB):
    Latency (us) Latency (us) Latency (us)  
Version   v1 v2 FA min(v1,v2)/FA
Context Length Batch Size        
512 1 20.285 24.053 21.732 0.93341616
512 2 22.507 27.161 21.557 1.04406921
512 4 27.766 32.225 21.303 1.3033845
512 8 51.812 54.791 26.772 1.93530554
512 16 85.764 94.209 34.092 2.5156635
512 32 154.25 165.756 65.76 2.34565085
512 64 296.102 314.561 99.969 2.9619382
512 128 575.615 601.286 167.221 3.44224111
512 256 1135.105 1183.508 300.406 3.77856967
1024 1 34.968 28.472 34.563 0.82377108
1024 2 38.488 33.959 34.814 0.97544091
1024 4 48.401 53.541 24.094 2.00884038
1024 8 95.94 93.901 34.651 2.70990736
1024 16 159.446 165.031 61.633 2.58702319
1024 32 281.751 312.197 119.318 2.36134531
1024 64 545.191 596.719 180.948 3.01297058
1024 128 1054.5 1167.859 308.539 3.41772029
1024 256 1973.139 2263.873 571.34 3.45352855
2048 1 63.153 33.674 21.453 1.56966392
2048 2 72.214 53.621 24.961 2.14819118
2048 4 96.665 93.176 36.007 2.58771905
2048 8 181.835 163.072 59.931 2.72099581
2048 16 292.313 306.513 114.099 2.56192429
2048 32 522.957 586.998 221.445 2.36156608
2048 64 1017.5 1147.996 348.459 2.9199992
2048 128 1856.959 2263.743 587.604 3.16022185
2048 256 3434.592 3927.69 1091.562 3.14649282
4096 1 119.062 55.451 26.47 2.09486211
4096 2 158.624 93.562 36.841 2.53961619
4096 4 184.772 162.506 56.607 2.8707757
4096 8 345.027 307.934 98.774 3.1175613
4096 16 559.029 584.602 219.395 2.54804804
4096 32 990.7 1141.4 428.872 2.31001324
4096 64 1942.043 2172.605 674.434 2.87951527
4096 128 3372.405 3997.455 1143.417 2.94940953
4096 256 6675.193 7643.167 2119.215 3.14984228
  • Performance of benchmarks/kernels/benchmark_attention.py (on A800-SXM4-80GB):

The no kv-cache version is the cache_ops.reshape_and_cache been disabled in attention.py. FA is the version using flash_attn_func for prefill and tensor's indexing operation to update kv-cache, FA-kvcache is using the kernel flash_attn_with_kvcache itself to update the kv-cache.

    Latency (us) Latency (us) Latency (us)  
Version   xformers FA FA-kvcache xformers/FA
Context Length Batch Size        
512 1 110.338 91.07 139.999 1.21157351
512 2 144.344 147.993 278.349 0.97534343
512 4 250.35 245.323 503.129 1.02049135
512 8 467.584 431.974 932.702 1.08243552
512 16 883.95 811.29 1712.075 1.08956107
512 32 1743.296 1565.759 3999.321 1.11338718
512 64 3174.326 2901.265 8777.063 1.09411791
512 128 6066.597 5493.965 17084.151 1.10422928
512 256 11833.118 10863.233 33945.686 1.08928143
1024 1 206.524 211.106 455.081 0.97829526
1024 2 351.918 350.311 825.614 1.00458735
1024 4 649.32 625.219 1553.106 1.03854809
1024 8 1231.075 1178.457 2834.009 1.04464991
1024 16 2413.602 2194.367 5874.732 1.09990808
1024 32 4509.488 4265.735 13223.946 1.05714209
1024 64 8383.135 8080.716 26377.322 1.03742478
1024 128 16872.201 15938.01 52312.309 1.05861403
1024 256 33296.577 31935.821 104567.139 1.04260908
2048 1 565.678 562.886 1475.639 1.00496015
2048 2 1026.456 1012.15 2451.654 1.01413427
2048 4 1843.185 1908.501 4756.062 0.96577628
2048 8 3501.278 3410.372 9763.956 1.02665574
2048 16 6791.142 6651.876 22373.683 1.02093635
2048 32 13349.664 13104.102 44491.954 1.01873932
2048 64 26580.194 25880.751 88575.947 1.02702561
2048 128 53249.013 51936.57 176763.652 1.02527011
2048 256 106779.17 104066.506 358598.515 1.02606664
4096 1 1797.347 1793.38 4714.081 1.00221202
4096 2 3173.016 3222.586 8502.306 0.98461794
4096 4 6234.825 5852.162 17453.473 1.06538831
4096 8 11774.696 11533.758 40768.599 1.02088981
4096 16 23339.23 23209.621 82110.277 1.00558428
4096 32 46793.742 46198.331 164628.018 1.01288815
4096 64 93572.857 92678.064 328483.712 1.00965485
4096 128 187961.277 186534.328 666575.131 1.00764979
4096 256 OOM 373899.889 1331579.82  

Throughput

  • Performance of benchmark_throughput.py on Llama-70B-GPTQ (on A800-SXM4-80GB):
    Througput (tokens/s) Througput (tokens/s)  
Input Len Output Len w/o FA FA  
512 512 593.29 619.68 1.04448078
1024 1024 404.06 419.42 1.03801416

Here the "Speed (seconds/round)" is the averaged duration for per 10 prompt/decodings runs. We can see that

  • In prompting stages, the performance of flash-attn starts degenerating after some sequences finished and the next batches of prompting starts.
  • In decoding stages, flash-attn stably outperforms the pagedattention kernel.

image

  • Throughput of prompting
    Througput (tokens/s) Througput (tokens/s) Througput (tokens/s) Througput (tokens/s) Througput (tokens/s)
Input Len Output Len w/o FA w/o FA, 256 FA w/o FA, no-cache FA, no-cache
512 1 1713.7 1688.68 1696.82 1719.04 1720.64
1024 1 1675.02 1692.14 1688.2 1702.63 1707.86

Real-world cases throughput

  • Benchmark on real-world dataset: running speculative decoding (from Introduce speculative decoding with draft models to vLLM #3029) with selected questions from the OpenOrca-1M-GPT4 dataset (input_length <= 256 && output_length <= 512), using Llama-2-70B-GPTQ as target model and TinyLlama-1.1B-Chat-v1.0-GPTQ as draft model, set speculative lookahead = 5. We can see about 1.3x~1.4x improvements when flash attention is enabled (it also means the performance differences between flash_attn_with_kvcache and context_attention_fwd):
  Througput (tokens/s) Througput (tokens/s)  
Batch Size w/o FA FA FA/(w/o FA)
32 237.8 322.64 1.3567704
40 265.74 367.76 1.38390908
48 268.61 372.09 1.38524254
64 314.25 440.39 1.40140016
96 340.25 475.93 1.39876561
128 365.1 530.13 1.45201315
144 367.02 533.93 1.45477086
160 377.61 548.73 1.45316596
192 384.37 560.5 1.45823035

@sighingnow
Copy link
Contributor Author

Polite ping @WoosukKwon I just noticed you have the working branch refactor-attn after preparing this PR. Hope changes/implementation in this PR could be helpful.

Thanks!

@sighingnow
Copy link
Contributor Author

The entrypoint test failure shouldn't be caused by this PR.

@zhaoyang-star
Copy link
Contributor

Glad to see porting FA to vLLM. FYI, there is already has a similar PR #2744. As FlashInfer is faster than FlashAttention, vLLMers may prefer FlashInfer than FA.

@sighingnow
Copy link
Contributor Author

Glad to see porting FA to vLLM. FYI, there is already has a similar PR #2744. As FlashInfer is faster than FlashAttention, vLLMers may prefer FlashInfer than FA.

Thanks for the information. I have left some comments on #2744.

@zhaoyang-star
Copy link
Contributor

zhaoyang-star commented Feb 26, 2024

Glad to see porting FA to vLLM. FYI, there is already has a similar PR #2744. As FlashInfer is faster than FlashAttention, vLLMers may prefer FlashInfer than FA.

Thanks for the information. I have left some comments on #2744.

Thanks for your work! As paged KV cache block size in Flash Attention must be divisible by 256. That is a big difference from the block_size=16 in vLLM. could you share the latency and throughput benchmark data? Does it cause side-effect?

@sighingnow
Copy link
Contributor Author

sighingnow commented Feb 26, 2024

Thanks for your work! As paged KV cache block size in Flash Attention must be divisible by 256. That is a big difference from the block_size=16 in vLLM. could you share the latency and throughput benchmark data? Does it cause side-effect?

Hi @zhaoyang-star I have updated the description for numbers (the same setting with #2744, thank you!).

Note that the current throughput benchmark may not reflect the real-world situations (as eos is ignored and batches are aligned). More benchmark data on unaligned, continuous batch will be added soon.

@zhaoyang-star
Copy link
Contributor

@sighingnow The speedup of throughput is ~3% from your benchmark. How about the e2e latency?

@sighingnow
Copy link
Contributor Author

@sighingnow The speedup of throughput is ~3% from your benchmark. How about the e2e latency?

Still running. I have added new data for speculative decoding, which means unaligned input sequences inside a batch, and parallel decoding requirement for the main model.

@zhaoyang-star
Copy link
Contributor

@sighingnow I tested the e2e latency using Flash-Attention in vLLM. Performance boost can only be seen on large batchsize.

@sighingnow
Copy link
Contributor Author

@sighingnow The speedup of throughput is ~3% from your benchmark. How about the e2e latency?

For input=512 and output=512, the improvement smaller than the numbers reported in #2744 should because I didn't use tensor parallel to accelerate these MLPs (which is more costly for shorter context size and GPTQ may amplify it).

@skrider
Copy link
Contributor

skrider commented Feb 26, 2024

Hello, just wanted to provide a quick update on Dao-AILab/flash-attention#824 - the PR is complete, flash attention can now support page sizes as low as 16

@sighingnow
Copy link
Contributor Author

Hello, just wanted to provide a quick update on Dao-AILab/flash-attention#824 - the PR is complete, flash attention can now support page sizes as low as 16

Thank you!

@sighingnow
Copy link
Contributor Author

sighingnow commented Feb 26, 2024

@zhaoyang-star I have uploaded more data and analysis to the pull request description. Hope that could be helpful.

@sighingnow
Copy link
Contributor Author

Thanks for your work! As paged KV cache block size in Flash Attention must be divisible by 256. That is a big difference from the block_size=16 in vLLM. could you share the latency and throughput benchmark data? Does it cause side-effect?

Hi @zhaoyang-star, I have observed the side effect of larger block size on saving key/value vectors into kv-cache (the side effect only affects prefill, and exists for both xformers kernel and flash-attn's kernel). I have incldued data in the pull request description.

After @skrider's PR been merged, I would take another try on the new version of flash-attention.

@sighingnow
Copy link
Contributor Author

Hi @WoosukKwon @zhuohan123 I would like to know if you folks have any comment on this pull requests? Integrating changes in this PR into #3005 is also fine for me.

Signed-off-by: Tao He <sighingnow@gmail.com>
@DarkLight1337
Copy link
Member

Closing as it is superseded by #3005.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants