Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Support MQA/GQA in prefill stage by using FlashAttention #2401

Closed
wants to merge 4 commits into from

Conversation

zhaoyang-star
Copy link
Contributor

@zhaoyang-star zhaoyang-star commented Jan 10, 2024

As shown in #1880, xformers has not supported MQA/GQA yet. So key and value need to be extended before calculating softmax(Q @ K^T * softmax_scale) @ V. While FlashAttention has supported MQA/GQA and has supported Turing, Ampere, Ada, or Hopper GPU. But note that FA has some limits (1. head size up to 256. datatype fp16 and bf16.)

So for prefill, I replaced xformers with FlashAttention when FA can handle. It will fallback to xformers when head size > 256 or dtype is float. Benchmark shows the speedup is 3x

  • Using CodeLLaMA-34B config (num_query_heads=64, num_key_value_heads=8, head_size=128)
  • Tested on A100-40GB
  • The latency is the time for caculating softmax(Q @ K^T * softmax_scale) @ V
  • The benchmark below could be reproduced by running benchmark_multi_query_kv_attention.py

The benchmark below is invalid as I misused torch.repeat_interleave when benchmarking original case.

Test id Batchsize Prompt length Original xformers (us) FA (us) Speedup (Original / FA)
1 1 1024 6.875 1.625 4.2
2 10 1024 36.965 11.518 3.2
3 100 1024 364.861 126.109 2.9

@Yard1
Copy link
Collaborator

Yard1 commented Jan 10, 2024

Hey, I cannot replicate the benchmark results wrt current implementation. Here is what I am getting:

  • --num-query-heads 64 --num-kv-heads 8 --head-size 128 --seq-len 1024 --batch-size 1 : 2.715 us

  • --num-query-heads 64 --num-kv-heads 8 --head-size 128 --seq-len 1024 --batch-size 1, modified to use non-blocking repeat_interleave (see vllm attention.py) : 1.876 us

  • --num-query-heads 64 --num-kv-heads 8 --head-size 128 --seq-len 1024 --batch-size 1 --use-flash-attn: 1.722 us

  • --num-query-heads 64 --num-kv-heads 8 --head-size 128 --seq-len 1024 --batch-size 10 : 20.311 us

  • --num-query-heads 64 --num-kv-heads 8 --head-size 128 --seq-len 1024 --batch-size 10, modified to use non-blocking repeat_interleave (see vllm attention.py) : 12.439 us

  • --num-query-heads 64 --num-kv-heads 8 --head-size 128 --seq-len 1024 --batch-size 1 --use-flash-attn: 12.166 us

As you can see, the flash attention result is close to what you have reported, but the current implementation result is much closer to flash attention, especially after the benchmark code has been modified to match what's actually in vllm code.

Here's the modification:

                key_expanded = key
                value_expanded = value
                query_expanded = query
                num_queries_per_kv = num_query_heads // num_kv_heads
                if num_queries_per_kv > 1:
                    # Handle MQA and GQA
                    query_expanded = query_expanded.view(query_expanded.shape[0], num_kv_heads,
                                    num_queries_per_kv, query_expanded.shape[-1])
                    key_expanded = key[:, :,
                            None, :].expand(key.shape[0], num_kv_heads,
                                            num_queries_per_kv, key.shape[-1])
                    value_expanded = value[:, :,
                                None, :].expand(value.shape[0], num_kv_heads,
                                                num_queries_per_kv,
                                                value.shape[-1])
                attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
                output = xops.memory_efficient_attention_forward(
                    query_expanded.unsqueeze(0),
                    key_expanded.unsqueeze(0),
                    value_expanded.unsqueeze(0),
                    attn_bias=attn_bias,
                    p=0.0,
                    scale=scale,
                )

A100-40GB, xformers==0.0.23.post1 flash_attn==2.4.2.

@zhaoyang-star zhaoyang-star marked this pull request as draft January 11, 2024 00:58
@zhaoyang-star
Copy link
Contributor Author

zhaoyang-star commented Jan 11, 2024

@Yard1 Thanks for your info. You are right. I should use expand same as in attention.py, rather than torch.repeat_interleave.

Updated benchmark results are below. Xformers is close to FA.

Test id Batchsize Prompt length Original xformers same in attention.py (us) FA (us) Speedup (Original / FA)
1 1 1024 2.010 1.625 1.24
2 10 1024 13.188 11.518 1.14
3 100 1024 140.809 126.109 1.12

xformers==0.0.22, flash-attn==2.4.2.

cc @casper-hansen @beginlner

@casper-hansen
Copy link
Contributor

For MQA/GQA, it should mainly see a speedup during decoding, although this looks good for a start. Can we measure throughput difference?

@zhaoyang-star
Copy link
Contributor Author

For MQA/GQA, it should mainly see a speedup during decoding, although this looks good for a start. Can we measure throughput difference?

I used starcoder which is a MQA model and the throughput is mainly close to the original version.

@Lvjinhong
Copy link

For MQA/GQA, it should mainly see a speedup during decoding, although this looks good for a start. Can we measure throughput difference?

I used starcoder which is a MQA model and the throughput is mainly close to the original version.

Hi, regarding the llama structure, does it significantly improve throughput performance?

@sh1ng
Copy link
Contributor

sh1ng commented Jan 11, 2024

https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html

Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor where a dimension of size one is expanded to a larger size by setting the stride to 0. Any dimension of size 1 can be expanded to an arbitrary value without allocating new memory.

I guess it explains the difference.

@zhaoyang-star
Copy link
Contributor Author

For MQA/GQA, it should mainly see a speedup during decoding, although this looks good for a start. Can we measure throughput difference?

I used starcoder which is a MQA model and the throughput is mainly close to the original version.

Hi, regarding the llama structure, does it significantly improve throughput performance?

I guess the result will be similar with starcoder as starcoder is MQA and llama2-34B is GQA. Besides, llama2-7B and 13B is MHA so will not gain any speedup.

@Lvjinhong
Copy link

When I tested this PR with the llama2 70B, the throughput did not improve. I used the asyncLLMengine server along with 4*A800 80G PCIE.

@zhaoyang-star
Copy link
Contributor Author

zhaoyang-star commented Jan 13, 2024

When I tested this PR with the llama2 70B, the throughput did not improve. I used the asyncLLMengine server along with 4*A800 80G PCIE.

Because the attention caculation could only achieve 1.1~1.2 speedup, the e2e speedup is hard to benchmark.

@zhaoyang-star zhaoyang-star marked this pull request as ready for review January 15, 2024 00:56
@casper-hansen
Copy link
Contributor

When I tested this PR with the llama2 70B, the throughput did not improve. I used the asyncLLMengine server along with 4*A800 80G PCIE.

Because the attention caculation could only achieve 1.1~1.2 speedup, the e2e speedup is hard to benchmark.

Largest speed up will probably be seen during decoding.

@sighingnow
Copy link
Contributor

See also #3010 which introduces the flash_attn_with_kvcache (available for paged kv-cache since flash-attn>=2.5.0).

@WoosukKwon
Copy link
Collaborator

Closing as it's already implemented. Thanks for submitting the PR. Learned a lot from it!

@WoosukKwon WoosukKwon closed this Aug 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants