-
-
Notifications
You must be signed in to change notification settings - Fork 9.1k
Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support #11844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
82b5a4c
to
4c4a33e
Compare
I see that you have |
4c4a33e
to
6b7c49e
Compare
This pull request has merge conflicts that must be resolved before it can be |
6b7c49e
to
35aac26
Compare
35aac26
to
91d5476
Compare
All conflicts fixed, could you please take another look? thanks! |
st] = decode_metadata.block_tables[i, st:ed] | ||
decode_metadata.block_tables_intra = block_tables_intra | ||
|
||
seq_lens_succ = (chunk_num_curr - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I try the Needle in a haystack test with qwen-7b and llama-8b(Modified code to support llama), there is a bug that produces a negative number when it is over 13k~15k.
I modified the code as below and confirmed that it works.
seq_lens_succ = ((chunk_num_curr - (chunk_num_curr - 1).clip(min=0)) * chunk_len)
This pull request has merge conflicts that must be resolved before it can be |
I tested it because I thought it was fixed, but I still have the same problem as below.
|
91d5476
to
c8781cd
Compare
The dual chunk attention doesn't support cuda graph and I have added an assertion in
It is indeed a bug introduced during preparing this PR, fixed. Thanks! |
c8781cd
to
8648b1e
Compare
Rebase against main. Hi @youkaichao @simon-mo @WoosukKwon Do you folks think if there are still things that need to be improved in this pull request? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spotted a few bits ofcommented out code that look like debug cruft or are otherwise mysterious. Could you clean those up and any other similar spots?
This pull request has merge conflicts that must be resolved before it can be |
qc_freqs = torch.einsum("i,j -> ij", qc_t, inv_freq) | ||
k_freqs = torch.einsum("i,j -> ij", k_t, inv_freq) | ||
qc_no_clamp_freqs = torch.einsum("i,j -> ij", qc_no_clamp_t, inv_freq) | ||
q_inter_freqs = torch.einsum("i,j -> ij", q_inter_t, inv_freq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think these einsum's are still slow on cuda than (a * b).sum(-1)
, not on the hot path though so not critical
ran bench_einsum.py
from that issue on an H100 and got:
python einsum_bench.py
[------------------------------------- -------------------------------------]
| mul/sum | torch.einsum | numpy.einsum
1 threads: -------------------------------------------------------------------
Nc,Nc->N cpu (1048576, 2) | 5000 | 3100 | 4000
Nc,Nc->N cuda (1048576, 2) | 20 | 747 | 3300
Times are in microseconds (us).
vllm/attention/layer.py
Outdated
logits_soft_cap, attn_type, **{ | ||
"dual_chunk_attention_config": dual_chunk_attention_config, | ||
"prefix": prefix, | ||
} if dual_chunk_attention_config is not None else {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this messy, I think we should maybe do something like:
def __init__(..., **extra_attn_kwargs):
self.impl = impl_cls(..., **extra_attn_kwargs)
the challenge here is prefix
would not be captured by extra_attn_kwargs
but is only (currently) used by DualChunkFlashAttentionImpl
. I do think it would be less messy though to do this any make prefix
a standard arg for attention impls, given that it is pretty generic. Thoughts @WoosukKwon
vllm/attention/layer.py
Outdated
if self.dual_chunk_attention_config: | ||
assert query_succ_and_inter is not None | ||
dca_kwargs = { | ||
"query_succ": query_succ_and_inter[0], | ||
"query_inter": query_succ_and_inter[1], | ||
"query_succ_critical": query_succ_and_inter[2], | ||
"query_inter_critical": query_succ_and_inter[3], | ||
} if query_succ_and_inter else {} | ||
else: | ||
dca_kwargs = {} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should try hard to see if there is cleaner way of passing these, maybe they can be bundled into a single q
tensor that get reinterpreted as components via a combination of slicing and .view
calls in the attn impl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would take a try to see if it can be simplified.
…h sparse attention support (vllm-project#11844)
…h sparse attention support (vllm-project#11844) Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
So If I understand correctly, now Qwen2.5-1M actually uses the correct attention mechanism and VRAM should be lowered and prompt processing faster, right ? |
I tested Qwen/Qwen2.5-7B-Instruct-1M using DualChunkFlashAttention backend. ubuntu-vllm-openai-1 | INFO 05-31 19:13:07 [logger.py:42] Received request cmpl-77d91882816c4f748e2023c93449f62d-0: prompt: 'Once upon a time', params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.05, temperature=0.0, top_p=1.0, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=1000, min_tokens=0, logprobs=1, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None), prompt_token_ids: [12522, 5193, 264, 882], prompt_embeds shape: None, lora_request: None, prompt_adapter_request: None. |
Exact same issue as above |
PR #19084 Fixes this issue. When working with contexts of 70k, with the model loaded + the context it uses something like 30Gb of vram, but during inference it goes up to 35-37gb of vram then back down to 30Gb. I'm guessing it's expected but is there some kind of way to preallocating this memory ? Because if you let VLLM allocate 80% of the vram and it tries to "eat" more VRAM, well obviously it will OOM Edit :
|
The qk estimate softmax has high memory overhead: https://github.com/vllm-project/vllm/blob/main/vllm/attention/backends/dual_chunk_flash_attn.py#L834 During start-up profiling, DCA specifically routes to flash-attention instead of the DCA sparse prefill function: In principle, there's no reason to use flash-attention during profiling from what I can see. So having that branch instead call the sparse attention branch should at least identify the OOM during profiling. |
Not a "blog" but it can help people working with it, so far we got much better results with Qwen 2.5 7b 1m than with nemotron 4M from nvidia. However beside the issues states before with quantization and gpu splitting, we did not manage either to do batching/parallel processing |
Quantization support has been added into #19420 Could not test kv cache quantization because this attention mechanism is based on Flash attention |
…h sparse attention support (vllm-project#11844) Signed-off-by: minpeter <kali2005611@gmail.com>
Thanks for reporting @ExtReMLapin @exceedzhang . Will investigate this week. |
It’s already fixed and a PR has been merged. |
@sighingnow @exceedzhang thank for your contributions, it's mostly those PR that needs a review : Priority because crash fix : #19084 FP8 quantization support #19420 |
Thank you for your development work; I've tested it, and the feature functions correctly. However, I've noticed a performance drop after enabling FP8 quantization. Here are the performance test results using four RTX 4090 24GB GPUs. |
I agree with you, we expect better performance with FP8 because of lower memory bottleneck. I also have another update waiting under the hood on this branch which should improve performances (packed torch operations) : Considering the slow downs ... isn't that the fault of the flash attention implementation considering the very little changes I did ? |
@ExtReMLapin ![]() |
Got it, not merging this performance branch into the FP8 branch then, it's not worth the risk of breaking something ! Again at the office we really appreciate the effort spent on releasing those models. We can a lot of tests, including other models claiming to have long context :
but this is the ONLY model actually following instructions on very long context and that can be ran easily (not insane resources). Looking forward to see more models like this in the future ! |
Well I'm not sure exactly what happened but reading my PR code again and again , it should only affect KV cache quantization and not model quantization, now checking again without my pr , quantization seems to work, without my changes, making this comment sounds like i'm insane #19084 (comment) |
Ran more tests :
|
This PR implements the dual-chunk flash attention, a training-free method to extend model context length (see also #6139), with sparse attention (https://github.com/microsoft/MInference) support.
This PR requires the sparse attention kernel from vllm-flash-attention. Qwen models with 1m context length support will be open-sourced in the next one or two weeks, and unit tests will be added later.
FIX #12452