Skip to content

Commit e3b1266

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[BugFix] : Fix Batched DeepGemm Experts (#19515)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent e6aab5d commit e3b1266

File tree

9 files changed

+52
-32
lines changed

9 files changed

+52
-32
lines changed

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,21 @@ def workspace_shapes(
4747
N: int,
4848
K: int,
4949
topk: int,
50-
num_experts: int,
50+
global_num_experts: int,
51+
local_num_experts: int,
5152
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
5253
assert a.dim() == 2
53-
num_dp = self.world_size // self.dp_size
54+
# FIXME (varun): We should be able to dispatch only from the leader
55+
# DP ranks in the case of TP > 1. At the moment, all the Ranks
56+
# end up sending their tokens. This needs to be fixed.
57+
num_dispatchers = self.world_size
58+
num_experts = local_num_experts
5459
max_num_tokens = a.size(
5560
0) if self.max_num_tokens is None else self.max_num_tokens
56-
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
57-
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
58-
output = (num_experts, max_num_tokens * num_dp, K)
61+
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
62+
max(K, N))
63+
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
64+
output = (num_experts, max_num_tokens * num_dispatchers, K)
5965
return (workspace13, workspace2, output, a.dtype)
6066

6167
def apply(
@@ -84,9 +90,6 @@ def apply(
8490
a1q = hidden_states
8591
_, N, K = w1.size()
8692

87-
if global_num_experts == -1:
88-
global_num_experts = w1.size(0)
89-
9093
assert w2.size(1) == K
9194

9295
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,19 @@ def workspace_shapes(
8181
N: int,
8282
K: int,
8383
topk: int,
84-
num_experts: int,
84+
global_num_experts: int,
85+
local_num_experts: int,
8586
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
8687
# Note: the deep gemm workspaces are strictly larger than the triton
8788
# workspaces so we can be pessimistic here and allocate for DeepGemm
8889
# even if we fall back to triton later, e.g. if expert maps are set.
8990
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
9091
return self.batched_deep_gemm_experts.workspace_shapes(
91-
a, aq, M, N, K, topk, num_experts)
92+
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
9293
else:
9394
assert self.batched_triton_experts is not None
9495
return self.batched_triton_experts.workspace_shapes(
95-
a, aq, M, N, K, topk, num_experts)
96+
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
9697

9798
def apply(
9899
self,

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ def workspace_shapes(
230230
N: int,
231231
K: int,
232232
topk: int,
233-
num_experts: int,
233+
global_num_experts: int,
234+
local_num_experts: int,
234235
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
235236
workspace1: tuple[int, ...] = ()
236237
workspace2: tuple[int, ...] = ()

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,12 @@ def supports_chunking(self) -> bool:
7474
return True
7575

7676
def workspace_shapes(
77-
self,
78-
a: torch.Tensor,
79-
aq: torch.Tensor,
80-
M: int,
81-
N: int,
82-
K: int,
83-
topk: int,
84-
num_experts: int,
77+
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
78+
topk: int, global_num_experts: int, local_num_experts: int
8579
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
80+
# We use global_num_experts due to how moe_align_block_size handles
81+
# expert_maps.
82+
num_experts = global_num_experts
8683
block_m = self.block_shape[0]
8784
M_sum = (M * topk) + num_experts * (block_m - 1)
8885
M_sum = round_up(M_sum, block_m)

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -521,10 +521,12 @@ def workspace_shapes(
521521
N: int,
522522
K: int,
523523
topk: int,
524-
num_experts: int,
524+
global_num_experts: int,
525+
local_num_experts: int,
525526
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
526527
assert a.dim() == 2
527-
num_dp = self.world_size // self.dp_size
528+
num_dp = self.dp_size
529+
num_experts = local_num_experts
528530
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
529531
workspace2 = (self.max_num_tokens * num_dp, N)
530532
return (workspace13, workspace2, workspace13, a.dtype)
@@ -624,10 +626,12 @@ def workspace_shapes(
624626
N: int,
625627
K: int,
626628
topk: int,
627-
num_experts: int,
629+
global_num_experts: int,
630+
local_num_experts: int,
628631
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
629632
assert a.dim() == 2
630633
num_dp = self.world_size // self.dp_size
634+
num_experts = local_num_experts
631635
max_num_tokens = a.size(
632636
0) if self.max_num_tokens is None else self.max_num_tokens
633637
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1553,7 +1553,8 @@ def workspace_shapes(
15531553
N: int,
15541554
K: int,
15551555
topk: int,
1556-
num_experts: int,
1556+
global_num_experts: int,
1557+
local_num_experts: int,
15571558
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
15581559
workspace1 = (M, topk, max(N * 2, K))
15591560
workspace2 = (M, topk, N)

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ def workspace_shapes(
194194
N: int,
195195
K: int,
196196
topk: int,
197-
num_experts: int,
197+
global_num_experts: int,
198+
local_num_experts: int,
198199
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
199200
"""
200201
Compute the shapes for the temporary and final outputs of the two gemms
@@ -372,8 +373,9 @@ def forward(
372373
a1 = hidden_states
373374
output = a1 if inplace else torch.zeros_like(a1)
374375

376+
local_num_experts = w1.size(0)
375377
if global_num_experts == -1:
376-
global_num_experts = w1.size(0)
378+
global_num_experts = local_num_experts
377379

378380
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
379381
_expert_topk_weights) = self.prepare_finalize.prepare(
@@ -408,16 +410,19 @@ def forward(
408410
if num_chunks == 1:
409411
(workspace13_shape, workspace2_shape, fused_out_shape,
410412
workspace_dtype) = self.fused_experts.workspace_shapes(
411-
a1, a1q, M, N, K, top_k, global_num_experts)
413+
a1, a1q, M, N, K, top_k, global_num_experts,
414+
local_num_experts)
412415
else:
413416
# Use the full M to get the final output shape.
414417
_, _, fused_out_shape, _ = (
415418
self.fused_experts.workspace_shapes(
416-
a1, a1q, M, N, K, top_k, global_num_experts))
419+
a1, a1q, M, N, K, top_k, global_num_experts,
420+
local_num_experts))
417421
# Use the CHUNK_SIZE to get the workspace shapes.
418422
workspace13_shape, workspace2_shape, _, workspace_dtype = (
419423
self.fused_experts.workspace_shapes(
420-
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts))
424+
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts,
425+
local_num_experts))
421426

422427
# We can reuse the memory between cache1 and cache3 because by the
423428
# time we need cache3, we're done with cache1.

vllm/model_executor/layers/fused_moe/moe_align_block_size.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ def moe_align_block_size(
159159
Aligns the token distribution across experts to be compatible with block
160160
size for matrix multiplication.
161161
162+
Note: In the case of expert_parallel, moe_align_block_size initially
163+
considers all experts as valid and aligns all tokens appropriately.
164+
Before the function returns it marks the experts_ids that are not in
165+
the current GPU rank as -1 so the MoE matmuls could skip those blocks.
166+
This requires the num_experts input arg to be the num global experts.
167+
162168
Parameters:
163169
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
164170
top-k expert indices for each token.

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,20 @@ def workspace_shapes(
4848
N: int,
4949
K: int,
5050
topk: int,
51-
num_experts: int,
51+
global_num_experts: int,
52+
local_num_experts: int,
5253
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
5354
# Note: the deep gemm workspaces are strictly larger than the triton
5455
# workspaces so we can be pessimistic here and allocate for DeepGemm
5556
# even if we fall back to triton later, e.g. if expert maps are set.
5657
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
5758
assert self.deep_gemm_expert is not None
5859
return self.deep_gemm_expert.workspace_shapes(
59-
a, aq, M, N, K, topk, num_experts)
60+
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
6061
else:
6162
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
62-
num_experts)
63+
global_num_experts,
64+
local_num_experts)
6365

6466
def apply(
6567
self,

0 commit comments

Comments
 (0)