diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bd13d8fecbb96..a0cb4337f9dee 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -410,6 +410,7 @@ def fused_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids @@ -443,7 +444,8 @@ def grouped_topk(hidden_states: torch.Tensor, if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + + return topk_weights, topk_ids.to(torch.int32) def get_config_dtype_str(dtype: torch.dtype,