@@ -79,6 +79,7 @@ def wrapper(*args, **kwargs):
7979 moe_state_dict ["apply_router_weight_on_input" ] = kwargs [
8080 "apply_router_weight_on_input"
8181 ]
82+ moe_state_dict ["max_loras" ] = layer .w1_lora_a_stacked .shape [0 ]
8283 result = func (* args , ** kwargs )
8384 return result
8485
@@ -93,6 +94,7 @@ def wrapper(*args, **kwargs):
9394 curr_topk_ids = moe_state_dict ["topk_ids" ]
9495 global_num_experts = moe_state_dict ["global_num_experts" ]
9596 expert_map = moe_state_dict ["expert_map" ]
97+ max_loras = moe_state_dict ["max_loras" ]
9698
9799 (token_lora_mapping , _ , _ , _ , _ , _ ) = (
98100 self .punica_wrapper .token_mapping_meta .meta_args (
@@ -128,7 +130,7 @@ def wrapper(*args, **kwargs):
128130 token_lora_mapping ,
129131 config ["BLOCK_SIZE_M" ],
130132 global_num_experts ,
131- curr_topk_ids . shape [ - 1 ] ,
133+ max_loras ,
132134 expert_map ,
133135 )
134136
@@ -141,10 +143,8 @@ def wrapper(*args, **kwargs):
141143 w13_lora_a_stacked = [self .w1_lora_a_stacked , self .w3_lora_a_stacked ]
142144 w13_lora_b_stacked = [self .w1_lora_b_stacked , self .w3_lora_b_stacked ]
143145 max_lora_rank = self .w1_lora_a_stacked .shape [- 2 ]
144- expert_ids_lora = expert_ids_lora .view (curr_topk_ids .shape [- 1 ], - 1 )
145- sorted_token_ids_lora = sorted_token_ids_lora .view (
146- curr_topk_ids .shape [- 1 ], - 1
147- )
146+ expert_ids_lora = expert_ids_lora .view (max_loras , - 1 )
147+ sorted_token_ids_lora = sorted_token_ids_lora .view (max_loras , - 1 )
148148
149149 self .punica_wrapper .add_lora_fused_moe (
150150 input .view (- 1 , top_k , input .shape [- 1 ]),
@@ -172,6 +172,7 @@ def wrapper(*args, **kwargs):
172172 hidden_states = moe_state_dict ["hidden_states" ]
173173 topk_weights = moe_state_dict ["topk_weights" ]
174174 curr_topk_ids = moe_state_dict ["topk_ids" ]
175+ max_loras = moe_state_dict ["max_loras" ]
175176
176177 config_dtype = _get_config_dtype_str (
177178 dtype = hidden_states .dtype ,
@@ -200,10 +201,8 @@ def wrapper(*args, **kwargs):
200201 "num_tokens_post_padded_lora"
201202 ]
202203
203- expert_ids_lora = expert_ids_lora .view (curr_topk_ids .shape [- 1 ], - 1 )
204- sorted_token_ids_lora = sorted_token_ids_lora .view (
205- curr_topk_ids .shape [- 1 ], - 1
206- )
204+ expert_ids_lora = expert_ids_lora .view (max_loras , - 1 )
205+ sorted_token_ids_lora = sorted_token_ids_lora .view (max_loras , - 1 )
207206 intermediate_cache2 = moe_state_dict ["intermediate_cache2" ]
208207 intermediate_cache3 = args [0 ]
209208 max_lora_rank = self .w1_lora_a_stacked .shape [- 2 ]
0 commit comments