Skip to content

Commit a931b70

Browse files
authored
Merge pull request vllm-project#11 from dcmaddix/fused_moe_lora_cleanup
Fused moe lora cleanup
2 parents 9f68dca + e5eec7b commit a931b70

File tree

9 files changed

+25
-36
lines changed

9 files changed

+25
-36
lines changed

csrc/moe/moe_lora_align_sum_kernels.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
2121

2222
} // namespace
2323

24+
// TODO: Refactor common parts with moe_align_sum_kernels
2425
template <typename scalar_t, typename token_cnts_t>
2526
__global__ void moe_lora_align_sum_kernel(
2627
scalar_t* __restrict__ topk_ids, scalar_t* __restrict__ token_lora_mapping,

csrc/ops.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,6 @@ void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
133133
torch::Tensor& scale);
134134

135135
#ifndef USE_ROCM
136-
// #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
137-
// (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
138136
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
139137
torch::Tensor& output_block_scale,
140138
torch::Tensor& input,

csrc/torch_bindings.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
122122
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
123123
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
124124

125-
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
126-
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
125+
#ifndef USE_ROCM
127126
ops.def(
128127
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
129128
"Tensor input, Tensor input_global_scale) -> ()");

vllm/lora/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def activate_adapter(
417417
if module_lora:
418418
module_lora.optimize()
419419
# Note (gnovack) - If MOE lora weights are not split into
420-
# um_experts chunks, we split them here
420+
# num_experts chunks, we split them here
421421
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
422422
module_lora.lora_a
423423
):

vllm/lora/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
286286
return packed_modules_mapping
287287
else:
288288
raise AttributeError(
289-
"To support LoRA for MoE model, 'get_expert_mapping' must be implemented"
289+
"To support LoRA for MoE model, " \
290+
"'get_expert_mapping' must be implemented"
290291
)
291292
else:
292293
return get_packed_modules_mapping(model)

vllm/lora/worker_manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
9494
expected_lora_modules.extend(packed_modules_mapping[module])
9595
else:
9696
expected_lora_modules.append(module)
97-
# TODO(gnovack) - Attempting to load full-layer moe adapter
9897
if module == "experts":
9998
expected_lora_modules.append(module)
10099
expected_lora_modules = list(set(expected_lora_modules))

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -192,25 +192,23 @@ def fused_marlin_moe(
192192
is_zp_float=False,
193193
)
194194

195-
if activation_func is not None:
196-
activation_func(
197-
activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
198-
)
199-
else:
200-
if activation == "silu":
201-
torch.ops._C.silu_and_mul(
202-
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
203-
)
204-
elif activation == "swigluoai":
205-
# alpha = 1.702, limit = 7.0
206-
torch.ops._C.swigluoai_and_mul(
207-
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
208-
)
209-
else:
210-
raise ValueError(
211-
f"Unsupported activation: {activation}. "
212-
"Only silu and swigluoai activations are supported."
213-
)
195+
if activation_func is None:
196+
def activation_func(activation:str, output:torch.Tensor, input:torch.Tensor) -> None:
197+
if activation == "silu":
198+
torch.ops._C.silu_and_mul(
199+
output, input
200+
)
201+
elif activation == "swigluoai":
202+
# alpha = 1.702, limit = 7.0
203+
torch.ops._C.swigluoai_and_mul(
204+
output, input
205+
)
206+
else:
207+
raise ValueError(
208+
f"Unsupported activation: {activation}. "
209+
"Only silu and swigluoai activations are supported."
210+
)
211+
activation_func(activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
214212

215213
if expert_map is not None:
216214
intermediate_cache3.zero_()
@@ -425,7 +423,6 @@ def apply(
425423
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
426424
ops.moe_sum(input, output)
427425

428-
429426
def modular_marlin_fused_moe(
430427
quant_config: FusedMoEQuantConfig, shared_experts: Optional[torch.nn.Module] = None
431428
) -> mk.FusedMoEModularKernel:

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,12 +2047,6 @@ def apply(
20472047
)
20482048

20492049
invoke_fused_moe_kernel(
2050-
# The code `hidden_states` is not performing any specific action in
2051-
# the provided snippet. It seems to be a variable name or
2052-
# placeholder without any associated code or context.
2053-
# The code `hidden_states` is not performing any specific action in
2054-
# the provided snippet. It seems to be a variable or placeholder
2055-
# that has been declared but not used or assigned any value.
20562050
hidden_states,
20572051
w1,
20582052
intermediate_cache1,
@@ -2114,7 +2108,7 @@ def apply(
21142108
B_bias=self.w2_bias,
21152109
)
21162110

2117-
# ops.moe_sum(intermediate_cache3, output)
2111+
# separate function is required for MoE + LoRA
21182112
self.moe_sum(intermediate_cache3, output)
21192113

21202114
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:

vllm/model_executor/models/gpt_oss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -697,13 +697,13 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
697697
return logits
698698

699699
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
700-
# Params for weights, fp8 weight scales, fp8 activation scales
700+
# Params for weights, weight scales, activation scales
701701
# (param_name, weight_name, expert_id, shard_id)
702702
return FusedMoE.make_expert_params_mapping(
703703
ckpt_gate_proj_name="gate_proj",
704704
ckpt_down_proj_name="down_proj",
705705
ckpt_up_proj_name="up_proj",
706-
num_experts=self.config.num_local_experts, # FIXME: self.config.n_routed_experts if in config
706+
num_experts=self.config.num_local_experts,
707707
num_redundant_experts=0,
708708
)
709709

0 commit comments

Comments
 (0)