-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[Model] Implement DualChunkAttention for Qwen2 Models #6139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Can you discuss why is this the case? If possible i would really appreciate that we get the first iteration working with cuda graph. @WoosukKwon: Also, IMO the best place to put this actually in FlashAttentionBackend and accept extra arguments for chunked config, but would like to to hear your thought |
@simon-mo The brute-force implementation of DCA (https://github.com/HKUNLP/ChunkLlama/blob/main/chunkllama_attn_replace.py#L169-L199) splits the entire sequence into several chunks and invokes the flashattention kernel 3*
@WoosukKwon The primary issue here lies in the argument list of |
|
Is If not, I would strongly prefer that the logic for this feature be implemented inside of |
Expect to the implementation of CUDA kernel😊 |
@robertgshaw2-neuralmagic No, it can be applied to all most models using RoPE. |
@nanmi I'm actually seeking helps from the community to implement the CUDA kernel. I'm not an expert. 😅 |
For instance, we have the following code: if dual_chunk_attention_config is not None:
self.attn = DualChunkAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
dual_chunk_attention_config=dual_chunk_attention_config)
else:
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)And: if self.dual_chunk_attention_config is None:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
else:
q, q_succ, q_inter, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, q_succ, q_inter, k, v, kv_cache, attn_metadata)If we could push this branching into |
|
@robertgshaw2-neuralmagic DCA requires three queries ( |
/
Can you elaborate on your ideas? I implemented DCA based on the python code logic and CUDA C++, but I think it can be more elegant and efficient logically. I am worried about how to optimize the segmentation processing and the final merging. |
|
@nanmi The functions Taking |
I'll follow your idea and think about how to use CUDA to implement |
|
@nanmi Really appreciate your help! If you encounter any questions, feel free to ask here |
| slot_mapping_tensor = torch.tensor(slot_mapping, | ||
| dtype=torch.long, | ||
| device=self.device) | ||
| prefill_original_seq_lens_tensor = torch.tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to push the creation of this tensor into attn_backend.make_metadata?
| self.kv_cache_dtype, | ||
| self.block_size, | ||
| ) if num_attn_heads else None | ||
| if getattr(self.model_config.hf_config, "dual_chunk_attention_config", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add dual_chunk_attention to ModelConfig to accessing the hf_config from here. We ran into problems with accessing the hf_config with Sliding Window in the past
| ) if num_attn_heads else None | ||
| if getattr(self.model_config.hf_config, "dual_chunk_attention_config", | ||
| None): | ||
| self.attn_backend = get_dual_chunk_attn_backend( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason this has to be a different function vs passing a self.model_config.get_dual_chunk_attention to get_attn_backend
|
Apologies for the delay in reviewing here @simon-mo @WoosukKwon and I discussed, and while we do not love implementing features in specific models, I think that we do not have the right abstractions in place yet to enable simultaneously modifying I left a few comments on I think that this PR also needs some tests:
|
| seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len | ||
| max_seq_len_inter = seq_lens_inter.max().item() | ||
| if max_seq_len_inter: | ||
| inter_output, succ_softmax_lse = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be inter_softmax_lse instead of succ_softmax_lse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out the typo. Will revise in the next commit.
| chunk_len = chunk_size - local_size | ||
| if chunk_len % block_size != 0: | ||
| raise ValueError("chunk_len must be divisible by block_size.") | ||
| chunk_num_curr = (cache_seqlens - 1) // chunk_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line makes the actual length of intra-chunk smaller than chunk_len in the variable length situation. It differs from the algorithm in the prefill stage with a fixed length of chunk_len. Could you explain the reason for this difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rejoicesyc The implementation is consistent with the original code (https://github.com/HKUNLP/ChunkLlama/blob/8a28f1464b2a5def03eb07e8d91f5bf4d00f667d/chunkllama_attn_replace.py#L167). See L167, L171, L173, L197.
Since the sequence length may not be divisible by chunk_len, there has to be a chunk that has a smaller length. In DCA algorithm, the last chunk (i.e., the intra chunk) may have length <= chunk_len while the previous chunk has a fixed length of chunk_len.
| 1.0).clip(min=1) | ||
| query = (query * mscale.view(-1, 1, 1, 1)).to( | ||
| query.dtype | ||
| ) # possible for numerical issue, need to fused in the kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @hzhwcmhf, could you explain this numerical issue and why kernel fusion can solve it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rejoicesyc We want to compute (query @ key) * softmax_scale * mscale as the attention weight, where query is represented in float16 or bfloat16, and mscale is represented in float32. Better precision can be achieved if the scaling is computed in float32, rather than multiplying query by mscale before the attention operation. The flash attention kernel provides an argument softmax_scale, which is close to our requirement. However, softmax_scale is a constant for the entire query, whereas mscale is a vector that specifies different softmax_scale values for each query vector at different positions.
|
This pull request has merge conflicts that must be resolved before it can be |
|
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
|
This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you! |
Overview
Dual Chunk Attention is a training-free method to extend model context length. It splits the entire sequence into chunks and uses three distinct query vectors to capture relative information within the same chunk, between successive chunks, and between distant chunks. The original implementation (using Hugging Face's Transformers) can be found here.
Qwen2 models now integrate this method, supporting 1M context length processing on 8x H100-80G GPUs. Qwen2-72B-Instruct achieves ~75% accuracy in the Needle in A Haystack test with 1M tokens input.
Features:
3*chunk_numtimes).Limitations:
enforce_eagermust be set to True due to the dynamic graph created by the brute-force implementation.Changes
DualChunkRotaryEmbeddinginvllm/model_executor/layers/rotary_embedding.py. The functionDualChunkRotaryEmbedding.forwardreturnsquery, query_succ, query_interinstead of a singlequery. These three query vectors are used for computing intra-/succ-/inter-attention in Dual Chunk Attention.DualChunkAttentionBackendinvllm/attention/backends/abstract.pyand implementDualChunkFlashAttentionBackendinvllm/attention/backends/dual_chunk_flash_attn.py. Note that we add an extra variableprefill_original_seq_lens_tensorinDualChunkFlashAttentionMetadata, which stores the whole prefill sequences' lengths. To obtain the value, we insert a few lines invllm/model_runner.py.DualChunkAttentioninvllm/vllm/attention/layer.py, which simply callsDualChunkAttentionBackend.dual_chunk_attention_configinQwen2ModelandQwen2MoeModel.How to use
config.jsonfile following:You will see outputs like: