Skip to content

[Bugfix] Add padding for block-scale fused-moe weights for AITER lib #19234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,113 @@ def __init__(self, quant_config: Fp8Config):
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm)

def _maybe_pad_rocm_aiter_block_scaled_fused_moe_weights(
self,
w2_weight,
w2_weight_scale_inv,
w13_weight,
w13_weight_scale_inv,
block_k=128,
block_n=128):
"""
Pads the MoE weights and scales to align with block quantization
requirements.

aiter.fmoe_fp8_blockscale_g1u1 only support out dtype = bf16,
inter_dim % 256 = 0 and fc_scale_blkn and fc_scale_blkk is 128
"""

if (not self.rocm_aiter_moe_enabled):
return (w2_weight, w2_weight_scale_inv, w13_weight,
w13_weight_scale_inv)

if (self.rocm_aiter_moe_enabled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you don't need to check self.rocm_aiter_mode_enabled here

and (w2_weight.shape[-1] % 256 == 0
and w13_weight.shape[-2] % 256 == 0)):
return (w2_weight, w2_weight_scale_inv, w13_weight,
w13_weight_scale_inv)

logger.info_once(
"ROCm AITER Padding MoE weights and scales for block quantization."
)
# for now this is enabled for DeepSeekV3 and Qwen3
assert block_k == 128, "block_k must be 128"
assert block_n == 128, "block_n must be 128"
assert block_k == block_n, (
"block_k and block_n must be the same value: 128")

num_experts, hidden_size, inter_dim = w2_weight.shape
padded_inter_dim = ((inter_dim + 255) // 256) * 256
# inter_dim_block_scale = layer.w2_weight_scale_inv.shape[2]
# = ((intermediate_size_per_partition + block_n - 1) // block_n)
inter_dim_block_scale = (inter_dim + block_n - 1) // block_n
padded_inter_dim_block_scale = ((padded_inter_dim + block_n - 1) //
block_n)

# k_block_scale is also known as hidden_size_block
# Pad w2_weight to
# [num_experts, hidden_size, inter_dim]
# Padding Logic:
# [expert(local_expert:EP), hidden_size, inter_dim]
# after padding inter_dim with 0.0 to multiple of 256
# [expert(local_expert:EP), hidden_size, padded_inter_dim]
if padded_inter_dim > inter_dim:
pad_size = padded_inter_dim - inter_dim
w2_weight = F.pad(w2_weight, (0, pad_size), value=0.0)

# Pad w2_weight_scale_inv to
# [num_experts, k_block_scale, inter_dim_block_scale]
# Padding Logic:
# [expert(local_expert:EP), k_block_scale, inter_dim_block_scale]
# after padding inter_dim with 1.0
# [expert(local_expert:EP), k_block_scale, padded_inter_dim_block_scale] # noqa: E501
if padded_inter_dim_block_scale > inter_dim_block_scale:
pad_size = padded_inter_dim_block_scale - inter_dim_block_scale
w2_weight_scale_inv = F.pad(w2_weight_scale_inv, (0, pad_size),
value=1.0)

# Pad w13_weight to
# [num_experts, 2 * inter_dim, hidden_size]
# Padding Logic:
# [expert(local_expert:EP), inter_dim*2, dim]
# after reshape
# [expert(local_expert:EP), 2, inter_dim, dim]
# after right padding
# [expert(local_expert:EP), 2, padded_inter_dim, dim]
# after reshape
# [expert(local_expert:EP), 2 * padded_inter_dim, dim]
w13_weight = w13_weight.view(num_experts, 2, inter_dim, hidden_size)
if padded_inter_dim > inter_dim:
pad_size = padded_inter_dim - inter_dim
w13_weight = F.pad(w13_weight, (0, 0, 0, pad_size), value=0.0)
w13_weight = w13_weight.view(num_experts, 2 * padded_inter_dim,
hidden_size)

# Pad w13_weight_scale_inv to
# [num_experts, 2 * inter_dim_block_scale, k_block_scale]
# Padding Logic:
# k_block_scale = ((hidden_size + block_k - 1) // block_k)
# [expert(local_expert:EP), inter_dim_block_scale*2, k_block_scale] # noqa: E501
# after reshape
# [expert(local_expert:EP), 2, inter_dim_block_scale, k_block_scale] # noqa: E501
# after right padding with 1.0
# [expert(local_expert:EP), 2, padded_inter_dim_block_scale, k_block_scale] # noqa: E501
# after reshape
# [expert(local_expert:EP), 2 * padded_inter_dim_block_scale, k_block_scale] # noqa: E501
k_block_scale = w13_weight_scale_inv.shape[
2] # k_block_scale = (hidden_size + block_k - 1) // block_k
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can you move the comment to the line above so that the code is all on the same line?

w13_weight_scale_inv = w13_weight_scale_inv.view(
num_experts, 2, inter_dim_block_scale, k_block_scale)
if padded_inter_dim_block_scale > inter_dim_block_scale:
pad_size = padded_inter_dim_block_scale - inter_dim_block_scale
w13_weight_scale_inv = F.pad(w13_weight_scale_inv,
(0, 0, 0, pad_size),
value=1.0)
w13_weight_scale_inv = w13_weight_scale_inv.view(
num_experts, 2 * padded_inter_dim_block_scale, k_block_scale)

return w2_weight, w2_weight_scale_inv, w13_weight, w13_weight_scale_inv

def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
Expand Down Expand Up @@ -623,6 +730,17 @@ def process_weights_after_loading(self, layer: Module) -> None:
w2_weight = layer.w2_weight
w2_weight_scale_inv = layer.w2_weight_scale_inv

if self.quant_config.weight_block_size is not None:
(w2_weight, w2_weight_scale_inv, w13_weight,
w13_weight_scale_inv
) = self._maybe_pad_rocm_aiter_block_scaled_fused_moe_weights(
w2_weight,
w2_weight_scale_inv,
w13_weight,
w13_weight_scale_inv,
block_n=self.quant_config.weight_block_size[0],
block_k=self.quant_config.weight_block_size[1])

# torch.compile() cannot use Parameter subclasses.
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
Expand Down