Skip to content

Conversation

@vadiklyutiy
Copy link
Collaborator

@vadiklyutiy vadiklyutiy commented Oct 8, 2025

Purpose

Qwen3-next MTP suffered from d<->h memory transfers on prefill phase. That makes MTP slower than STM (at least for some inputs).

This PR fixes the following

  1. Remove boolean indexing from GDN. Change it on direct indexing. (tensor[bool_mask] cause d2h transfer of number of True element needed for creation resulting tensor).
  2. Increase cache size in @tensor_cache. There are 2 calls for every GDN attn, they are grouped by 3. We need at least 2*3 cache entry.
  3. Make early .contiguous() for q,k,v in GDN. This is not help right now but helps with torch.compile to make better fusion.

Test Result

B200. Test MTP with 512 concurrency.

VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct -tp 4 --no-enable-prefix-caching --speculative-config '{"method": "qwen3_next_mtp", "num_speculative_tokens": 2}' --cuda-graph-sizes=2048
vllm bench serve --backend vllm --model Qwen/Qwen3-Next-80B-A3B-Instruct --endpoint /v1/completions --dataset-name random --random-input 1024 --random-output 1024 --max-concurrency 512 --num-prompt 512 --ignore-eos

Before

============ Serving Benchmark Result ============
Successful requests:                     512
Maximum request concurrency:             512
Benchmark duration (s):                  32.88
Total input tokens:                      524288
Total generated tokens:                  524288
Request throughput (req/s):              15.57
Output token throughput (tok/s):         15947.50
Peak output token throughput (tok/s):    9216.00
Peak concurrent requests:                512.00
Total Token throughput (tok/s):          31894.99
---------------Time to First Token----------------
Mean TTFT (ms):                          5594.84
Median TTFT (ms):                        5429.04
P99 TTFT (ms):                           11260.93
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          22.89
Median TPOT (ms):                        22.99
P99 TPOT (ms):                           29.52
---------------Inter-token Latency----------------
Mean ITL (ms):                           64.31
Median ITL (ms):                         56.32
P99 ITL (ms):                            165.80
==================================================

After

============ Serving Benchmark Result ============
Successful requests:                     512
Maximum request concurrency:             512
Benchmark duration (s):                  29.34
Total input tokens:                      524288
Total generated tokens:                  524288
Request throughput (req/s):              17.45
Output token throughput (tok/s):         17872.27
Peak output token throughput (tok/s):    9727.00
Peak concurrent requests:                512.00
Total Token throughput (tok/s):          35744.55
---------------Time to First Token----------------
Mean TTFT (ms):                          3953.48
Median TTFT (ms):                        3795.80
P99 TTFT (ms):                           7914.50
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          21.12
Median TPOT (ms):                        21.04
P99 TPOT (ms):                           26.59
---------------Inter-token Latency----------------
Mean ITL (ms):                           59.34
Median ITL (ms):                         56.31
P99 ITL (ms):                            117.85
==================================================

Speedup of Output token throughput is 12%

For comparison STP mode:

VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct -tp 4 --no-enable-prefix-caching --cuda-graph-sizes=2048
vllm bench serve --backend vllm --model Qwen/Qwen3-Next-80B-A3B-Instruct --endpoint /v1/completions --dataset-name random --random-input 1024 --random-output 1024 --max-concurrency 512 --num-prompt 512 --ignore-eos
============ Serving Benchmark Result ============
Successful requests:                     512
Maximum request concurrency:             512
Benchmark duration (s):                  37.51
Total input tokens:                      524288
Total generated tokens:                  524288
Request throughput (req/s):              13.65
Output token throughput (tok/s):         13976.35
Peak output token throughput (tok/s):    17920.00
Peak concurrent requests:                512.00
Total Token throughput (tok/s):          27952.70
---------------Time to First Token----------------
Mean TTFT (ms):                          3698.59
Median TTFT (ms):                        3467.95
P99 TTFT (ms):                           7608.89
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          32.47
Median TPOT (ms):                        32.75
P99 TPOT (ms):                           34.83
---------------Inter-token Latency----------------
Mean ITL (ms):                           32.47
Median ITL (ms):                         29.54
P99 ITL (ms):                            131.08
==================================================

@mergify mergify bot added qwen Related to Qwen models v1 labels Oct 8, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several performance optimizations for Qwen3-next MTP, primarily by replacing boolean mask indexing with direct indexing to avoid device-to-host synchronization. The changes are well-motivated and the logic to generate indices using torch.argsort is clever. I have one critical suggestion to ensure the correctness of this logic across different environments.

@vadiklyutiy vadiklyutiy force-pushed the vadim/remove-bool-indexing branch from 80ecbbf to afe63a7 Compare October 8, 2025 23:13
@xinli-sw
Copy link
Contributor

Hi @benchislett @LucasWilkinson could you assign reviewers for this change? Thanks!

@youkaichao
Copy link
Member

cc @sighingnow should be back

@mergify
Copy link

mergify bot commented Oct 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vadiklyutiy.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 14, 2025
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@vadiklyutiy vadiklyutiy force-pushed the vadim/remove-bool-indexing branch from afe63a7 to 099d260 Compare October 14, 2025 11:55
@mergify mergify bot removed the needs-rebase label Oct 14, 2025
Copy link
Collaborator

@sighingnow sighingnow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. As spec_token_masks is not used anymore it could be deleted.

@youkaichao youkaichao changed the title [PERF] Qwen3-next MTP speedup [PERF] Qwen3-next MTP speedup (change bool mask indexing to index_select / index_copy to reduce d2h) Oct 15, 2025
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@youkaichao youkaichao enabled auto-merge (squash) October 16, 2025 01:48
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 16, 2025
@youkaichao youkaichao disabled auto-merge October 16, 2025 04:18
@youkaichao youkaichao merged commit 785d8b6 into vllm-project:main Oct 16, 2025
55 of 57 checks passed
mandy-li pushed a commit to mandy-li/vllm that referenced this pull request Oct 16, 2025
…ect / index_copy to reduce d2h) (vllm-project#26437)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@vadiklyutiy vadiklyutiy deleted the vadim/remove-bool-indexing branch October 16, 2025 10:35
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 16, 2025
…ect / index_copy to reduce d2h) (vllm-project#26437)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 16, 2025
…ect / index_copy to reduce d2h) (vllm-project#26437)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
…ect / index_copy to reduce d2h) (vllm-project#26437)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
…ect / index_copy to reduce d2h) (vllm-project#26437)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…ect / index_copy to reduce d2h) (vllm-project#26437)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…ect / index_copy to reduce d2h) (vllm-project#26437)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
MengqingCao pushed a commit to vllm-project/vllm-ascend that referenced this pull request Oct 25, 2025
### What this PR does / why we need it?
Fix Qwen3NextGatedDeltaNet, caused by
vllm-project/vllm#26437

### How was this patch tested?
```
def main():
    prompts = [
        "窗前明月光,",
        "The president of the United States is Mr.",
        "The capital of France is",
        "The future of AI is",
        "感时花溅泪,",
        "家书抵万金啥意思?",
        "plz tell me a story: ",
    ]

    # Create a sampling params object.
    sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
    # Create an LLM.
    llm = LLM(
        model="/root/.cache/modelscope/hub/models/Qwen/Qwen3-Next-80B-A3B-Instruct",
              tensor_parallel_size=4,
              enforce_eager=True,
              trust_remote_code=True,
              max_model_len=256,
              gpu_memory_utilization=0.7,
              block_size=64
              )

    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: Icey <1790571317@qq.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…ect / index_copy to reduce d2h) (vllm-project#26437)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…ect / index_copy to reduce d2h) (vllm-project#26437)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants