Skip to content

Commit 5ec1086

Browse files
committed
redo workspace allocation logic a bit
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 8f5170e commit 5ec1086

File tree

11 files changed

+113
-101
lines changed

11 files changed

+113
-101
lines changed

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
250250

251251
def workspace_shapes(
252252
self,
253-
curr_M: int,
254-
M: int,
253+
M_chunk: int,
254+
M_full: int,
255255
N: int,
256256
K: int,
257257
topk: int,
@@ -264,7 +264,7 @@ def workspace_shapes(
264264
# end up sending their tokens. This needs to be fixed.
265265
num_dispatchers = self.num_dispatchers
266266
num_experts = local_num_experts
267-
max_num_tokens = (curr_M if self.max_num_tokens is None else
267+
max_num_tokens = (M_chunk if self.max_num_tokens is None else
268268
self.max_num_tokens)
269269
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
270270
max(K, N))

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
9393

9494
def workspace_shapes(
9595
self,
96-
curr_M: int,
97-
M: int,
96+
M_chunk: int,
97+
M_full: int,
9898
N: int,
9999
K: int,
100100
topk: int,
@@ -108,13 +108,13 @@ def workspace_shapes(
108108
if self.allow_deep_gemm:
109109
assert self.batched_deep_gemm_experts is not None
110110
return self.batched_deep_gemm_experts.workspace_shapes(
111-
curr_M, M, N, K, topk, global_num_experts, local_num_experts,
112-
expert_tokens_metadata)
111+
M_chunk, M_full, N, K, topk, global_num_experts,
112+
local_num_experts, expert_tokens_metadata)
113113
else:
114114
assert self.batched_triton_experts is not None
115115
return self.batched_triton_experts.workspace_shapes(
116-
curr_M, M, N, K, topk, global_num_experts, local_num_experts,
117-
expert_tokens_metadata)
116+
M_chunk, M_full, N, K, topk, global_num_experts,
117+
local_num_experts, expert_tokens_metadata)
118118

119119
def apply(
120120
self,

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -313,18 +313,18 @@ def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
313313

314314
def workspace_shapes(
315315
self,
316-
curr_M: int,
317-
M: int,
316+
M_chunk: int,
317+
M_full: int,
318318
N: int,
319319
K: int,
320320
topk: int,
321321
global_num_experts: int,
322322
local_num_experts: int,
323323
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
324324
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
325-
workspace1 = (curr_M * topk, max(N, K))
326-
workspace2 = (curr_M * topk, max(N // 2, K))
327-
output = (M, K)
325+
workspace1 = (M_chunk * topk, max(N, K))
326+
workspace2 = (M_chunk * topk, max(N // 2, K))
327+
output = (M_full, K)
328328
return (workspace1, workspace2, output)
329329

330330

@@ -371,8 +371,8 @@ def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
371371

372372
def workspace_shapes(
373373
self,
374-
curr_M: int,
375-
M: int,
374+
M_chunk: int,
375+
M_full: int,
376376
N: int,
377377
K: int,
378378
topk: int,
@@ -382,9 +382,11 @@ def workspace_shapes(
382382
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
383383
num_dp = self.num_dispatchers
384384
assert num_dp is not None
385-
workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
386-
workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K))
387-
output = (self.max_experts_per_worker, M, K)
385+
assert M_chunk == M_full
386+
workspace1 = (self.max_experts_per_worker, M_full * num_dp, max(N, K))
387+
workspace2 = (self.max_experts_per_worker, M_full * num_dp,
388+
max(N // 2, K))
389+
output = (self.max_experts_per_worker, M_full, K)
388390
return (workspace1, workspace2, output)
389391

390392

@@ -670,8 +672,8 @@ def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
670672

671673
def workspace_shapes(
672674
self,
673-
curr_M: int,
674-
M: int,
675+
M_chunk: int,
676+
M_full: int,
675677
N: int,
676678
K: int,
677679
topk: int,
@@ -683,13 +685,14 @@ def workspace_shapes(
683685
workspace2: tuple[int, ...] = ()
684686
output: tuple[int, ...] = ()
685687
if self.use_batched_format:
686-
workspace1 = (self.max_experts_per_worker, M, max(N, K))
687-
workspace2 = (self.max_experts_per_worker, M, (N // 2))
688-
output = (self.max_experts_per_worker, M, K)
688+
assert M_chunk == M_full
689+
workspace1 = (self.max_experts_per_worker, M_full, max(N, K))
690+
workspace2 = (self.max_experts_per_worker, M_full, (N // 2))
691+
output = (self.max_experts_per_worker, M_full, K)
689692
else:
690-
workspace1 = (curr_M * topk, max(2 * N, K))
691-
workspace2 = (curr_M * topk, N)
692-
output = (M, K)
693+
workspace1 = (M_chunk * topk, max(2 * N, K))
694+
workspace2 = (M_chunk * topk, N)
695+
output = (M_full, K)
693696
return (workspace1, workspace2, output)
694697

695698
def apply(

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
181181

182182
def workspace_shapes(
183183
self,
184-
curr_M: int,
185-
M: int,
184+
M_chunk: int,
185+
M_full: int,
186186
N: int,
187187
K: int,
188188
topk: int,
@@ -192,13 +192,13 @@ def workspace_shapes(
192192
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
193193
assert self.block_shape is not None
194194
block_m = self.block_shape[0]
195-
M_sum = compute_aligned_M(curr_M, topk, local_num_experts, block_m,
195+
M_sum = compute_aligned_M(M_chunk, topk, local_num_experts, block_m,
196196
expert_tokens_meta)
197197
assert M_sum % block_m == 0
198198

199199
workspace1 = (M_sum, max(N, K))
200200
workspace2 = (M_sum, max(N // 2, K))
201-
output = (M, K)
201+
output = (M_full, K)
202202
return (workspace1, workspace2, output)
203203

204204
def apply(

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
8181

8282
def workspace_shapes(
8383
self,
84-
curr_M: int,
85-
M: int,
84+
M_chunk: int,
85+
M_full: int,
8686
N: int,
8787
K: int,
8888
topk: int,
@@ -108,9 +108,9 @@ def workspace_shapes(
108108
- Note: in order for activation chunking to work, the first dimension
109109
of each tuple must be the number of tokens.
110110
"""
111-
workspace1 = (curr_M, K)
111+
workspace1 = (M_chunk, K)
112112
workspace2 = (0, )
113-
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
113+
output_shape = (M_full, K * 2 if self.quant_dtype == "nvfp4" else K)
114114
# The workspace is determined by `aq`, since it comes after any
115115
# potential communication op and is involved in the expert computation.
116116
return (workspace1, workspace2, output_shape)

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,8 +652,8 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
652652

653653
def workspace_shapes(
654654
self,
655-
curr_M: int,
656-
M: int,
655+
M_chunk: int,
656+
M_full: int,
657657
N: int,
658658
K: int,
659659
topk: int,
@@ -850,8 +850,8 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
850850

851851
def workspace_shapes(
852852
self,
853-
curr_M: int,
854-
M: int,
853+
M_chunk: int,
854+
M_full: int,
855855
N: int,
856856
K: int,
857857
topk: int,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,18 +1729,18 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
17291729

17301730
def workspace_shapes(
17311731
self,
1732-
curr_M: int,
1733-
M: int,
1732+
M_chunk: int,
1733+
M_full: int,
17341734
N: int,
17351735
K: int,
17361736
topk: int,
17371737
global_num_experts: int,
17381738
local_num_experts: int,
17391739
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
17401740
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
1741-
workspace1 = (curr_M, topk, max(N // 2, K))
1742-
workspace2 = (curr_M, topk, max(N, K))
1743-
output = (M, K)
1741+
workspace1 = (M_chunk, topk, max(N // 2, K))
1742+
workspace2 = (M_chunk, topk, max(N, K))
1743+
output = (M_full, K)
17441744
return (workspace1, workspace2, output)
17451745

17461746
def apply(

vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ def supports_chunking(self) -> bool:
252252

253253
def workspace_shapes(
254254
self,
255-
curr_M: int,
256-
M: int,
255+
M_chunk: int,
256+
M_full: int,
257257
N: int,
258258
K: int,
259259
topk: int,
@@ -262,9 +262,9 @@ def workspace_shapes(
262262
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
263263
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
264264
# workspace are allocated inside the kernel
265-
workspace1 = (M, K)
265+
workspace1 = (M_chunk, K)
266266
workspace2 = (0, 0)
267-
output = (M, K)
267+
output = (M_full, K)
268268
return (workspace1, workspace2, output)
269269

270270
def apply(

0 commit comments

Comments
 (0)