Skip to content

[BugFix] : Fix Batched DeepGemm Experts #19515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,21 @@ def workspace_shapes(
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.world_size
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
output = (num_experts, max_num_tokens * num_dp, K)
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)

def apply(
Expand Down Expand Up @@ -84,9 +90,6 @@ def apply(
a1q = hidden_states
_, N, K = w1.size()

if global_num_experts == -1:
global_num_experts = w1.size(0)

assert w2.size(1) == K

E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,19 @@ def workspace_shapes(
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
return self.batched_deep_gemm_experts.workspace_shapes(
a, aq, M, N, K, topk, num_experts)
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
else:
assert self.batched_triton_experts is not None
return self.batched_triton_experts.workspace_shapes(
a, aq, M, N, K, topk, num_experts)
a, aq, M, N, K, topk, global_num_experts, local_num_experts)

def apply(
self,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def workspace_shapes(
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
Expand Down
13 changes: 5 additions & 8 deletions vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,12 @@ def supports_chunking(self) -> bool:
return True

def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
num_experts = global_num_experts
block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/layers/fused_moe/fused_batched_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,10 +521,12 @@ def workspace_shapes(
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
num_dp = self.dp_size
num_experts = local_num_experts
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
workspace2 = (self.max_num_tokens * num_dp, N)
return (workspace13, workspace2, workspace13, a.dtype)
Expand Down Expand Up @@ -624,10 +626,12 @@ def workspace_shapes(
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,7 +1553,8 @@ def workspace_shapes(
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M, topk, max(N * 2, K))
workspace2 = (M, topk, N)
Expand Down
15 changes: 10 additions & 5 deletions vllm/model_executor/layers/fused_moe/modular_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ def workspace_shapes(
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
Expand Down Expand Up @@ -372,8 +373,9 @@ def forward(
a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1)

local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = w1.size(0)
global_num_experts = local_num_experts

(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
Expand Down Expand Up @@ -408,16 +410,19 @@ def forward(
if num_chunks == 1:
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts)
a1, a1q, M, N, K, top_k, global_num_experts,
local_num_experts)
else:
# Use the full M to get the final output shape.
_, _, fused_out_shape, _ = (
self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts))
a1, a1q, M, N, K, top_k, global_num_experts,
local_num_experts))
# Use the CHUNK_SIZE to get the workspace shapes.
workspace13_shape, workspace2_shape, _, workspace_dtype = (
self.fused_experts.workspace_shapes(
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts))
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts,
local_num_experts))

# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/layers/fused_moe/moe_align_block_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def moe_align_block_size(
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.

Note: In the case of expert_parallel, moe_align_block_size initially
considers all experts as valid and aligns all tokens appropriately.
Before the function returns it marks the experts_ids that are not in
the current GPU rank as -1 so the MoE matmuls could skip those blocks.
This requires the num_experts input arg to be the num global experts.

Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,20 @@ def workspace_shapes(
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
a, aq, M, N, K, topk, num_experts)
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
else:
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
num_experts)
global_num_experts,
local_num_experts)

def apply(
self,
Expand Down