@@ -332,6 +332,31 @@ def get_default_config(
332332 return config
333333
334334
335+ def try_get_optimal_moe_config (
336+ w1_shape : Tuple [int , ...],
337+ w2_shape : Tuple [int , ...],
338+ top_k : int ,
339+ dtype : Optional [str ],
340+ M : int ,
341+ override_config : Optional [Dict [str , Any ]] = None ,
342+ ):
343+ if override_config :
344+ config = override_config
345+ else :
346+ # First try to load optimal config from the file
347+ E , _ , N = w2_shape
348+ configs = get_moe_configs (E , N , dtype )
349+
350+ if configs :
351+ # If an optimal configuration map has been found, look up the
352+ # optimal config
353+ config = configs [min (configs .keys (), key = lambda x : abs (x - M ))]
354+ else :
355+ # Else use the default config
356+ config = get_default_config (M , E , N , w1_shape [2 ], top_k , dtype )
357+ return config
358+
359+
335360def fused_topk (
336361 hidden_states : torch .Tensor ,
337362 gating_output : torch .Tensor ,
@@ -428,22 +453,16 @@ def fused_experts(hidden_states: torch.Tensor,
428453 CHUNK_SIZE = envs .VLLM_FUSED_MOE_CHUNK_SIZE
429454 M = min (num_tokens , CHUNK_SIZE )
430455
431- if override_config :
432- config = override_config
433- else :
434- # First try to load optimal config from the file
435- configs = get_moe_configs (E , w2 .shape [2 ],
436- "float8" if use_fp8 else None )
456+ get_config_func = functools .partial (
457+ try_get_optimal_moe_config ,
458+ w1 .shape ,
459+ w2 .shape ,
460+ topk_ids .shape [1 ],
461+ "float8" if use_fp8 else None ,
462+ override_config = override_config ,
463+ )
437464
438- if configs :
439- # If an optimal configuration map has been found, look up the
440- # optimal config
441- config = configs [min (configs .keys (), key = lambda x : abs (x - M ))]
442- else :
443- # Else use the default config
444- config = get_default_config (M , E , N , w1 .shape [2 ],
445- topk_ids .shape [1 ],
446- "float8" if use_fp8 else None )
465+ config = get_config_func (M )
447466
448467 intermediate_cache1 = torch .empty ((M , topk_ids .shape [1 ], N ),
449468 device = hidden_states .device ,
@@ -478,6 +497,8 @@ def fused_experts(hidden_states: torch.Tensor,
478497 intermediate_cache1 = intermediate_cache1 [:tokens_in_chunk ]
479498 intermediate_cache2 = intermediate_cache2 [:tokens_in_chunk ]
480499 intermediate_cache3 = intermediate_cache3 [:tokens_in_chunk ]
500+ # reload config to get better performance on the last chunk
501+ config = get_config_func (tokens_in_chunk )
481502
482503 curr_topk_ids = topk_ids [begin_chunk_idx :end_chunk_idx ]
483504 curr_topk_weights = topk_weights [begin_chunk_idx :end_chunk_idx ]
0 commit comments