-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Support MQA/GQA in prefill stage by using FlashAttention #2401
Conversation
Hey, I cannot replicate the benchmark results wrt current implementation. Here is what I am getting:
As you can see, the flash attention result is close to what you have reported, but the current implementation result is much closer to flash attention, especially after the benchmark code has been modified to match what's actually in vllm code. Here's the modification: key_expanded = key
value_expanded = value
query_expanded = query
num_queries_per_kv = num_query_heads // num_kv_heads
if num_queries_per_kv > 1:
# Handle MQA and GQA
query_expanded = query_expanded.view(query_expanded.shape[0], num_kv_heads,
num_queries_per_kv, query_expanded.shape[-1])
key_expanded = key[:, :,
None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value_expanded = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv,
value.shape[-1])
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query_expanded.unsqueeze(0),
key_expanded.unsqueeze(0),
value_expanded.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
) A100-40GB, |
@Yard1 Thanks for your info. You are right. I should use Updated benchmark results are below. Xformers is close to FA.
xformers==0.0.22, flash-attn==2.4.2. |
For MQA/GQA, it should mainly see a speedup during decoding, although this looks good for a start. Can we measure throughput difference? |
I used starcoder which is a MQA model and the throughput is mainly close to the original version. |
Hi, regarding the llama structure, does it significantly improve throughput performance? |
https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html
I guess it explains the difference. |
I guess the result will be similar with starcoder as starcoder is MQA and llama2-34B is GQA. Besides, llama2-7B and 13B is MHA so will not gain any speedup. |
When I tested this PR with the llama2 70B, the throughput did not improve. I used the asyncLLMengine server along with 4*A800 80G PCIE. |
Because the attention caculation could only achieve 1.1~1.2 speedup, the e2e speedup is hard to benchmark. |
Largest speed up will probably be seen during decoding. |
See also #3010 which introduces the |
Closing as it's already implemented. Thanks for submitting the PR. Learned a lot from it! |
As shown in #1880, xformers has not supported MQA/GQA yet. So
key
andvalue
need to be extended before calculatingsoftmax(Q @ K^T * softmax_scale) @ V
. While FlashAttention has supported MQA/GQA and has supported Turing, Ampere, Ada, or Hopper GPU. But note that FA has some limits (1. head size up to 256. datatype fp16 and bf16.)So for prefill, I replaced xformers with FlashAttention when FA can handle. It will fallback to xformers when head size > 256 or dtype is float.
Benchmark shows the speedup is 3xbenchmark_multi_query_kv_attention.py
The benchmark below is invalid as I misused
torch.repeat_interleave
when benchmarking original case.