-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[PERF] Qwen3-next MTP speedup (change bool mask indexing to index_select / index_copy to reduce d2h) #26437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PERF] Qwen3-next MTP speedup (change bool mask indexing to index_select / index_copy to reduce d2h) #26437
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several performance optimizations for Qwen3-next MTP, primarily by replacing boolean mask indexing with direct indexing to avoid device-to-host synchronization. The changes are well-motivated and the logic to generate indices using torch.argsort is clever. I have one critical suggestion to ensure the correctness of this logic across different environments.
80ecbbf to
afe63a7
Compare
|
Hi @benchislett @LucasWilkinson could you assign reviewers for this change? Thanks! |
|
cc @sighingnow should be back |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
afe63a7 to
099d260
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. As spec_token_masks is not used anymore it could be deleted.
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…ect / index_copy to reduce d2h) (vllm-project#26437) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…ect / index_copy to reduce d2h) (vllm-project#26437) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
…ect / index_copy to reduce d2h) (vllm-project#26437) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
…ect / index_copy to reduce d2h) (vllm-project#26437) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…ect / index_copy to reduce d2h) (vllm-project#26437) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…ect / index_copy to reduce d2h) (vllm-project#26437) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…ect / index_copy to reduce d2h) (vllm-project#26437) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
### What this PR does / why we need it? Fix Qwen3NextGatedDeltaNet, caused by vllm-project/vllm#26437 ### How was this patch tested? ``` def main(): prompts = [ "窗前明月光,", "The president of the United States is Mr.", "The capital of France is", "The future of AI is", "感时花溅泪,", "家书抵万金啥意思?", "plz tell me a story: ", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/Qwen/Qwen3-Next-80B-A3B-Instruct", tensor_parallel_size=4, enforce_eager=True, trust_remote_code=True, max_model_len=256, gpu_memory_utilization=0.7, block_size=64 ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Icey <1790571317@qq.com>
…ect / index_copy to reduce d2h) (vllm-project#26437) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…ect / index_copy to reduce d2h) (vllm-project#26437) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Purpose
Qwen3-next MTP suffered from d<->h memory transfers on prefill phase. That makes MTP slower than STM (at least for some inputs).
This PR fixes the following
tensor[bool_mask]cause d2h transfer of number ofTrueelement needed for creation resulting tensor).@tensor_cache. There are 2 calls for every GDN attn, they are grouped by 3. We need at least 2*3 cache entry..contiguous()for q,k,v in GDN. This is not help right now but helps withtorch.compileto make better fusion.Test Result
B200. Test MTP with 512 concurrency.
Before
After
Speedup of Output token throughput is 12%
For comparison STP mode: