Skip to content

Commit 160b60a

Browse files
avshalommanjimpang
authored andcommitted
[Kernel] reloading fused_moe config on the last chunk (vllm-project#6210)
1 parent 82f9898 commit 160b60a

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,31 @@ def get_default_config(
332332
return config
333333

334334

335+
def try_get_optimal_moe_config(
336+
w1_shape: Tuple[int, ...],
337+
w2_shape: Tuple[int, ...],
338+
top_k: int,
339+
dtype: Optional[str],
340+
M: int,
341+
override_config: Optional[Dict[str, Any]] = None,
342+
):
343+
if override_config:
344+
config = override_config
345+
else:
346+
# First try to load optimal config from the file
347+
E, _, N = w2_shape
348+
configs = get_moe_configs(E, N, dtype)
349+
350+
if configs:
351+
# If an optimal configuration map has been found, look up the
352+
# optimal config
353+
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
354+
else:
355+
# Else use the default config
356+
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
357+
return config
358+
359+
335360
def fused_topk(
336361
hidden_states: torch.Tensor,
337362
gating_output: torch.Tensor,
@@ -428,22 +453,16 @@ def fused_experts(hidden_states: torch.Tensor,
428453
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
429454
M = min(num_tokens, CHUNK_SIZE)
430455

431-
if override_config:
432-
config = override_config
433-
else:
434-
# First try to load optimal config from the file
435-
configs = get_moe_configs(E, w2.shape[2],
436-
"float8" if use_fp8 else None)
456+
get_config_func = functools.partial(
457+
try_get_optimal_moe_config,
458+
w1.shape,
459+
w2.shape,
460+
topk_ids.shape[1],
461+
"float8" if use_fp8 else None,
462+
override_config=override_config,
463+
)
437464

438-
if configs:
439-
# If an optimal configuration map has been found, look up the
440-
# optimal config
441-
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
442-
else:
443-
# Else use the default config
444-
config = get_default_config(M, E, N, w1.shape[2],
445-
topk_ids.shape[1],
446-
"float8" if use_fp8 else None)
465+
config = get_config_func(M)
447466

448467
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
449468
device=hidden_states.device,
@@ -478,6 +497,8 @@ def fused_experts(hidden_states: torch.Tensor,
478497
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
479498
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
480499
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
500+
# reload config to get better performance on the last chunk
501+
config = get_config_func(tokens_in_chunk)
481502

482503
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
483504
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]

0 commit comments

Comments
 (0)