Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Support MQA/GQA in decode phase by using FlashAttention #2744

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

zhaoyang-star
Copy link
Contributor

@zhaoyang-star zhaoyang-star commented Feb 4, 2024

As shown in issue #1880, vLLM's current paged attention kernel does not leverage the benefits of MQA/GQA. FlashAttention supports MQA/GQA at kernel level and has supported paged kv cache since v2.5.0.

To enjoy the benefit of MQA/GQA, this PR replaces _paged_attention with flash_attn_with_kvcache.

Notes:

Kernel latency

Env:

  • A100-40GB-PCIE
  • num query heads=64, num kv heads=8
  • kernel data type=bfloat16
  • use_alibi=False

Below are results by running benchmark_paged_attention.py.
context_len=1024 :

Batchsize Paged attention V1 (us) Paged attention V2 (us) FA (us) Speedup (min(V1, V2) / FA)
1 37.541 29.602 27.675 1.07
4 68.585 71.447 27.433 2.50
16 162.237 168.997 67.681 2.40
64 687.938 733.687 214.041 3.21
256 2215.532 2378.606 680.123 3.26

context_len=4096 :

Batchsize Paged attention V1 (us) Paged attention V2 (us) FA (us) Speedup (min(V1, V2) / FA)
1 137.858 70.783 28.557 2.48
4 223.438 164.293 62.104 2.65
16 716.780 724.601 238.730 3.00
64 2324.863 2792.657 772.266 3.01
256 8771.168 9723.521 2583.970 3.39

E2E latency

E2E throughput

Env:

  • 4xA100-40GB-PCIE
  • Baseline: python benchmarks/benchmark_throughput.py --input-len 512 --output-len 512 --model /CodeLlama-34b-hf/ --tokenizer /CodeLlama-34b-hf/ --trust-remote-code --tensor-parallel-size 4 --enforce-eager
  • FA: python benchmarks/benchmark_throughput.py --input-len 512 --output-len 512 --model /CodeLlama-34b-hf/ --tokenizer /CodeLlama-34b-hf/ --trust-remote-code --tensor-parallel-size 4 --enforce-eager --use-flash-attn

Below are results by running benchmark_throughput.py.

Original (tokens/sec) FA (tokens/sec) Speedup (FA / Original)
2114.28 2311.75 1.09

@zhaoyang-star zhaoyang-star marked this pull request as ready for review February 5, 2024 06:10
@ttbachyinsda
Copy link
Contributor

It appears that the throughput test for qwen-7b has a negative improvement, and it seems that the GPU blocks have decreased significantly compared to before. Maybe there is a bug?

No Flash Attention:
Namespace(backend='vllm', dataset=None, input_len=128, output_len=128, model='./Qwen-7B-Chat', tokenizer='./Qwen-7B-Chat', quantization=None, tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=True, max_model_len=None, dtype='auto', enforce_eager=False, kv_cache_dtype='auto', use_flash_attn=False, device='cuda')INFO 02-05 19:24:49 llm_engine.py:73] Initializing an LLM engine with config: model='./Qwen-7B-Chat', tokenizer='./Qwen-7B-Chat', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, use_flash_attn=False, device_config=cuda, seed=0) WARNING 02-05 19:24:50 tokenizer.py:64] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.INFO 02-05 19:25:21 llm_engine.py:331] # GPU blocks: 643, # CPU blocks: 512 INFO 02-05 19:25:22 model_runner.py:653] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.INFO 02-05 19:25:22 model_runner.py:657] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing gpu_memory_utilization or enforcing eager mode. You can also reduce the max_num_seqs as needed to decrease memory usage. INFO 02-05 19:25:26 model_runner.py:720] Graph capturing finished in 4 secs. Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:21<00:00, 7.07it/s]Throughput: 7.07 requests/s, 1810.59 tokens/s

With Flash Attention:
Namespace(backend='vllm', dataset=None, input_len=128, output_len=128, model='./Qwen-7B-Chat', tokenizer='./Qwen-7B-Chat', quantization=None, tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=True, max_model_len=None, dtype='auto', enforce_eager=False, kv_cache_dtype='auto', use_flash_attn=True, device='cuda')INFO 02-05 19:31:59 llm_engine.py:73] Initializing an LLM engine with config: model='./Qwen-7B-Chat', tokenizer='./Qwen-7B-Chat', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, use_flash_attn=True, device_config=cuda, seed=0)WARNING 02-05 19:32:00 tokenizer.py:64] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.INFO 02-05 19:32:14 llm_engine.py:331] # GPU blocks: 40, # CPU blocks: 32INFO 02-05 19:32:15 model_runner.py:653] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.INFO 02-05 19:32:15 model_runner.py:657] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing gpu_memory_utilization or enforcing eager mode. You can also reduce the max_num_seqs as needed to decrease memory usage.INFO 02-05 19:32:19 model_runner.py:720] Graph capturing finished in 5 secs.Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:09<00:00, 3.23it/s]Throughput: 3.23 requests/s, 826.13 tokens/s

