Skip to content

Commit 2155672

Browse files
committed
cleanup
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent b3d2ea5 commit 2155672

11 files changed

+42
-30
lines changed

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def apply(
284284
global_num_experts: int,
285285
expert_map: Optional[torch.Tensor],
286286
a1q_scale: Optional[torch.Tensor],
287+
a2_scale: Optional[torch.Tensor],
287288
workspace13: torch.Tensor,
288289
workspace2: torch.Tensor,
289290
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def apply(
128128
global_num_experts: int,
129129
expert_map: Optional[torch.Tensor],
130130
a1q_scale: Optional[torch.Tensor],
131+
a2_scale: Optional[torch.Tensor],
131132
workspace13: torch.Tensor,
132133
workspace2: torch.Tensor,
133134
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@@ -138,5 +139,5 @@ def apply(
138139
assert experts is not None
139140
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
140141
activation, global_num_experts, expert_map, a1q_scale,
141-
workspace13, workspace2, expert_tokens_meta,
142+
a2_scale, workspace13, workspace2, expert_tokens_meta,
142143
apply_router_weight_on_input)

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def apply(
241241
global_num_experts: int,
242242
expert_map: Optional[torch.Tensor],
243243
a1q_scale: Optional[torch.Tensor],
244+
a2_scale: Optional[torch.Tensor],
244245
workspace13: torch.Tensor,
245246
workspace2: torch.Tensor,
246247
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@@ -262,7 +263,7 @@ def apply(
262263
run_cutlass_moe_fp8(
263264
output, hidden_states, w1, w2, topk_ids, activation_callable,
264265
global_num_experts, expert_map, self.w1_scale, self.w2_scale,
265-
a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2,
266+
a1q_scale, a2_scale, self.ab_strides1, self.ab_strides2,
266267
self.c_strides1, self.c_strides2, workspace13, workspace2,
267268
expert_num_tokens,
268269
self.out_dtype if self.out_dtype is not None else in_dtype,
@@ -703,6 +704,7 @@ def apply(
703704
global_num_experts: int,
704705
expert_map: Optional[torch.Tensor],
705706
a1q_scale: Optional[torch.Tensor], # unused
707+
a2_scale: Optional[torch.Tensor], # unused
706708
workspace13: Optional[torch.Tensor],
707709
workspace2: Optional[torch.Tensor],
708710
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,14 @@ def apply(
213213
global_num_experts: int,
214214
expert_map: Optional[torch.Tensor],
215215
a1q_scale: Optional[torch.Tensor],
216+
a2_scale: Optional[torch.Tensor],
216217
workspace13: torch.Tensor,
217218
workspace2: torch.Tensor,
218219
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
219220
apply_router_weight_on_input: bool,
220221
):
221222
assert a1q_scale is not None
222-
assert self.a2_scale is None
223+
assert a2_scale is None
223224
assert self.block_shape is not None
224225
assert self.w1_scale is not None
225226
assert self.w2_scale is not None

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def apply(
125125
global_num_experts: int,
126126
expert_map: Optional[torch.Tensor],
127127
a1q_scale: Optional[torch.Tensor],
128+
a2_scale: Optional[torch.Tensor],
128129
workspace13: Optional[torch.Tensor],
129130
workspace2: Optional[torch.Tensor],
130131
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ def apply(
686686
global_num_experts: int,
687687
expert_map: Optional[torch.Tensor],
688688
a1q_scale: Optional[torch.Tensor],
689+
a2_scale: Optional[torch.Tensor],
689690
workspace13: torch.Tensor,
690691
workspace2: torch.Tensor,
691692
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@@ -875,6 +876,7 @@ def apply(
875876
global_num_experts: int,
876877
expert_map: Optional[torch.Tensor],
877878
a1q_scale: Optional[torch.Tensor],
879+
a2_scale: Optional[torch.Tensor],
878880
workspace13: torch.Tensor,
879881
workspace2: torch.Tensor,
880882
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@@ -966,7 +968,7 @@ def apply(
966968
intermediate_cache1.view(-1, N))
967969

968970
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
969-
intermediate_cache2, self.a2_scale, max_num_tokens, E, N,
971+
intermediate_cache2, a2_scale, max_num_tokens, E, N,
970972
expert_num_tokens, self.quant_dtype, self.per_act_token_quant,
971973
self.block_shape)
972974

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,7 @@ def apply(
15971597
global_num_experts: int,
15981598
expert_map: Optional[torch.Tensor],
15991599
a1q_scale: Optional[torch.Tensor],
1600+
a2_scale: Optional[torch.Tensor],
16001601
workspace13: torch.Tensor,
16011602
workspace2: torch.Tensor,
16021603
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@@ -1689,7 +1690,7 @@ def apply(
16891690
a2q_scale: Optional[torch.Tensor] = None
16901691

16911692
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
1692-
intermediate_cache2, self.a2_scale, self.quant_dtype,
1693+
intermediate_cache2, a2_scale, self.quant_dtype,
16931694
self.per_act_token_quant, self.block_shape)
16941695

16951696
invoke_fused_moe_kernel(

vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def apply(
184184
global_num_experts: int,
185185
expert_map: Optional[torch.Tensor],
186186
a1q_scale: Optional[torch.Tensor],
187+
a2_scale: Optional[torch.Tensor],
187188
workspace13: torch.Tensor,
188189
workspace2: torch.Tensor,
189190
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _moe_problem_size(
8585
M = a1.size(0)
8686
else:
8787
assert a1.dim() == 3
88-
#assert a1.size(0) == E, f"{a1.size(0)} == {E}"
88+
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
8989
M = a1.size(1) # This is max_num_tokens
9090

9191
assert topk_ids.dim() == 2
@@ -536,11 +536,12 @@ def apply(
536536
global_num_experts: int,
537537
expert_map: Optional[torch.Tensor],
538538
a1q_scale: Optional[torch.Tensor],
539+
a2_scale: Optional[torch.Tensor],
539540
workspace13: torch.Tensor,
540541
workspace2: torch.Tensor,
541542
expert_tokens_meta: Optional[ExpertTokensMetadata],
542543
apply_router_weight_on_input: bool,
543-
):
544+
) -> None:
544545
"""
545546
This function computes the intermediate result of a Mixture of Experts
546547
(MoE) layer using two sets of weights, w1 and w2.
@@ -674,22 +675,22 @@ def _allocate_buffers(
674675

675676
# We can reuse the memory between cache1 and cache3 because by the
676677
# time we need cache3, we're done with cache1.
677-
workspace13 = torch.zeros(prod(workspace13_shape),
678-
device=device,
679-
dtype=workspace_dtype)
680-
workspace2 = torch.zeros(prod(workspace2_shape),
681-
device=device,
682-
dtype=workspace_dtype)
678+
workspace13 = self.workspace13_buffer.get(workspace13_shape,
679+
device=device,
680+
dtype=workspace_dtype)
681+
workspace2 = self.workspace2_buffer.get(workspace2_shape,
682+
device=device,
683+
dtype=workspace_dtype)
683684

684685
# Construct the entire output that can then be processed in chunks.
685686
if num_chunks == 1 and prod(workspace13_shape) >= prod(
686687
fused_out_shape):
687688
# Reuse workspace13 for the output in the non-chunked case.
688689
fused_out = _resize_cache(workspace13, fused_out_shape)
689690
else:
690-
fused_out = torch.empty(fused_out_shape,
691-
device=device,
692-
dtype=out_dtype)
691+
fused_out = self.fused_out_buffer.get(fused_out_shape,
692+
device=device,
693+
dtype=out_dtype)
693694

694695
return workspace13, workspace2, fused_out
695696

@@ -785,7 +786,10 @@ def forward(
785786
- torch.Tensor: The output tensor after applying the MoE layer.
786787
"""
787788

788-
output = hidden_states if inplace else torch.zeros_like(hidden_states)
789+
if inplace and self.shared_experts is None:
790+
output = hidden_states
791+
else:
792+
output = torch.zeros_like(hidden_states)
789793

790794
local_num_experts = w1.size(0)
791795
if global_num_experts == -1:
@@ -799,8 +803,6 @@ def forward(
799803
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
800804
_expert_topk_weights) = self.prepare_finalize.prepare(
801805
hidden_states,
802-
a1_scale,
803-
a2_scale,
804806
topk_weights,
805807
topk_ids,
806808
global_num_experts,
@@ -810,10 +812,9 @@ def forward(
810812
)
811813
else:
812814
# Overlap shared expert compute with all2all dispatch.
813-
receiver = self.prepare_finalize.prepare_async(
815+
dbo_maybe_run_recv_hook()
816+
hook, receiver = self.prepare_finalize.prepare_async(
814817
hidden_states,
815-
a1_scale,
816-
a2_scale,
817818
topk_weights,
818819
topk_ids,
819820
global_num_experts,
@@ -838,6 +839,8 @@ def forward(
838839
topk_weights = (topk_weights if _expert_topk_weights is None else
839840
_expert_topk_weights)
840841

842+
fused_out = None
843+
841844
if a1q.numel() == 0:
842845
# This happens when none of the tokens from the all2all reach this
843846
# EP rank. Also, note that this is only relevant for CUDAGraph
@@ -853,7 +856,7 @@ def forward(
853856
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
854857
num_chunks = cdiv(M, CHUNK_SIZE)
855858
else:
856-
CHUNK_SIZE = M #a1q.size(0)
859+
CHUNK_SIZE = M #a1q.size(0)
857860
num_chunks = 1
858861

859862
def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
@@ -892,12 +895,8 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
892895
activation=activation,
893896
global_num_experts=global_num_experts,
894897
expert_map=expert_map,
895-
w1_scale=w1_scale,
896-
w2_scale=w2_scale,
897-
w1_zp=w1_zp,
898-
w2_zp=w2_zp,
899898
a1q_scale=_chunk_scales(a1q_scale, s, e),
900-
a2_scale=_chunk_scales(a2_scale, e, e),
899+
a2_scale=_chunk_scales(self.fused_experts.a2_scale, e, e),
901900
workspace13=workspace13,
902901
workspace2=workspace2,
903902
expert_tokens_meta=c_expert_tokens_meta,
@@ -918,7 +917,7 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
918917
self.fused_experts.finalize_weight_and_reduce_impl(),
919918
)
920919
if self.shared_experts is not None:
921-
shared_output = self.shared_experts(a1)
920+
shared_output = self.shared_experts(hidden_states)
922921
else:
923922
recv_hook = self.prepare_finalize.finalize_async(
924923
output,
@@ -930,7 +929,7 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
930929
)
931930

932931
if self.shared_experts is not None:
933-
shared_output = self.shared_experts(a1)
932+
shared_output = self.shared_experts(hidden_states)
934933

935934
assert recv_hook is not None
936935
dbo_register_recv_hook(recv_hook)

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def apply(
110110
global_num_experts: int,
111111
expert_map: Optional[torch.Tensor],
112112
a1q_scale: Optional[torch.Tensor],
113+
a2_scale: Optional[torch.Tensor],
113114
workspace13: torch.Tensor,
114115
workspace2: torch.Tensor,
115116
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@@ -133,6 +134,7 @@ def apply(
133134
global_num_experts,
134135
expert_map,
135136
a1q_scale,
137+
a2_scale,
136138
workspace13,
137139
workspace2,
138140
expert_tokens_meta,

0 commit comments

Comments
 (0)