Skip to content

Commit 0687656

Browse files
committed
Fix the incorrect retrieval of max_loras.
Signed-off-by: Chen Wu <cntryroa@gmail.com>
1 parent 6e245f8 commit 0687656

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

vllm/lora/layers/fused_moe.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def wrapper(*args, **kwargs):
7979
moe_state_dict["apply_router_weight_on_input"] = kwargs[
8080
"apply_router_weight_on_input"
8181
]
82+
moe_state_dict["max_loras"] = layer.w1_lora_a_stacked.shape[0]
8283
result = func(*args, **kwargs)
8384
return result
8485

@@ -93,6 +94,7 @@ def wrapper(*args, **kwargs):
9394
curr_topk_ids = moe_state_dict["topk_ids"]
9495
global_num_experts = moe_state_dict["global_num_experts"]
9596
expert_map = moe_state_dict["expert_map"]
97+
max_loras = moe_state_dict["max_loras"]
9698

9799
(token_lora_mapping, _, _, _, _, _) = (
98100
self.punica_wrapper.token_mapping_meta.meta_args(
@@ -128,7 +130,7 @@ def wrapper(*args, **kwargs):
128130
token_lora_mapping,
129131
config["BLOCK_SIZE_M"],
130132
global_num_experts,
131-
curr_topk_ids.shape[-1],
133+
max_loras,
132134
expert_map,
133135
)
134136

@@ -141,10 +143,8 @@ def wrapper(*args, **kwargs):
141143
w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
142144
w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
143145
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
144-
expert_ids_lora = expert_ids_lora.view(curr_topk_ids.shape[-1], -1)
145-
sorted_token_ids_lora = sorted_token_ids_lora.view(
146-
curr_topk_ids.shape[-1], -1
147-
)
146+
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
147+
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
148148

149149
self.punica_wrapper.add_lora_fused_moe(
150150
input.view(-1, top_k, input.shape[-1]),
@@ -172,6 +172,7 @@ def wrapper(*args, **kwargs):
172172
hidden_states = moe_state_dict["hidden_states"]
173173
topk_weights = moe_state_dict["topk_weights"]
174174
curr_topk_ids = moe_state_dict["topk_ids"]
175+
max_loras = moe_state_dict["max_loras"]
175176

176177
config_dtype = _get_config_dtype_str(
177178
dtype=hidden_states.dtype,
@@ -200,10 +201,8 @@ def wrapper(*args, **kwargs):
200201
"num_tokens_post_padded_lora"
201202
]
202203

203-
expert_ids_lora = expert_ids_lora.view(curr_topk_ids.shape[-1], -1)
204-
sorted_token_ids_lora = sorted_token_ids_lora.view(
205-
curr_topk_ids.shape[-1], -1
206-
)
204+
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
205+
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
207206
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
208207
intermediate_cache3 = args[0]
209208
max_lora_rank = self.w1_lora_a_stacked.shape[-2]

vllm/lora/layers/replicated_linear.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,15 @@ def can_replace_layer(
5757
model_config: Optional[PretrainedConfig],
5858
) -> bool:
5959
return type(source_layer) is ReplicatedLinear
60+
61+
def slice_lora_a(
62+
self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
63+
) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
64+
"""Slice lora a if splitting for tensor parallelism."""
65+
return lora_a
66+
67+
def slice_lora_b(
68+
self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
69+
) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
70+
"""Slice lora b if splitting with tensor parallelism."""
71+
return lora_b

0 commit comments

Comments
 (0)