[Speculative Decoding] Fix attn_mask_offset for multi-step MTP in mixed and PD-split modes#5738
Conversation
|
Thanks for your contribution! |
d973560 to
da28019
Compare
There was a problem hiding this comment.
Pull request overview
This PR fixes the attention mask offset calculation for multi-step MTP (Multi-Token Prediction) in mixed and PD-split (Prefill-Decode split) modes of speculative decoding. The fix addresses incorrect mask rollback behavior when the draft model operates in these specific configurations.
Key Changes:
- Modified mask_rollback calculation in PD-split mode to account for num_model_steps
- Added mask_rollback parameter propagation through the CUDA kernel pipeline
- Added debug logging for troubleshooting attn_mask_offsets and sequence lengths
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
fastdeploy/spec_decode/mtp.py |
Updated mask_rollback calculation formula for PD-split mode, added mask_rollback to kernel inputs, and included debug logging |
custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu |
Added mask_rollback parameter to kernel signature and implemented mask_rollback accumulation logic in decode generation path; includes code formatting improvements |
custom_ops/gpu_ops/cpp_extensions.cc |
Added mask_rollback parameter to DraftModelPreprocess function signature for consistency |
Comments suppressed due to low confidence (1)
fastdeploy/spec_decode/mtp.py:883
- These debug logging statements should be removed before merging to production. Debug logs at the info level in performance-critical code paths can significantly impact performance, especially when they execute on every substep iteration.
# Initialize forward meta data
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
| # NOTE(liuzichang): | ||
| # extra 1 : P-D split need rollback one step | ||
| # -(self.speculative_config.num_model_steps - 1) : | ||
| # 1. draft_model_preprocess will rollback (num_model_steps - 1) in each Step. But In P-D splitewise, | ||
| # 2. P only generate one token, so we need to minus it | ||
| self.model_inputs["mask_rollback"][idx : idx + 1] = 1 - ( | ||
| self.speculative_config.num_model_steps - 1 | ||
| ) |
There was a problem hiding this comment.
The calculation logic for mask_rollback is complex and would benefit from clearer documentation. The comment mentions two points but the explanation is somewhat unclear. Consider adding:
- A more detailed explanation of why we subtract (num_model_steps - 1)
- What the relationship is between P-D splitwise mode and num_model_steps
- An example calculation to illustrate the expected values
| # NOTE(liuzichang): | ||
| # extra 1 : P-D split need rollback one step | ||
| # -(self.speculative_config.num_model_steps - 1) : | ||
| # 1. draft_model_preprocess will rollback (num_model_steps - 1) in each Step. But In P-D splitewise, | ||
| # 2. P only generate one token, so we need to minus it | ||
| self.model_inputs["mask_rollback"][idx : idx + 1] = 1 - ( | ||
| self.speculative_config.num_model_steps - 1 | ||
| ) |
There was a problem hiding this comment.
The PR description is incomplete. According to the PR template and custom guidelines, the description should include:
- Motivation: A clear explanation of what problem this PR solves
- Modifications: Detailed description of the changes made
- Usage or Command: How to test or use these changes
- Accuracy Tests: Results demonstrating the fix works correctly
Currently, all these sections are empty. Please provide:
- Why the attn_mask_offset calculation was incorrect
- What scenarios were failing before this fix
- Test results showing the issue is resolved
| process_splitwise_prefill<BlockSize, | ||
| TRUNCATE_FIRST_TOKEN, | ||
| KVCACHE_SCHEDULER_V1> | ||
| <<<1, BlockSize, 0, stream>>>(draft_tokens, | ||
| input_ids, | ||
| stop_flags, | ||
| seq_lens_this_time, | ||
| seq_lens_encoder, | ||
| seq_lens_decoder, | ||
| step_idx, | ||
| not_need_stop, | ||
| is_block_step, | ||
| batch_drop, | ||
| pre_ids, | ||
| accept_tokens, | ||
| accept_num, | ||
| base_model_seq_lens_this_time, | ||
| base_model_seq_lens_encoder, | ||
| base_model_seq_lens_decoder, | ||
| base_model_step_idx, | ||
| base_model_stop_flags, | ||
| base_model_is_block_step, | ||
| base_model_draft_tokens, | ||
| bsz, | ||
| num_model_step, | ||
| accept_tokens_len, | ||
| draft_tokens_len, | ||
| input_ids_len, | ||
| base_model_draft_tokens_len, | ||
| pre_ids_len); |
There was a problem hiding this comment.
The process_splitwise_prefill kernel doesn't receive or use the mask_rollback parameter, while the draft_model_preprocess_kernel does. This asymmetry could be problematic if mask_rollback adjustments are needed in splitwise_prefill mode. Please verify whether:
- The mask_rollback logic is intentionally not needed for splitwise_prefill scenarios
- If it is needed, the parameter should be added to this kernel as well
If it's intentional that splitwise_prefill doesn't need mask_rollback, consider adding a comment explaining why.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #5738 +/- ##
==========================================
Coverage ? 65.91%
==========================================
Files ? 330
Lines ? 41819
Branches ? 6406
==========================================
Hits ? 27567
Misses ? 12210
Partials ? 2042
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…ed and PD-split modes (PaddlePaddle#5738) * fix attn_mask_offset in mtp with multi-step and pd-split-mode * fix xpu operater register * update pmtp multi-step mtp strategy in d-split -mode * add note * fix xpu register
…ed and PD-split modes (PaddlePaddle#5738) * fix attn_mask_offset in mtp with multi-step and pd-split-mode * fix xpu operater register * update pmtp multi-step mtp strategy in d-split -mode * add note * fix xpu register
…ed and PD-split modes (PaddlePaddle#5738) * fix attn_mask_offset in mtp with multi-step and pd-split-mode * fix xpu operater register * update pmtp multi-step mtp strategy in d-split -mode * add note * fix xpu register
…ed and PD-split modes (PaddlePaddle#5738) * fix attn_mask_offset in mtp with multi-step and pd-split-mode * fix xpu operater register * update pmtp multi-step mtp strategy in d-split -mode * add note * fix xpu register
…P in mixed and PD-split modes (PaddlePaddle#5738)" This reverts commit ba0d35a.
…ed and PD-split modes (PaddlePaddle#5738) * fix attn_mask_offset in mtp with multi-step and pd-split-mode * fix xpu operater register * update pmtp multi-step mtp strategy in d-split -mode * add note * fix xpu register
* [Speculative Decoding] Fix attn_mask_offset for multi-step MTP in mixed and PD-split modes (#5738) * fix attn_mask_offset in mtp with multi-step and pd-split-mode * fix xpu operater register * update pmtp multi-step mtp strategy in d-split -mode * add note * fix xpu register * fix entropy bugs * Revert "[Speculative Decoding] Fix attn_mask_offset for multi-step MTP in mixed and PD-split modes (#5738)" This reverts commit ba0d35a. * fix ut * fix --------- Co-authored-by: freeliuzc <lzc842650834@gmail.com>
…ed and PD-split modes (PaddlePaddle#5738) * fix attn_mask_offset in mtp with multi-step and pd-split-mode * fix xpu operater register * update pmtp multi-step mtp strategy in d-split -mode * add note * fix xpu register
…ed and PD-split modes (PaddlePaddle#5738) * fix attn_mask_offset in mtp with multi-step and pd-split-mode * fix xpu operater register * update pmtp multi-step mtp strategy in d-split -mode * add note * fix xpu register
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.