[Cherry-Pick][CI] Fix attn_mask_offset for multi-step MTP in mixed and PD-split modes(#5738)#5792
Conversation
…ed and PD-split modes (PaddlePaddle#5738) * fix attn_mask_offset in mtp with multi-step and pd-split-mode * fix xpu operater register * update pmtp multi-step mtp strategy in d-split -mode * add note * fix xpu register
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
This is a cherry-pick of PR #5738 that fixes the attn_mask_offset calculation for multi-step MTP (Medusa Tree-based Parallel decoding) in mixed and PD-split (Prefill-Decode split) modes. The fix introduces two new parameters (mask_rollback and recompute_token_num) to properly track and manage attention mask offsets when tokens are recomputed in speculative decoding.
Key Changes:
- Adds
mask_rollbackandrecompute_token_numtracking for proper attention mask offset calculation in multi-step scenarios - Updates CUDA kernels and Python implementation to handle dynamic recomputation strategies, particularly for PD-split mode
- Improves test structure by adding
unittest.TestCaseinheritance and simplifying assertions
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu |
Adds mask_rollback and recompute_token_num parameters to CUDA kernel, implements logic to track recomputed tokens and update mask offsets; includes code formatting improvements |
custom_ops/gpu_ops/cpp_extensions.cc |
Updates function signature to include new mask_rollback and recompute_token_num parameters |
fastdeploy/spec_decode/mtp.py |
Initializes new model input tensors for mask rollback and recompute tracking; sets special values for PD-split mode where prefill-decode coordination requires one-step rollback |
tests/operators/test_draft_model_preprocess.py |
Adds new parameters to test functions, improves test class with unittest.TestCase inheritance, and simplifies assertion logic to loop through all results |
| np.testing.assert_allclose(results1[12], results2[12]) # accept_num | ||
| np.testing.assert_allclose(results1[7], results2[7]) # not_need_stop | ||
| for i in range(12): | ||
| np.testing.assert_equal(results1[i].numpy(), results2[i].numpy()) |
There was a problem hiding this comment.
Using np.testing.assert_equal for all 12 results may be too strict. The original code used assert_allclose for most comparisons, which allows for small floating-point differences. If any of the results contain floating-point values (like logits or probabilities), this change could cause tests to fail due to numerical precision issues. Consider using assert_allclose for floating-point tensors and assert_equal for integer/boolean tensors.
| np.testing.assert_equal(results1[i].numpy(), results2[i].numpy()) | |
| arr1 = results1[i].numpy() | |
| arr2 = results2[i].numpy() | |
| if np.issubdtype(arr1.dtype, np.floating): | |
| np.testing.assert_allclose(arr1, arr2) | |
| else: | |
| np.testing.assert_equal(arr1, arr2) |
|
|
||
|
|
||
| class TestDraftModelPreprocess: | ||
| class TestDraftModelPreprocess(unittest.TestCase): |
There was a problem hiding this comment.
Adding unittest.TestCase as the base class is a good improvement. This ensures proper test discovery and integration with unittest framework. However, this change doesn't appear to be related to the PR's main purpose (fixing attn_mask_offset). Consider mentioning this improvement in the PR description or making it a separate cleanup commit.
| mask_rollback = paddle.zeros([bsz], dtype="int32") | ||
| recompute_token_num = paddle.zeros([bsz], dtype="int32") |
There was a problem hiding this comment.
The test initializes recompute_token_num to zeros, but in the actual implementation (mtp.py), it's initialized to self.num_model_steps - 1. This discrepancy means the test isn't fully exercising the behavior when recompute_token_num has a non-zero initial value. Consider testing both scenarios: initial value of 0 (for PD-split mode) and initial value of num_model_step - 1 (for normal multi-step MTP).
| process_splitwise_prefill<BlockSize, | ||
| TRUNCATE_FIRST_TOKEN, | ||
| KVCACHE_SCHEDULER_V1> | ||
| <<<1, BlockSize, 0, stream>>>(draft_tokens, | ||
| input_ids, | ||
| stop_flags, | ||
| seq_lens_this_time, | ||
| seq_lens_encoder, | ||
| seq_lens_decoder, | ||
| step_idx, | ||
| not_need_stop, | ||
| is_block_step, | ||
| batch_drop, | ||
| pre_ids, | ||
| accept_tokens, | ||
| accept_num, | ||
| base_model_seq_lens_this_time, | ||
| base_model_seq_lens_encoder, | ||
| base_model_seq_lens_decoder, | ||
| base_model_step_idx, | ||
| base_model_stop_flags, | ||
| base_model_is_block_step, | ||
| base_model_draft_tokens, | ||
| bsz, | ||
| num_model_step, | ||
| accept_tokens_len, | ||
| draft_tokens_len, | ||
| input_ids_len, | ||
| base_model_draft_tokens_len, | ||
| pre_ids_len); |
There was a problem hiding this comment.
The process_splitwise_prefill kernel call is missing the mask_rollback and recompute_token_num parameters that were added to the DispatchRunner function signature. These parameters are passed to draft_model_preprocess_kernel (lines 348-349) but not to process_splitwise_prefill.
If these parameters are not needed for the splitwise_prefill path, they should still be passed to maintain API consistency and avoid potential issues when the logic changes. If they are needed, this is a critical bug that could cause incorrect behavior in splitwise_prefill mode.
| # NOTE(liuzichang): | ||
| # extra 1 : P-D split need rollback one step | ||
| self.model_inputs["mask_rollback"][idx : idx + 1] = 1 |
There was a problem hiding this comment.
The comment mentions "extra 1: P-D split need rollback one step" but this is misleading. Looking at line 529, mask_rollback is set to 1, which represents the extra rollback step. The comment should be placed above line 529 where mask_rollback is set, not above line 528 where it appears now. This would make the code more clear about which line the comment is describing.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## release/online/20251131 #5792 +/- ##
==========================================================
Coverage ? 59.00%
==========================================================
Files ? 319
Lines ? 39108
Branches ? 5893
==========================================================
Hits ? 23074
Misses ? 14179
Partials ? 1855
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
19a625a
into
PaddlePaddle:release/online/20251131
…ed and PD-split modes (#5738)
fix attn_mask_offset in mtp with multi-step and pd-split-mode
fix xpu operater register
update pmtp multi-step mtp strategy in d-split -mode
add note
fix xpu register
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.