Skip to content

Conversation

@Luosuu
Copy link

@Luosuu Luosuu commented Oct 25, 2025

Integrates Triton MoE communication kernels

achieved 60% throughput increase compared to #24588

36117

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates Triton kernels for MoE communication, which is a significant performance enhancement. The changes are extensive, touching profiling, environment variable handling, and the core MoE logic. My review focuses on correctness and potential issues. I've identified a critical import typo, a bug in profiling metric calculation, a hardcoded backend selection that should be revisited, and a couple of areas where the code could be made more robust to prevent future errors. Overall, this is a promising direction for improving MoE performance.

from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.distributed import get_dp_group, get_ep_group

from triton_kenerls.distributed import symm_mem_pool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a typo in the import statement. triton_kenerls should be triton_kernels. This will cause a runtime ImportError and prevent the module from loading.

Suggested change
from triton_kenerls.distributed import symm_mem_pool
from triton_kernels.distributed import symm_mem_pool

Comment on lines +1220 to +1228
try:
payload = await raw_request.json()
except json.JSONDecodeError:
payload = None
except Exception:
payload = None
else:
if isinstance(payload, dict):
profile_options = payload
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The exception handling for parsing the JSON payload is too broad and contains a redundant except block. except Exception: will catch json.JSONDecodeError, making the first except block unreachable. More importantly, catching Exception and silently setting payload = None can hide unexpected errors during request processing. It's better to handle only the expected json.JSONDecodeError.

Suggested change
try:
payload = await raw_request.json()
except json.JSONDecodeError:
payload = None
except Exception:
payload = None
else:
if isinstance(payload, dict):
profile_options = payload
try:
payload = await raw_request.json()
if isinstance(payload, dict):
profile_options = payload
except json.JSONDecodeError:
# It's okay if the request has no body or is not valid JSON.
pass

Comment on lines +234 to +242
if get_dp_group().world_size > 1:
hidden_states, routing_data, gather_idx, scatter_idx, rs_metadata = ep_routing(
hidden_states, gating_output, topk, sm_first=not renormalize, expt_assignment=expt_assignment,
group_name = get_dp_group().device_group,
)
else:
routing_data, gather_idx, scatter_idx = routing(
gating_output, topk, sm_first=not renormalize
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The variable rs_metadata is only defined within the if get_dp_group().world_size > 1: block, but it is used later in a separate if block with the same condition. While this is logically correct in the current structure, it's a fragile pattern. If the code is refactored, this could easily lead to an UnboundLocalError. It's better practice to initialize rs_metadata to None before the conditional blocks.

Suggested change
if get_dp_group().world_size > 1:
hidden_states, routing_data, gather_idx, scatter_idx, rs_metadata = ep_routing(
hidden_states, gating_output, topk, sm_first=not renormalize, expt_assignment=expt_assignment,
group_name = get_dp_group().device_group,
)
else:
routing_data, gather_idx, scatter_idx = routing(
gating_output, topk, sm_first=not renormalize
)
rs_metadata = None
if get_dp_group().world_size > 1:
hidden_states, routing_data, gather_idx, scatter_idx, rs_metadata = ep_routing(
hidden_states, gating_output, topk, sm_first=not renormalize, expt_assignment=expt_assignment,
group_name = get_dp_group().device_group,
)
else:
routing_data, gather_idx, scatter_idx = routing(
gating_output, topk, sm_first=not renormalize
)

name="matmul_ogs-w2",
metrics={
"flops": flops_matmul_ogs,
"bytes": num_bytes(w1) + token_num_bytes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation for the bytes metric in the matmul_ogs-w2 profiling scope appears to be a copy-paste error from the matmul_ogs-w1 scope. It incorrectly uses num_bytes(w1) instead of num_bytes(w2). Additionally, token_num_bytes is calculated based on hidden_states (the input to the first matmul), but the input to the second matmul is intermediate_cache1, which has a different shape and size. This will lead to incorrect profiling data for the second matmul operation.

Suggested change
"bytes": num_bytes(w1) + token_num_bytes
"bytes": num_bytes(w2) + num_bytes(intermediate_cache1)

def get_mxfp4_backend():
# Backend Selection
if current_platform.is_cuda():
return Mxfp4Backend.TRITON
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function get_mxfp4_backend is hardcoded to return Mxfp4Backend.TRITON. This makes all the subsequent backend selection logic unreachable. While this might be intentional for development, it's a significant change that bypasses platform-specific optimizations and should likely be removed or made conditional before merging.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +56 to +58
from vllm.distributed import get_dp_group, get_ep_group

from triton_kenerls.distributed import symm_mem_pool

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Fix typo in symm_mem_pool import

The new Triton MoE path imports symm_mem_pool from triton_kenerls.distributed, but the package name is misspelled ("kenerls"). As soon as this module is imported, Python will raise ModuleNotFoundError, preventing the MXFP4 Triton backend from loading at all. The typo is new in this commit and should be corrected to triton_kernels.

Useful? React with 👍 / 👎.

Comment on lines 80 to 84
def get_mxfp4_backend():
# Backend Selection
if current_platform.is_cuda():
return Mxfp4Backend.TRITON
if (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Remove unconditional TRITON backend override

The early return Mxfp4Backend.TRITON now executes immediately for all CUDA platforms, making every subsequent condition that selects FlashInfer or Marlin backends unreachable. As a result, Hopper/Blackwell GPUs will no longer pick the tuned FlashInfer kernels even when available, and environments without Triton kernels will still try to use the Triton backend. This regression was introduced in the new commit and will change behaviour for all users.

Useful? React with 👍 / 👎.

Comment on lines +202 to 213
if self.mxfp4_backend == Mxfp4Backend.TRITON and moe.dp_size != 0:

self.expt_assignment = create_expt_assignment(EP=moe.ep_size, n_expts_tot=moe.num_experts, device=torch.cuda.current_device())
self.symm_mem_pool = symm_mem_pool.initialize(
1024 * 1024,
dtype=torch.uint8,
device=torch.cuda.current_device(),
group=get_dp_group().device_group
)
# self.symm_handle = torch_symm_mem.rendezvous(self.symm_mem_pool, get_dp_group().device_group)


Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard Triton MoE attributes when dp_size is zero

In the constructor Mxfp4MoEMethod.__init__, the new expt_assignment and symm_mem_pool attributes are only created when moe.dp_size != 0. However, the Triton path in fused_moe always passes self.expt_assignment and self.symm_mem_pool to triton_kernel_moe_forward regardless of dp_size. When the model runs with a single data-parallel rank (dp_size == 0), these attributes are undefined and the first request will raise AttributeError. Either initialize them for the single-rank case or skip passing them.

Useful? React with 👍 / 👎.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

frontend gpt-oss Related to GPT-OSS models v1

Projects

Status: To Triage

Development

Successfully merging this pull request may close these issues.

2 participants