Skip to content

Commit

Permalink
[Misc][Breaking] Change FP8 checkpoint format from act_scale -> input…
Browse files Browse the repository at this point in the history
…_scale (vllm-project#5353)
  • Loading branch information
mgoin authored and jimpang committed Jul 24, 2024
1 parent 33c0a24 commit 77891b8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
30 changes: 15 additions & 15 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ def create_weights(
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)

# ACTIVATION SCALE
# INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
self._create_scale_param(
scale_name="act_scale",
scale_name="input_scale",
layer=layer,
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
Expand Down Expand Up @@ -207,7 +207,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.logical_widths = None
layer.act_scale = None
layer.input_scale = None
return

# If checkpoint is fp8, requantize the separately quantized logical
Expand All @@ -232,18 +232,18 @@ def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

# ACT_SCALE
# INPUT ACTIVATION SCALE
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the act_scales (since they are equal).
# Static: set to max of the input_scales (since they are equal).
if self.quant_config.activation_scheme == "dynamic":
layer.act_scale = None
layer.input_scale = None
elif self.quant_config.activation_scheme == "static":
if not all_close_1d(layer.act_scale):
if not all_close_1d(layer.input_scale):
raise ValueError(
"All the act_scales for the logical weights of a layer "
f"must be equal. But got {layer.act_scale}")
layer.act_scale = Parameter(layer.act_scale.max(),
requires_grad=False)
"All the input_scales for the logical weights of a "
f"layer must be equal. But got {layer.input_scale}")
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
raise ValueError(
f"Unknown scheme {self.quant_config.activation_scheme}")
Expand All @@ -254,11 +254,11 @@ def apply(self,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.

if bias is None and self.cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)

# Fused GEMM_DQ
output = ops.cutlass_scaled_mm_dq(
Expand All @@ -271,7 +271,7 @@ def apply(self,

else:
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.act_scale,
layer.input_scale,
batch_dim_padding=17)

# Fused GEMM_DQ -- note we padded the input above because
Expand Down
16 changes: 8 additions & 8 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(
"weight_loader": self.weight_loader,
})

# ACT_SCALE (for fp8)
# INPUT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
Expand Down Expand Up @@ -182,11 +182,11 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
param_data[expert_id, :, :] = loaded_weight[:, shard]

# Loading scales
if "act_scale" in weight_name or "w2.weight_scale" in weight_name:
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"act_scales of w1 and w3 of a layer "
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
Expand Down Expand Up @@ -225,9 +225,9 @@ def process_weights_after_loading(self):
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)

else:
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
# If checkpoint is fp8 + static, cleanup input_scales.
# Since state_dict has an input_scale per expert but our kernels
# are passed one input_scale shared across all experts.
if self.quant_config.activation_scheme == "static":
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
Expand All @@ -237,7 +237,7 @@ def process_weights_after_loading(self):
if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)):
print_warning_once(
"Found act_scales that are not equal for "
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. ")

Expand Down Expand Up @@ -576,7 +576,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
f"experts.{expert_id}.{weight_name}.input_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
Expand Down

0 comments on commit 77891b8

Please sign in to comment.