Skip to content

Conversation

@sighingnow
Copy link
Collaborator

@sighingnow sighingnow commented Sep 16, 2025

This PR fixes the corner cases where guided decoding backend rollbacks draft tokens, causing unaligned verify batches.

Fixes #24730.
Fixes #24881.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix an issue with variable-length sequences in Qwen3-Next's multi-token prediction implementation, particularly for speculative decoding rollbacks. The changes span across the causal convolution Triton kernel, the Qwen3-Next model file, and the GatedDeltaNet attention backend. While the core logic for handling varlen inputs seems correct, I've identified a critical issue in the attention backend related to CUDA graph batch size calculation and a high-severity issue in the Triton kernel concerning the use of constexpr for runtime variables, which could lead to severe performance degradation or compilation errors.

Comment on lines 76 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Changing self.decode_cudagraph_max_bs to be a token count (by multiplying with self.num_spec + 1) is incorrect, as this variable is used as a sequence count (batch size) for tensor allocations. For example, self.spec_state_indices_tensor is allocated with this as its first dimension (line 80), which is indexed by sequence, not by token. This change will lead to incorrect tensor allocations (either too large, wasting memory, or too small, causing out-of-bounds errors) and likely runtime failures.

To fix this correctly, decode_cudagraph_max_bs should remain a sequence count. A new variable should be introduced for the maximum token count if needed for the check at line 221.

Comment on lines 730 to 732
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The kernel parameters state_len and seqlen are declared as tl.constexpr in the function signature (lines 635 and 634 respectively), but they are being reassigned here. constexpr values are meant to be compile-time constants and should not be modified at runtime. Passing runtime values to tl.constexpr parameters causes Triton to recompile the kernel for each unique value, which can lead to significant performance degradation and long compilation times. This reassignment is also confusing and can lead to unexpected behavior.

To fix this, you should change their type hints in the kernel signature to int. Additionally, for better code clarity and to avoid modifying input parameters, it's recommended to use new local variables for the updated values.

…mentation.

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
(cherry picked from commit 8b83d23259ac24ec1f3e5e012da0c997a90031d8)
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Sep 17, 2025
…mentation

Fixes CUDA illegal memory access errors during Qwen3-Next speculative decoding
by implementing proper varlen sequence handling and CUDA graph batch size fixes.

Key changes from upstream PR vllm-project#24957:
- Enhanced GDNAttentionMetadata with num_actual_tokens field
- Fixed CUDA graph batch size calculation for speculative decoding scenarios
- Added varlen sequence support to causal_conv1d operations
- Improved token accounting across MTP verification paths

Resolves issues with:
- Multi-token prediction verification with unaligned speculative tokens
- Variable-length sequence processing in continuous batching
- CUDA memory allocation errors in graph capture

Co-authored-by: upstream contributors from PR vllm-project#24957
Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com>
@chaunceyjiang
Copy link
Collaborator

100%|█████████████████████████████████████████████████████████████████████████████| 100/100 [01:11<00:00,  1.39it/s]
============ Serving Benchmark Result ============
Successful requests:                     100       
Maximum request concurrency:             10        
Benchmark duration (s):                  71.97     
Total input tokens:                      204357    
Total generated tokens:                  97082     
Request throughput (req/s):              1.39      
Output token throughput (tok/s):         1348.89   
Total Token throughput (tok/s):          4188.30   
---------------Time to First Token----------------
Mean TTFT (ms):                          540.71    
Median TTFT (ms):                        170.88    
P99 TTFT (ms):                           4819.26   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.88      
Median TPOT (ms):                        6.29      
P99 TPOT (ms):                           13.97     
---------------Inter-token Latency----------------
Mean ITL (ms):                           19.42     
Median ITL (ms):                         15.50     
P99 ITL (ms):                            127.42    
==================================================

Nice!!!! @sighingnow

@sighingnow
Copy link
Collaborator Author

Nice!!!! @sighingnow

@chaunceyjiang Thanks for the help to verify and the feedback!

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix!

@sighingnow sighingnow enabled auto-merge (squash) September 17, 2025 09:52
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 17, 2025
@youkaichao youkaichao disabled auto-merge September 17, 2025 13:59
@youkaichao youkaichao merged commit dd6a910 into vllm-project:main Sep 17, 2025
58 of 62 checks passed
@david6666666
Copy link
Contributor

thanks for the fix!

FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…mentation. (vllm-project#24957)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…mentation. (vllm-project#24957)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Signed-off-by: charlifu <charlifu@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…mentation. (vllm-project#24957)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
…mentation. (vllm-project#24957)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…mentation. (vllm-project#24957)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

4 participants