@casper-hansen
Copy link
Contributor

@zhaoyang-star you may be interested in FlashInfer's PagedAttention which is 3x faster than vLLM's version. It supports GQA

https://github.com/flashinfer-ai/flashinfer/
https://flashinfer.ai/2024/02/02/introduce-flashinfer.html

image

@zhaoyang-star
Copy link
Contributor Author

zhaoyang-star commented Feb 6, 2024

@ttbachyinsda Thanks for your feedback.

As paged KV cache block size in Flash Attention must be divisible by 256, the block size is set to 256 by default in https://github.com/vllm-project/vllm/pull/2744/files#diff-ea8b8ff63961713ccb62d78e53e96404b587b7828cb9fee08a9e5576bf563673R54 So GPU Blocks will decrease significantly compared to before. It is not a bug.

I have no Qwen model on my hands. Below is the throughput benchmark under CodeLLaMA-7B on A100-40GB-PCIE and speedup is ~1.07x. Could you please use this model to benchmark throughput?

Note that this PR is mainly for MQA/GQA model. Qwen and CodeLLaMA-7B/13B both are MHA so they will not gain much speedup based on this PR.

root@532469214f5e:/bigdata/zhaoyang/github/remote/vllm# python benchmarks/benchmark_throughput.py --input-len 128 --output-len 128 --model /bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/ --tokenizer /bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/ --trust-remote-code
Namespace(backend='vllm', dataset=None, input_len=128, output_len=128, model='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', tokenizer='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', quantization=None, tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=True, max_model_len=None, dtype='auto', enforce_eager=False, kv_cache_dtype='auto', use_flash_attn=False, device='cuda')
INFO 02-06 01:20:27 llm_engine.py:73] Initializing an LLM engine with config: model='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', tokenizer='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=16384, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, use_flash_attn=False, device_config=cuda, seed=0)
INFO 02-06 01:20:45 llm_engine.py:331] # GPU blocks: 2624, # CPU blocks: 512
INFO 02-06 01:20:50 model_runner.py:653] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 02-06 01:20:50 model_runner.py:657] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 02-06 01:21:04 model_runner.py:720] Graph capturing finished in 15 secs.
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:51<00:00, 19.45it/s]
Throughput: 19.45 requests/s, 4978.70 tokens/s
root@532469214f5e:/bigdata/zhaoyang/github/remote/vllm# python benchmarks/benchmark_throughput.py --input-len 128 --output-len 128 --model /bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/ --tokenizer /bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/ --trust-remote-code --use-flash-attn
Namespace(backend='vllm', dataset=None, input_len=128, output_len=128, model='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', tokenizer='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', quantization=None, tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=True, max_model_len=None, dtype='auto', enforce_eager=False, kv_cache_dtype='auto', use_flash_attn=True, device='cuda')
INFO 02-06 01:22:33 llm_engine.py:73] Initializing an LLM engine with config: model='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', tokenizer='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=16384, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, use_flash_attn=True, device_config=cuda, seed=0)
INFO 02-06 01:22:47 llm_engine.py:331] # GPU blocks: 164, # CPU blocks: 32
INFO 02-06 01:22:52 model_runner.py:653] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 02-06 01:22:52 model_runner.py:657] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 02-06 01:23:15 model_runner.py:720] Graph capturing finished in 23 secs.
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:48<00:00, 20.73it/s]
Throughput: 20.73 requests/s, 5306.34 tokens/s

@zhaoyang-star
Copy link
Contributor Author

@zhaoyang-star you may be interested in FlashInfer's PagedAttention which is 3x faster than vLLM's version. It supports GQA

https://github.com/flashinfer-ai/flashinfer/ https://flashinfer.ai/2024/02/02/introduce-flashinfer.html

image

Thanks for your valuable feedback. Yes, I am very interested in kernel optimization and I will dive on FlashInfer soon.

@ttbachyinsda
Copy link
Contributor

root

@ttbachyinsda Thanks for your feedback.

As paged KV cache block size in Flash Attention must be divisible by 256, the block size is set to 256 by default in https://github.com/vllm-project/vllm/pull/2744/files#diff-ea8b8ff63961713ccb62d78e53e96404b587b7828cb9fee08a9e5576bf563673R54 So GPU Blocks will decrease significantly compared to before. It is not a bug.

I have no Qwen model on my hands. Below is the throughput benchmark under CodeLLaMA-7B on A100-40GB-PCIE and speedup is ~1.07x. Could you please use this model to benchmark throughput?

Note that this PR is mainly for MQA/GQA model. Qwen and CodeLLaMA-7B/13B both are MHA so they will not gain much speedup based on this PR.

root@532469214f5e:/bigdata/zhaoyang/github/remote/vllm# python benchmarks/benchmark_throughput.py --input-len 128 --output-len 128 --model /bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/ --tokenizer /bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/ --trust-remote-code
Namespace(backend='vllm', dataset=None, input_len=128, output_len=128, model='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', tokenizer='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', quantization=None, tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=True, max_model_len=None, dtype='auto', enforce_eager=False, kv_cache_dtype='auto', use_flash_attn=False, device='cuda')
INFO 02-06 01:20:27 llm_engine.py:73] Initializing an LLM engine with config: model='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', tokenizer='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=16384, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, use_flash_attn=False, device_config=cuda, seed=0)
INFO 02-06 01:20:45 llm_engine.py:331] # GPU blocks: 2624, # CPU blocks: 512
INFO 02-06 01:20:50 model_runner.py:653] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 02-06 01:20:50 model_runner.py:657] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 02-06 01:21:04 model_runner.py:720] Graph capturing finished in 15 secs.
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:51<00:00, 19.45it/s]
Throughput: 19.45 requests/s, 4978.70 tokens/s
root@532469214f5e:/bigdata/zhaoyang/github/remote/vllm# python benchmarks/benchmark_throughput.py --input-len 128 --output-len 128 --model /bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/ --tokenizer /bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/ --trust-remote-code --use-flash-attn
Namespace(backend='vllm', dataset=None, input_len=128, output_len=128, model='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', tokenizer='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', quantization=None, tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=True, max_model_len=None, dtype='auto', enforce_eager=False, kv_cache_dtype='auto', use_flash_attn=True, device='cuda')
INFO 02-06 01:22:33 llm_engine.py:73] Initializing an LLM engine with config: model='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', tokenizer='/bigdata/shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23/', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=16384, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, use_flash_attn=True, device_config=cuda, seed=0)
INFO 02-06 01:22:47 llm_engine.py:331] # GPU blocks: 164, # CPU blocks: 32
INFO 02-06 01:22:52 model_runner.py:653] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 02-06 01:22:52 model_runner.py:657] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 02-06 01:23:15 model_runner.py:720] Graph capturing finished in 23 secs.
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:48<00:00, 20.73it/s]
Throughput: 20.73 requests/s, 5306.34 tokens/s

Thank you for the guidance. I was testing with an RTX 3090, which might not be suitable for the changes in this PR. I will try to test the throughput benchmark under codellama-7b next.

@zhaoyang-star zhaoyang-star changed the title Use flash attention in decode phase [Performance] Support MQA/GQA in decode phase by using FlashAttention Feb 6, 2024
@zhaoyang-star
Copy link
Contributor Author

zhaoyang-star commented Feb 7, 2024

FlashAttention and FlashInfer are both SOTA solutions to speedup the decode phase in vLLM. We may do more research on it to decide which one to use. @WoosukKwon @zhuohan123 @casper-hansen I am glad to hear your opinions?

@casper-hansen
Copy link
Contributor

FlashAttention and FlashInfer are both SOTA solutions to speedup the decode phase in vLLM. We may do more research on it to decide which one to use. @WoosukKwon @zhuohan123 @casper-hansen I am glad to hear your opinions?

In the current stage of development, it looks like #2772 will be 44.5% faster than the main branch, probably due to FlashInfer being better than FlashAttention for PagedAttention.

@zhaoyang-star
Copy link
Contributor Author

@skrider is working on supporting small page sizes in FlashAttention. The block_size will be 16 after #824 is merged in FlashAttention. @zhuohan123

@skrider
Copy link
Contributor

skrider commented Feb 21, 2024

Hello! Yes, I have been working on that in flash attention, it is almost ready to be merged, just one small issue to deal with (fused RoPE). It could be vendored right now. However flashinfer is still slightly faster than flash attention 2, and then there is the issue of the fp8 kv cache and Turing support. I have talked to @simon-mo and the plan is to use flashinfer.

@zhaoyang-star
Copy link
Contributor Author

zhaoyang-star commented Feb 21, 2024

@skrider @simon-mo Thanks for you informantion. Glad to see FlashInfer has better performance than FA. I noticed #2772 is working on this, but still has some issues on it. I think it is still helpful to merge this pr as an option of high performance cuda kernel. Once FlashInfer is ready we could set FlashInfer as default.

@sighingnow
Copy link
Contributor

Some comments (from PR #3010):

  • flash-attn's flash_attn_with_kv_cache can be used for prefill as well (both no-prefixed and prefixed (current context_attention_fwd requires fixes for GQA))
  • the cache() op kernel is not needed as the kv-cache update would be handled by flash-attn's flash_attn_with_kv_cache.

Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 30, 2024
@mergify mergify bot added the ci/build label Oct 30, 2024
Copy link

mergify bot commented Oct 30, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @zhaoyang-star please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants