-
-
Couldn't load subscription status.
- Fork 10.9k
[WIP] [GPT-OSS] customized symm_mem based EP comm kernel integration #27495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request integrates Triton kernels for MoE communication, which is a significant performance enhancement. The changes are extensive, touching profiling, environment variable handling, and the core MoE logic. My review focuses on correctness and potential issues. I've identified a critical import typo, a bug in profiling metric calculation, a hardcoded backend selection that should be revisited, and a couple of areas where the code could be made more robust to prevent future errors. Overall, this is a promising direction for improving MoE performance.
| from vllm.utils.torch_utils import is_torch_equal_or_newer | ||
| from vllm.distributed import get_dp_group, get_ep_group | ||
|
|
||
| from triton_kenerls.distributed import symm_mem_pool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| try: | ||
| payload = await raw_request.json() | ||
| except json.JSONDecodeError: | ||
| payload = None | ||
| except Exception: | ||
| payload = None | ||
| else: | ||
| if isinstance(payload, dict): | ||
| profile_options = payload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exception handling for parsing the JSON payload is too broad and contains a redundant except block. except Exception: will catch json.JSONDecodeError, making the first except block unreachable. More importantly, catching Exception and silently setting payload = None can hide unexpected errors during request processing. It's better to handle only the expected json.JSONDecodeError.
| try: | |
| payload = await raw_request.json() | |
| except json.JSONDecodeError: | |
| payload = None | |
| except Exception: | |
| payload = None | |
| else: | |
| if isinstance(payload, dict): | |
| profile_options = payload | |
| try: | |
| payload = await raw_request.json() | |
| if isinstance(payload, dict): | |
| profile_options = payload | |
| except json.JSONDecodeError: | |
| # It's okay if the request has no body or is not valid JSON. | |
| pass |
| if get_dp_group().world_size > 1: | ||
| hidden_states, routing_data, gather_idx, scatter_idx, rs_metadata = ep_routing( | ||
| hidden_states, gating_output, topk, sm_first=not renormalize, expt_assignment=expt_assignment, | ||
| group_name = get_dp_group().device_group, | ||
| ) | ||
| else: | ||
| routing_data, gather_idx, scatter_idx = routing( | ||
| gating_output, topk, sm_first=not renormalize | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable rs_metadata is only defined within the if get_dp_group().world_size > 1: block, but it is used later in a separate if block with the same condition. While this is logically correct in the current structure, it's a fragile pattern. If the code is refactored, this could easily lead to an UnboundLocalError. It's better practice to initialize rs_metadata to None before the conditional blocks.
| if get_dp_group().world_size > 1: | |
| hidden_states, routing_data, gather_idx, scatter_idx, rs_metadata = ep_routing( | |
| hidden_states, gating_output, topk, sm_first=not renormalize, expt_assignment=expt_assignment, | |
| group_name = get_dp_group().device_group, | |
| ) | |
| else: | |
| routing_data, gather_idx, scatter_idx = routing( | |
| gating_output, topk, sm_first=not renormalize | |
| ) | |
| rs_metadata = None | |
| if get_dp_group().world_size > 1: | |
| hidden_states, routing_data, gather_idx, scatter_idx, rs_metadata = ep_routing( | |
| hidden_states, gating_output, topk, sm_first=not renormalize, expt_assignment=expt_assignment, | |
| group_name = get_dp_group().device_group, | |
| ) | |
| else: | |
| routing_data, gather_idx, scatter_idx = routing( | |
| gating_output, topk, sm_first=not renormalize | |
| ) |
| name="matmul_ogs-w2", | ||
| metrics={ | ||
| "flops": flops_matmul_ogs, | ||
| "bytes": num_bytes(w1) + token_num_bytes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation for the bytes metric in the matmul_ogs-w2 profiling scope appears to be a copy-paste error from the matmul_ogs-w1 scope. It incorrectly uses num_bytes(w1) instead of num_bytes(w2). Additionally, token_num_bytes is calculated based on hidden_states (the input to the first matmul), but the input to the second matmul is intermediate_cache1, which has a different shape and size. This will lead to incorrect profiling data for the second matmul operation.
| "bytes": num_bytes(w1) + token_num_bytes | |
| "bytes": num_bytes(w2) + num_bytes(intermediate_cache1) |
| def get_mxfp4_backend(): | ||
| # Backend Selection | ||
| if current_platform.is_cuda(): | ||
| return Mxfp4Backend.TRITON |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function get_mxfp4_backend is hardcoded to return Mxfp4Backend.TRITON. This makes all the subsequent backend selection logic unreachable. While this might be intentional for development, it's a significant change that bypasses platform-specific optimizations and should likely be removed or made conditional before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| from vllm.distributed import get_dp_group, get_ep_group | ||
|
|
||
| from triton_kenerls.distributed import symm_mem_pool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix typo in symm_mem_pool import
The new Triton MoE path imports symm_mem_pool from triton_kenerls.distributed, but the package name is misspelled ("kenerls"). As soon as this module is imported, Python will raise ModuleNotFoundError, preventing the MXFP4 Triton backend from loading at all. The typo is new in this commit and should be corrected to triton_kernels.
Useful? React with 👍 / 👎.
| def get_mxfp4_backend(): | ||
| # Backend Selection | ||
| if current_platform.is_cuda(): | ||
| return Mxfp4Backend.TRITON | ||
| if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unconditional TRITON backend override
The early return Mxfp4Backend.TRITON now executes immediately for all CUDA platforms, making every subsequent condition that selects FlashInfer or Marlin backends unreachable. As a result, Hopper/Blackwell GPUs will no longer pick the tuned FlashInfer kernels even when available, and environments without Triton kernels will still try to use the Triton backend. This regression was introduced in the new commit and will change behaviour for all users.
Useful? React with 👍 / 👎.
| if self.mxfp4_backend == Mxfp4Backend.TRITON and moe.dp_size != 0: | ||
|
|
||
| self.expt_assignment = create_expt_assignment(EP=moe.ep_size, n_expts_tot=moe.num_experts, device=torch.cuda.current_device()) | ||
| self.symm_mem_pool = symm_mem_pool.initialize( | ||
| 1024 * 1024, | ||
| dtype=torch.uint8, | ||
| device=torch.cuda.current_device(), | ||
| group=get_dp_group().device_group | ||
| ) | ||
| # self.symm_handle = torch_symm_mem.rendezvous(self.symm_mem_pool, get_dp_group().device_group) | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard Triton MoE attributes when dp_size is zero
In the constructor Mxfp4MoEMethod.__init__, the new expt_assignment and symm_mem_pool attributes are only created when moe.dp_size != 0. However, the Triton path in fused_moe always passes self.expt_assignment and self.symm_mem_pool to triton_kernel_moe_forward regardless of dp_size. When the model runs with a single data-parallel rank (dp_size == 0), these attributes are undefined and the first request will raise AttributeError. Either initialize them for the single-rank case or skip passing them.
Useful? React with 👍 / 👎.
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
Integrates Triton MoE communication kernels
achieved 60% throughput increase compared to #24588