Skip to content

[Bugfix] Fix topk_ids indices_type for CUTLASS w8a8 FP8 MoE #20166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

minosfuture
Copy link

@minosfuture minosfuture commented Jun 27, 2025

Purpose

This PR fixes the following error when starting EP on Maverick:

(VllmWorker rank=3 pid=1737537) ERROR 06-15 22:58:28 [multiproc_executor.py:527]     run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
(VllmWorker rank=3 pid=1737537) ERROR 06-15 22:58:28 [multiproc_executor.py:527]   File "/home/yeq/gitrepos/vllm/vllm/model_executor/layers/fused_moe/cutlass_moe.py", line 89, in run_cutlass_moe_fp8
(VllmWorker rank=3 pid=1737537) ERROR 06-15 22:58:28 [multiproc_executor.py:527]     local_topk_ids = torch.where(expert_map[topk_ids] != -1,
(VllmWorker rank=3 pid=1737537) ERROR 06-15 22:58:28 [multiproc_executor.py:527]                                  ~~~~~~~~~~^^^^^^^^^^
(VllmWorker rank=3 pid=1737537) ERROR 06-15 22:58:28 [multiproc_executor.py:527] IndexError: tensors used as indices must be long, int, byte or bool tensors

In the PPLX implementation #18762, the dtype got flipped to uint32, here.

Besides this fix, the workspace_shapes needed another fix here from #19168, which is already merged; otherwise, the torch.zeros is slow for processing much larger size of data here.

Test Plan

  1. benchmark for latency sanity
# serve
vllm serve meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 \
        --max_model_len 8192 \
        --kv_cache_dtype fp8 \
        --enable-expert-parallel \
        --tensor-parallel-size 8 \
        --trust-remote-code \
        --enforce_eager \
        --gpu-memory-utilization 0.8 \
        --disable-log-requests 2>&1 | tee ep_`date +%Y%m%d_%H%M%S`.log
# benchmark serve
python benchmarks/benchmark_serving.py  --model meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 \
        --port 8000  --dataset-name random  --ignore-eos  --num-prompts 500   --max-concurrency 128 \
        --random-input-len 2000 --random-output-len 150
  1. lm_eval

Test Result

runtime exception during init is fixed. Attaching benchmark results:

============ Serving Benchmark Result ============
Successful requests:                     500
Benchmark duration (s):                  42.34
Total input tokens:                      998815
Total generated tokens:                  75000
Request throughput (req/s):              11.81
Output token throughput (tok/s):         1771.43
Total Token throughput (tok/s):          25362.46
---------------Time to First Token----------------
Mean TTFT (ms):                          1119.22
Median TTFT (ms):                        384.95
P99 TTFT (ms):                           5939.92
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          63.31
Median TPOT (ms):                        66.14
P99 TPOT (ms):                           67.69
---------------Inter-token Latency----------------
Mean ITL (ms):                           63.31
Median ITL (ms):                         33.70
P99 ITL (ms):                            198.63
==================================================

lm_eval results:

local-chat-completions (model=meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8,base_url=http://127.0.0.1:8081/v1/chat/completions,num_concurrent=32), gen_kwargs: (None), limit: 200.0, num_fewshot: 5, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.935 ± 0.0175
strict-match 5 exact_match 0.920 ± 0.0192

with cuda graph:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.940 ± 0.0168
strict-match 5 exact_match 0.925 ± 0.0187

(Optional) Documentation Update

Signed-off-by: Ming Yang <yming@meta.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @minosfuture, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request provides a crucial bugfix for the Mixture of Experts (MoE) implementation, specifically for FP8 quantization with CUTLASS. It resolves a runtime IndexError that prevented the successful initialization and execution of models utilizing this configuration, ensuring the stability and functionality of FP8 MoE operations.

Highlights

  • Bugfix: MoE FP8 Indexing: This pull request addresses a critical IndexError occurring during the execution of CUTLASS w8a8 FP8 MoE (Mixture of Experts) operations. The error stemmed from topk_ids tensors being incorrectly cast to torch.uint32, which is not a valid type for indexing in PyTorch.
  • Code Correction: The fix involves removing the explicit indices_type=torch.uint32 argument from the apply function call within the fused_experts initialization in compressed_tensors_moe.py. This allows the system to use the correct default or inferred integer type (e.g., torch.long) for indexing, resolving the runtime crash.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a runtime IndexError that occurred during MoE execution with CUTLASS. The error was caused by topk_ids having an unsupported uint32 dtype for indexing. The fix, which removes the indices_type=torch.uint32 argument from the routing function call, is direct and effective, allowing the topk_ids tensor to default to a valid type for indexing. The change is well-supported by the provided error log and test results.

@yeqcharlotte
Copy link
Collaborator

thanks for the fix! could you also share the eval result? has cudagraph worked it?

cc: @ElizaWszola @bnellnm to take a look!

@minosfuture
Copy link
Author

minosfuture commented Jun 27, 2025

thanks for the fix! could you also share the eval result? has cudagraph worked it?

cc: @ElizaWszola @bnellnm to take a look!

updated with lm-eval results. Note that it's tested with the correctness fix #20167. Yes, both eager and cuda graph work.

@ElizaWszola
Copy link
Contributor

Thanks for the fix! Can you please check if the kernels in csrc/quantization/cutlass_w8a8/moe/moe_data.cu that use uint32_t topk_ids will still work and compile without complaints if you change the types to int32_t and update your pr to use int32_t in these functions? If this breaks the kernels, it would be good to have an explicit conversions to uint32_t when we want to call them.

Signed-off-by: Ming Yang <yming@meta.com>
@minosfuture
Copy link
Author

Thanks for the fix! Can you please check if the kernels in csrc/quantization/cutlass_w8a8/moe/moe_data.cu that use uint32_t topk_ids will still work and compile without complaints if you change the types to int32_t and update your pr to use int32_t in these functions? If this breaks the kernels, it would be good to have an explicit conversions to uint32_t when we want to call them.

updated.

In PplxPrepareAndFinalize and DeepEPLLPrepareAndFinalize, topk_indices_dtype returns uint32 and int64, respectively. I suggest we change them to int32 for consistency? Keeping them as is shouldn't result in casting error though given that ids should be within a small range from zero.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants